#!/usr/bin/env/ python
################################################################################
# Copyright 2016 Brecht Baeten
# This file is part of plottools.
#
# plottools is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# plottools is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with plottools. If not, see <http://www.gnu.org/licenses/>.
################################################################################
import matplotlib.pyplot as plt
import numpy as np
import itertools
[docs]def set_publication_rc():
"""
Sets rc parameters for creating plots suitable for publication
Notes
-----
The computer modern fonts are not installed by default on windows. But can
be downloaded at https://sourceforge.net/projects/cm-unicode/
To use new installed fonts in matplotlib you must delete the font cache file
located at C:\Users\yourusername\.matplotlib
Examples
--------
.. plot::
>>> import matplotlib.pyplot as plt
>>> import numpy as np
>>> import plottools
>>>
>>> plottools.set_publication_rc()
>>> plt.plot(np.arange(10),10*np.random.random(10))
>>> plt.xlabel('x-label')
>>> plt.ylabel('y-label')
>>> plt.show()
"""
# figure
plt.rc('figure', autolayout=True, figsize=(80/25.4,50/25.4))
plt.rc('savefig', format='pdf', dpi=150, bbox='tight', pad_inches=0.02)
# text
#plt.rc('text', usetex=True)
# font
plt.rc('font', size=6)
plt.rc('font', **{'family':'sans-serif', 'sans-serif':['computer modern sans serif', 'CMU Sans Serif'], 'serif':['computer modern roman', 'CMU Serif']} )
# axes
plt.rc('axes', linewidth=0.4, labelsize=8)
plt.rc('axes.formatter', useoffset=False)
# legend
plt.rc('legend', fontsize=8, frameon=True)
# lines
plt.rc('lines', linewidth=0.8,markersize=4)
# patch
plt.rc('patch', linewidth=0.4, edgecolor=(1.0,1.0,1.0))
# ticks
plt.rc('xtick.major', size=2, width=0.3, pad=3)
plt.rc('ytick.major', size=2, width=0.3, pad=3)
plt.rc('xtick.minor', size=1, width=0.3, pad=3)
plt.rc('ytick.minor', size=1, width=0.3, pad=3)
def savefig(filename,width=None,height=None,ratio=8./5.,**kwargs):
"""
Creates a new figure with a specified width and 8:5, width:height ratio
If no width or height are specified a 8cm x 5cm figure is saved. If the
width and height are specified, a figure of that size is saved.
If one of width or height is specified, the ratio is used to define the
other.
Parameters
----------
width : number
figure width in cm
height : number
figure height in cm
ratio : number
figure height in cm
"""
if width is None and height is None:
width = 8.
height = width/ratio
elif not width is None and height is None:
height = width/ratio
elif not height is None and width is None:
width = height*ratio
plt.gcf().set_size_inches(width/2.54,height/2.54)
plt.savefig(filename,**kwargs)
def set_style(style,axes=None):
"""
Sets the style of a single axes object from some specification on top of other rc parameters
Parameters
----------
style : {'horizontalgrid','horizontalgridwithoutticks'}
style string 'horizontalgrid','horizontalgridwithoutticks'
axes : matplotlib axes object
the axes to which to apply the style, if omitted the current axis
obtained with plt.gca() it styled
Examples
--------
>>> plt.plot(np.arange(10),10*np.random.random(10))
>>> plottools.set_style('horizontalgrid')
>>> plt.show()
"""
if axes == None:
axes = plt.gca()
if style in ['horizontalgrid','horizontalgridwithoutticks']:
# hide the spines except the bottom one
axes.spines['top'].set_visible(False)
# axes.spines['bottom'].set_visible(False)
axes.spines['right'].set_visible(False)
axes.spines['left'].set_visible(False)
# show ticks only on the left bottom
axes.get_xaxis().tick_bottom()
axes.get_yaxis().tick_left()
# add horizontal lines
yticks = axes.get_yticks()
xlim = axes.get_xlim()
for y in yticks:
axes.plot(xlim, [y,y], '-', linewidth=0.3, color='k', alpha=0.3, zorder=-10)
axes.yaxis.set_tick_params(which='both', bottom='off', top='off', labelbottom='on', left='off', right='off', labelleft='on')
if style == 'horizontalgridwithoutticks':
axes.xaxis.set_tick_params(which='both', bottom='off', top='off', labelbottom='on', left='off', right='off', labelleft='on')
[docs]def zoom_axes(fig,ax,zoom_x,zoom_y,axes_x,axes_y,box=True,box_color='k',box_alpha=0.8,connect=True,connect_color='k',connect_alpha=0.3,spacing=4,tick_width=20,tick_height=12):
"""
Creates a new axes which zooms in on a part of a given axes.
A box is drawn around the area to be zoomed specified in data coordinates. A
new empty axes is created at the specified location, supplied in data
coordinates. The new axis limits are set so that they match the zoom box.
The zoom box and axis can be connected with two lines, connecting the outer
most corner points while leaving space for the axis ticks.
Parameters
----------
fig : matplotlib figure
the figure in which to create a zoom axis
ax : matplotlib axes
the axis in which to create a zoom axis
zoom_x : list
[min, max] specifying the zooming horizontal area in data
coordinates
zoom_y : list
[min, max] specifying the zooming vertical area in data coordinates
axes_x : list
[min, max] specifying the new axes horizontal location in data
coordinates
axes_y : list
[min, max] specifying the new axes vertical location in data
coordinates
box : bool, optional
specifies whether a box is drawn
box_color : color string or tuple,optional
specifies the box color
box_alpha : number
between 0 and 1, specifies the box alpha
connect : bool, optional
specifies whether the connecting lines are drawn
connect_color : color string or tuple,optional
specifies the connecting lines color
connect_alpha : number
between 0 and 1, specifies the connecting lines alpha
spacing : number
specifies the spacing between the box, axis and the connecting lines
in points
tick_width : number
specifies the width of the tick labels in points to avoid drawing
connecting lines through the tick labels
tick_height : number
specifies the height of the tick labels in points to avoid drawing
connecting lines through the tick labels
Returns
-------
ax_zoom : matplotlib axes
the new axes
Notes
-----
* Axes limits should not be changed after a zoom axes has been added
* :code:`zoom_axes` does not give the expected results when used on a
subfigure
Examples
--------
.. plot::
>>> import matplotlib.pyplot as plt
>>> import numpy as np
>>> import plottools
>>>
>>> fig,ax = plt.subplots()
>>> x = np.linspace(0,1,100)
>>> y = 1-x + 0.02*(2*np.random.random(len(x))-1)
>>> ax.plot(x,y)
>>> ax_zoom = plottools.zoom_axes(fig,ax,[0.1,0.2],[0.8,0.9],[0.6,0.9],[0.6,0.9])
>>> ax_zoom.plot(x,y)
>>> plt.show()
"""
plt.tight_layout()
ax1_p0 = (ax.transData + fig.transFigure.inverted()).transform_point((axes_x[0],axes_y[0]))
ax1_p1 = (ax.transData + fig.transFigure.inverted()).transform_point((axes_x[1],axes_y[1]))
ax1 = plt.axes([ax1_p0[0],ax1_p0[1],ax1_p1[0]-ax1_p0[0],ax1_p1[1]-ax1_p0[1]])
ax1.set_xlim(zoom_x)
ax1.set_ylim(zoom_y)
plt.xticks(fontsize=4)
plt.yticks(fontsize=4)
ax1.tick_params(axis='x', pad=3)
ax1.tick_params(axis='y', pad=2)
if box:
ax.plot([zoom_x[0],zoom_x[1],zoom_x[1],zoom_x[0],zoom_x[0]],[zoom_y[0],zoom_y[0],zoom_y[1],zoom_y[1],zoom_y[0]],color=box_color,alpha=box_alpha,linewidth=0.4)
if connect:
# define a box of points of the axes and the zoom
zoom_xx = [zoom_x[0],zoom_x[0],zoom_x[1],zoom_x[1]]
zoom_yy = [zoom_y[0],zoom_y[1],zoom_y[1],zoom_y[0]]
axes_xx = [axes_x[0],axes_x[0],axes_x[1],axes_x[1]]
axes_yy = [axes_y[0],axes_y[1],axes_y[1],axes_y[0]]
# determine which points to connect
if axes_x[1] < zoom_x[1]:
# left
if axes_y[0] > zoom_y[0]:
# top
p1 = 0
p2 = 2
elif axes_y[1] < zoom_y[1]:
# bottom
p1 = 1
p2 = 3
else:
# center
p1 = 2
p2 = 3
elif axes_x[0] > zoom_x[0]:
# right
if axes_y[0] > zoom_y[0]:
# top
p1 = 1
p2 = 3
elif axes_y[1] < zoom_y[1]:
# bottom
p1 = 0
p2 = 2
else:
# center
p1 = 0
p2 = 1
else:
# center
if axes_y[0] > zoom_y[0]:
# top
p1 = 0
p2 = 3
elif axes_y[1] < zoom_y[1]:
# bottom
p1 = 1
p2 = 2
else:
# center, the axes is over the zoom
p1 = 0
p2 = 0
line1 = ([zoom_xx[p1],axes_xx[p1]],[zoom_yy[p1],axes_yy[p1]])
line2 = ([zoom_xx[p2],axes_xx[p2]],[zoom_yy[p2],axes_yy[p2]])
# estimate the width and height of the ticks
tick_width = (ax.transData.inverted()).transform_point((tick_width,0))[0]-(ax.transData.inverted()).transform_point((0,0))[0]
tick_height = (ax.transData.inverted()).transform_point((0,tick_height))[1]-(ax.transData.inverted()).transform_point((0,0))[1]
spacing = (ax.transData.inverted()).transform_point((spacing,0))[0]-(ax.transData.inverted()).transform_point((0,0))[0]
# create fictional boxes around the axes where no lines should be
box_axes_x = [ axes_x[0]-tick_width , axes_x[1]+spacing]
box_axes_y = [ axes_y[0]-tick_height , axes_y[1]+spacing]
box_zoom_x = [ zoom_x[0]-spacing , zoom_x[1]+spacing]
box_zoom_y = [ zoom_y[0]-spacing , zoom_y[1]+spacing]
# cut the lines inside the boxes
t = np.linspace(0,1,100)
line1_cut = line1
line2_cut = line2
for tt in t:
x = line1[0][0]*(1-tt) + line1[0][1]*tt
y = line1[1][0]*(1-tt) + line1[1][1]*tt
if x <= box_zoom_x[0] or x >= box_zoom_x[1] or y <= box_zoom_y[0] or y >= box_zoom_y[1]:
line1_cut[0][0] = x
line1_cut[1][0] = y
break
for tt in t[::-1]:
x = line1[0][0]*(1-tt) + line1[0][1]*tt
y = line1[1][0]*(1-tt) + line1[1][1]*tt
if (x <= box_axes_x[0] or x >= box_axes_x[1]) or (y <= box_axes_y[0] or y >= box_axes_y[1]):
line1_cut[0][1] = x
line1_cut[1][1] = y
break
for tt in t:
x = line2[0][0]*(1-tt) + line2[0][1]*tt
y = line2[1][0]*(1-tt) + line2[1][1]*tt
if (x <= box_zoom_x[0] or x >= box_zoom_x[1]) or (y <= box_zoom_y[0] or y >= box_zoom_y[1]):
line2_cut[0][0] = x
line2_cut[1][0] = y
break
for tt in t[::-1]:
x = line2[0][0]*(1-tt) + line2[0][1]*tt
y = line2[1][0]*(1-tt) + line2[1][1]*tt
if (x <= box_axes_x[0] or x >= box_axes_x[1]) or (y <= box_axes_y[0] or y >= box_axes_y[1]):
line2_cut[0][1] = x
line2_cut[1][1] = y
break
# draw the connecting lines
ax.plot(line1_cut[0],line1_cut[1],color=connect_color,alpha=connect_alpha,linewidth=0.4)
ax.plot(line2_cut[0],line2_cut[1],color=connect_color,alpha=connect_alpha,linewidth=0.4)
return ax1
[docs]def categorized_xticklabels(xticks,xticklabels,xticklabelnames=None,fmt=None,size=None,rotation=None,spacing=1.4):
"""
Creates categorized ticks on the x-axis
Parameters
----------
xticks : array-like
The x-locations of the data points
xticklabels : list of array-likes
A list of lists or arrays of which each must have the same length as
xticks. These are all used as x-tick labels, the first array is
displayed highest, the next arrays are printed below the previous
one. Results are the most appealing if the 1st array has the highest
variation and the last array the lowest.
xticklabelnames: list of strings, optional
A list of names for the labels. It must have the same length as the
xticklabels list.
fmt: list of fromat strings, optional
A list of fromatting strings as used by :code:`format` for the tick
labels. It must have the same length as the xticklabels list.
size: list of numbers, optional
A list of numbers specifying the size of the ticklabels in points.
It must have the same length as the xticklabels list.
rotation: list of numbers, optional
A list of numbers specifying the rotation of the ticklabels in
degrees. It must have the same length as the xticklabels list.
spacing: number, optional
Controls the vertical spacing between the differnt ticklabels
Examples
--------
.. plot::
>>> import matplotlib.pyplot as plt
>>> import numpy as np
>>> import plottools
>>>
>>> plottools.set_publication_rc()
>>>
>>> # generate data
>>> C,B,A = np.meshgrid([10,20],[0.4,0.6,0.8],[1,2,3],indexing='ij')
>>> xticklabels = [ A.reshape((-1,)), B.reshape((-1,)), C.reshape((-1,)) ]
>>> values = [np.random.random(len(t)) for t in xticklabels]
>>> xticks = np.arange(len(values[0]))
>>>
>>> xticklabelnames = ['coord','$B_\mathrm{value}$','CCC']
>>> labels = ['set1','set2','set3']
>>> fmt = ['${:.2f}$','${}$ m','$10^{{{:.0f}}}$']
>>> rotation = [70,0,0]
>>>
>>> # create the figure
>>> plt.figure()
>>> bottom = np.zeros_like(values[0])
>>> for val,lab in zip(values,labels):
... plt.bar(xticks+0.05,val,0.9,bottom=bottom,label=lab,color=plottools.color.default.next())
... bottom += val
...
>>>
>>> plt.legend(framealpha=0.7,loc='upper right')
>>> plt.ylabel('y-label')
>>>
>>> # add categories on the x-axis
>>> plottools.categorized_xticklabels(xticks+0.5,xticklabels,xticklabelnames=xticklabelnames,fmt=fmt,rotation=rotation)
>>> plt.show()
"""
# input parsing
if xticklabelnames == None:
xticklabelnames = ['']*len(xticklabels)
if fmt == None:
fmt = ['{}']*len(xticklabels)
if size == None:
size = [plt.gca().xaxis.get_major_ticks()[0].label.get_fontsize()]*len(xticklabels)
if rotation == None:
rotation = [0]*len(xticklabels)
dxticks = np.zeros_like(xticks)
for i in range(len(xticks)-1):
dxticks[i] = xticks[i+1]-xticks[i]
dxticks[-1] = dxticks[-2]
# set the x limits
plt.xlim([xticks[0]-dxticks[0]/2,xticks[-1]+dxticks[-1]/2])
# get both axis limits
xlim = plt.xlim()
ylim = plt.ylim()
# create a list of y positions in points
yp = []
ypi = -1
for i in range(len(xticklabels)):
ypi += -spacing*size[i] - 1*size[i]*np.sin(1.*rotation[i]/180*np.pi)
yp.append( ypi )
# manual xticks and labels
plt.xticks([])
linepositions = []
for i in range(len(xticklabels)-1,0,-1):
c = xticklabels[i]
xtick_old = None
xticklabel_old = None
for j,xtl in enumerate(c):
if not xtl == xticklabel_old:
if not j in linepositions:
# add the separator line
plt.annotate('',xy=(xticks[j]-0.5*dxticks[j], ylim[0]), xycoords='data',xytext=(0, yp[i]), textcoords='offset points',arrowprops={'arrowstyle':'-','color':(0.3,0.3,0.3)})
linepositions.append(j)
if not xticklabel_old==None:
# add the tick label
xtick_avg = 0.5*(xtick_old - 0.5*dxticks[j]) + 0.5*(xticks[j] - 0.5*dxticks[j])
try:
lab = fmt[i].format(xticklabel_old)
except:
lab = xticklabel_old
plt.annotate(lab,xy=(xtick_avg, ylim[0]), xycoords='data',xytext=(0, yp[i]), textcoords='offset points',ha="center", va="bottom", size=size[i], rotation=rotation[i])
xtick_old = xticks[j]
xticklabel_old = xtl
# add the final tick label
j = len(c)-1
xtick_avg = 0.5*(xtick_old - 0.5*dxticks[j]) + 0.5*(xticks[j] + 0.5*dxticks[j])
try:
lab = fmt[i].format(xticklabel_old)
except:
lab = xticklabel_old
plt.annotate(lab,xy=(xtick_avg, ylim[0]), xycoords='data',xytext=(0, yp[i]), textcoords='offset points',ha="center", va="bottom", size=size[i], rotation=rotation[i])
# add the deepest ticklabel
i = 0
for j,xtl in enumerate(xticklabels[i]):
try:
lab = fmt[i].format(xtl)
except:
lab = xtl
plt.annotate(lab,xy=(xticks[j], ylim[0]), xycoords='data',xytext=(0, yp[i]), textcoords='offset points',ha="center", va="bottom", size=size[i], rotation=rotation[i])
# add the final separator line
i = len(xticklabels)-1
j = len(xticklabels[-1])-1
plt.annotate('',xy=(xticks[j]+0.5*dxticks[j], ylim[0]), xycoords='data',xytext=(0, yp[i]), textcoords='offset points',arrowprops={'arrowstyle':'-','color':(0.3,0.3,0.3)})
linepositions.append(j)
# add the ticklabelnames
xp = 3
for i,l in enumerate(xticklabelnames):
plt.annotate(l,xy=(xlim[0], ylim[0]), xycoords='data',xytext=(-xp, yp[i]+0.1*size[i]), textcoords='offset points',ha="right", va="bottom", size=size[i])
def cmapval(v,vmin=0,vmax=1,cmap=None):
"""
Extracts the color belonging to a value or list of values from a colormap
Parameters
----------
v : float, list, tuple
the value or values for which to return the colors
vmin : float
the value corresponding to the first color in the colormap
vmax : float
the value corresponding to the last color in the colormap
cmap : colormap
a matplotlib colormap
Returns
-------
hexcolor : string or list of strings
a html representation of the colors correspoinding to the values
Examples
--------
>>> import matplotlib.pyplot as plt
>>> v = cmapval(0.6)
'#22a784'
>>> v = cmapval([1000,2000,3000],vmin=100,vmax=3150,cmap=plt.cm.plasma)
['#8f0da3', '#e46a5d', '#f7e024']
"""
if cmap is None:
cmap = plt.cm.viridis
extractvalue = False
if not isinstance(v,(list,tuple)):
v = [v]
extractvalue = True
color = cmap( np.interp(v,[vmin,vmax],[0.01,0.99]) )
hexcolor = map(lambda rgb:'#%02x%02x%02x' % (rgb[0]*255,rgb[1]*255,rgb[2]*255),tuple(color[:,0:-1]))
if extractvalue:
hexcolor = hexcolor[0]
return hexcolor
def marker(i):
"""
Default cycle of markers
Parameters
----------
i : int
an index when the supplied index is too big, the markers are cycled
"""
values = ['^','s','<','o','>','*','v','1']
return values[np.mod(i,len(values))]