Source code for NeuroTools.plotting

"""
NeuroTools.plotting
===================

This module contains a collection of tools for plotting and image processing that 
shall facilitate the generation and handling of NeuroTools data visualizations.
It utilizes the Matplotlib and the Python Imaging Library (PIL) packages.


Classes
-------

SimpleMultiplot     - object that creates and handles a figure consisting of multiple panels, all with the same datatype and the same x-range.


Functions
---------

get_display         - returns a pylab object with a plot() function to draw the plots.
progress_bar        - prints a progress bar to stdout, filled to the given ratio.
pylab_params        - returns a dictionary with a set of parameters that help to nicely format figures by updating the pylab run command parameters dictionary 'pylab.rcParams'.
set_axis_limits     - defines the axis limits in a plot.
set_labels          - defines the axis labels of a plot.
set_pylab_params    - updates a set of parameters within the the pylab run command parameters dictionary 'pylab.rcParams' in order to achieve nicely formatted figures.
save_2D_image       - saves a 2D numpy array of gray shades between 0 and 1 to a PNG file.
save_2D_movie       - saves a list of 2D numpy arrays of gray shades between 0 and 1 to a zipped tree of PNG files.
"""

import sys, numpy
from NeuroTools import check_dependency


# Check availability of pylab (essential!)
if check_dependency('matplotlib'):
    from matplotlib import use
    use('Agg')
    from matplotlib.figure import Figure
    from matplotlib.lines import Line2D
    from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
if check_dependency('pylab'):
    import pylab

# Check availability of PIL
PILIMAGEUSE = check_dependency('PIL')
if PILIMAGEUSE:
    import PIL.Image as Image



########################################################
# UNIVERSAL FUNCTIONS AND CLASSES FOR NORMAL PYLAB USE #
########################################################



[docs]def get_display(display): """ Returns a pylab object with a plot() function to draw the plots. Inputs: display - if True, a new figure is created. Otherwise, if display is a subplot object, this object is returned. """ if display is False: return None elif display is True: pylab.figure() return pylab else: return display
[docs]def progress_bar(progress): """ Prints a progress bar to stdout. Inputs: progress - a float between 0. and 1. Example: >> progress_bar(0.7) |=================================== | """ progressConditionStr = "ERROR: The argument of function NeuroTools.plotting.progress_bar(...) must be a float between 0. and 1.!" assert (type(progress) == float) and (progress >= 0.) and (progress <= 1.), progressConditionStr length = 50 filled = int(round(length*progress)) print "|" + "=" * filled + " " * (length-filled) + "|\r", sys.stdout.flush()
[docs]def pylab_params(fig_width_pt=246.0, ratio=(numpy.sqrt(5)-1.0)/2.0,# Aesthetic golden mean ratio by default text_fontsize=10, tick_labelsize=8, useTex=False): """ Returns a dictionary with a set of parameters that help to nicely format figures. The return object can be used to update the pylab run command parameters dictionary 'pylab.rcParams'. Inputs: fig_width_pt - figure width in points. If you want to use your figure inside LaTeX, get this value from LaTeX using '\\showthe\\columnwidth'. ratio - ratio between the height and the width of the figure. text_fontsize - size of axes and in-pic text fonts. tick_labelsize - size of tick label font. useTex - enables or disables the use of LaTeX for all labels and texts (for details on how to do that, see http://www.scipy.org/Cookbook/Matplotlib/UsingTex). """ inches_per_pt = 1.0/72.27 # Convert pt to inch fig_width = fig_width_pt*inches_per_pt # width in inches fig_height = fig_width*ratio # height in inches fig_size = [fig_width,fig_height] params = { 'axes.labelsize' : text_fontsize, 'font.size' : text_fontsize, 'xtick.labelsize' : tick_labelsize, 'ytick.labelsize' : tick_labelsize, 'text.usetex' : useTex, 'figure.figsize' : fig_size} return params
[docs]def set_axis_limits(subplot, xmin, xmax, ymin, ymax): """ Defines the axis limits of a plot. Inputs: subplot - the targeted plot xmin, xmax - the limits of the x axis ymin, ymax - the limits of the y axis Example: >> x = range(10) >> y = [] >> for i in x: y.append(i*i) >> pylab.plot(x,y) >> plotting.set_axis_limits(pylab, 0., 10., 0., 100.) """ if hasattr(subplot, 'xlim'): subplot.xlim(xmin, xmax) subplot.ylim(ymin, ymax) elif hasattr(subplot, 'set_xlim'): subplot.set_xlim(xmin, xmax) subplot.set_ylim(ymin, ymax) else: raise Exception('ERROR: The plot passed to function NeuroTools.plotting.set_axis_limits(...) does not provide limit defining functions.')
[docs]def set_labels(subplot, xlabel, ylabel): """ Defines the axis labels of a plot. Inputs: subplot - the targeted plot xlabel - a string for the x label ylabel - a string for the y label Example: >> x = range(10) >> y = [] >> for i in x: y.append(i*i) >> pylab.plot(x,y) >> plotting.set_labels(pylab, 'x', 'y=x^2') """ if hasattr(subplot, 'xlabel'): subplot.xlabel(xlabel) subplot.ylabel(ylabel) elif hasattr(subplot, 'set_xlabel'): subplot.set_xlabel(xlabel) subplot.set_ylabel(ylabel) else: raise Exception('ERROR: The plot passed to function NeuroTools.plotting.set_label(...) does not provide labelling functions.')
[docs]def set_pylab_params(fig_width_pt=246.0, ratio=(numpy.sqrt(5)-1.0)/2.0,# Aesthetic golden mean ratio by default text_fontsize=10, tick_labelsize=8, useTex=False): """ Updates a set of parameters within the the pylab run command parameters dictionary 'pylab.rcParams' in order to achieve nicely formatted figures. Inputs: fig_width_pt - figure width in points. If you want to use your figure inside LaTeX, get this value from LaTeX using '\showthe\columnwidth' ratio - ratio between the height and the width of the figure text_fontsize - size of axes and in-pic text fonts tick_labelsize - size of tick label font useTex - enables or disables the use of LaTeX for all labels and texts (for details on how to do that, see http://www.scipy.org/Cookbook/Matplotlib/UsingTex) """ pylab.rcParams.update(pylab_params(fig_width_pt=fig_width_pt, ratio=ratio, text_fontsize=text_fontsize, \ tick_labelsize=tick_labelsize, useTex=useTex)) #################################################################### # SPECIAL PLOTTING FUNCTIONS AND CLASSES FOR SPECIFIC REQUIREMENTS # ####################################################################
[docs]def save_2D_image(mat, filename): """ Saves a 2D numpy array of gray shades between 0 and 1 to a PNG file. Inputs: mat - a 2D numpy array of floats between 0 and 1 filename - string specifying the filename where to save the data, has to end on '.png' Example: >> import numpy >> a = numpy.random.random([100,100]) # creates a 2D numpy array with random values between 0. and 1. >> save_2D_image(a,'randomarray100x100.png') """ assert PILIMAGEUSE, "ERROR: Since PIL has not been detected, the function NeuroTools.plotting.save_2D_image(...) is not supported!" matConditionStr = "ERROR: First argument of function NeuroTools.plotting.imsave(...) must be a 2D numpy array of floats between 0. and 1.!" filenameConditionStr = "ERROR: Second argument of function NeuroTools.plotting.imsave(...) must be a string ending on \".png\"!" assert (type(mat) == numpy.ndarray) and (mat.ndim == 2) and (mat.min() >= 0.) and (mat.max() <= 1.), matConditionStr assert (type(filename) == str) and (len(filename) > 4) and (filename[-4:].lower() == '.png'), filenameConditionStr mode = 'L' # PIL asks for a permuted (col,line) shape coresponding to the natural (x,y) space pilImage = Image.new(mode, (mat.shape[1], mat.shape[0])) data = numpy.floor(numpy.ravel(mat) * 256.) pilImage.putdata(data) pilImage.save(filename)
[docs]def save_2D_movie(frame_list, filename, frame_duration): """ Saves a list of 2D numpy arrays of gray shades between 0 and 1 to a zipped tree of PNG files. Inputs: frame_list - a list of 2D numpy arrays of floats between 0 and 1 filename - string specifying the filename where to save the data, has to end on '.zip' frame_duration - specifier for the duration per frame, will be stored as additional meta-data Example: >> import numpy >> framelist = [] >> for i in range(100): framelist.append(numpy.random.random([100,100])) # creates a list of 2D numpy arrays with random values between 0. and 1. >> save_2D_movie(framelist, 'randommovie100x100x100.zip', 0.1) """ try: import zipfile except ImportError: raise ImportError("ERROR: Python module zipfile not found! Needed by NeuroTools.plotting.save_2D_movie(...)!") try: import StringIO except ImportError: raise ImportError("ERROR: Python module StringIO not found! Needed by NeuroTools.plotting.save_2D_movie(...)!") assert PILIMAGEUSE, "ERROR: Since PIL has not been detected, the function NeuroTools.plotting.save_2D_movie(...) is not supported!" filenameConditionStr = "ERROR: Second argument of function NeuroTools.plotting.save_2D_movie(...) must be a string ending on \".zip\"!" assert (type(filename) == str) and (len(filename) > 4) and (filename[-4:].lower() == '.zip'), filenameConditionStr zf = zipfile.ZipFile(filename, 'w', zipfile.ZIP_DEFLATED) container = filename[:-4] # remove .zip frame_name_format = "frame%s.%dd.png" % ("%", pylab.ceil(pylab.log10(len(frame_list)))) for frame_num, frame in enumerate(frame_list): frame_data = [(p,p,p) for p in frame.flat] im = Image.new('RGB', frame.shape, 'white') im.putdata(frame_data) io = StringIO.StringIO() im.save(io, format='png') pngname = frame_name_format % frame_num arcname = "%s/%s" % (container, pngname) io.seek(0) zf.writestr(arcname, io.read()) progress_bar(float(frame_num)/len(frame_list)) # add 'parameters' and 'frames' files to the zip archive zf.writestr("%s/parameters" % container, 'frame_duration = %s' % frame_duration) zf.writestr("%s/frames" % container, '\n'.join(["frame%.3d.png" % i for i in range(len(frame_list))])) zf.close()
[docs]class SimpleMultiplot(object): """ A figure consisting of multiple panels, all with the same datatype and the same x-range. """ def __init__(self, nrows, ncolumns, title="", xlabel=None, ylabel=None, scaling=('linear','linear')): self.fig = Figure() self.canvas = FigureCanvas(self.fig) self.axes = [] self.all_panels = self.axes self.nrows = nrows self.ncolumns = ncolumns self.n = nrows*ncolumns self._curr_panel = 0 self.title = title topmargin = 0.06 rightmargin = 0.02 bottommargin = 0.1 leftmargin=0.1 v_panelsep = 0.1*(1 - topmargin - bottommargin)/nrows #0.05 h_panelsep = 0.1*(1 - leftmargin - rightmargin)/ncolumns panelheight = (1 - topmargin - bottommargin - (nrows-1)*v_panelsep)/nrows panelwidth = (1 - leftmargin - rightmargin - (ncolumns-1)*h_panelsep)/ncolumns assert panelheight > 0 bottomlist = [bottommargin + i*v_panelsep + i*panelheight for i in range(nrows)] leftlist = [leftmargin + j*h_panelsep + j*panelwidth for j in range(ncolumns)] bottomlist.reverse() for j in range(ncolumns): for i in range(nrows): ax = self.fig.add_axes([leftlist[j],bottomlist[i],panelwidth,panelheight]) self.set_frame(ax,[True,True,False,False]) ax.xaxis.tick_bottom() ax.yaxis.tick_left() self.axes.append(ax) if xlabel: self.axes[self.nrows-1].set_xlabel(xlabel) if ylabel: self.fig.text(0.5*leftmargin,0.5,ylabel, rotation='vertical', horizontalalignment='center', verticalalignment='center') if scaling == ("linear","linear"): self.plot_function = "plot" elif scaling == ("log", "log"): self.plot_function = "loglog" elif scaling == ("log", "linear"): self.plot_function = "semilogx" elif scaling == ("linear", "log"): self.plot_function = "semilogy" else: raise Exception("Invalid value for scaling parameter")
[docs] def finalise(self): """Adjustments to be made after all panels have been plotted.""" # Turn off tick labels for all x-axes except the bottom one self.fig.text(0.5, 0.99, self.title, horizontalalignment='center', verticalalignment='top') for ax in self.axes[0:self.nrows-1]+self.axes[self.nrows:]: ax.xaxis.set_ticklabels([])
[docs] def save(self, filename): """Saves/prints the figure to file. Inputs: filename - string specifying the filename where to save the data """ self.finalise() self.canvas.print_figure(filename)
[docs] def next_panel(self): """Changes to next panel within figure.""" ax = self.axes[self._curr_panel] self._curr_panel += 1 if self._curr_panel >= self.n: self._curr_panel = 0 ax.plot1 = getattr(ax, self.plot_function) return ax
[docs] def panel(self, i): """Returns panel i.""" ax = self.axes[i] ax.plot1 = getattr(ax, self.plot_function) return ax
[docs] def set_frame(self, ax, boollist, linewidth=2): """ Defines frames for the chosen axis. Inputs: as - the targeted axis boollist - a list linewidth - the limits of the y axis """ assert type(boollist) in [list, numpy.ndarray] assert len(boollist) == 4 if boollist != [True,True,True,True]: bottom = Line2D([0, 1], [0, 0], transform=ax.transAxes, linewidth=linewidth, color='k') left = Line2D([0, 0], [0, 1], transform=ax.transAxes, linewidth=linewidth, color='k') top = Line2D([0, 1], [1, 1], transform=ax.transAxes, linewidth=linewidth, color='k') right = Line2D([1, 0], [1, 1], transform=ax.transAxes, linewidth=linewidth, color='k') ax.set_frame_on(False) for side,draw in zip([left,bottom,right,top],boollist): if draw: ax.add_line(side)