Source code for pytadbit.utils.extraviews

"""
06 Aug 2013


"""
from warnings import warn
import numpy as np
from subprocess import Popen, PIPE
from itertools import combinations

try:
    from matplotlib import pyplot as plt
except ImportError:
    warn('matplotlib not found\n')


def nicer(res):
    """
    writes resolution number for human beings.
    """
    if not res % 1000000000:
        return str(res)[:-9] + 'Gb'
    if not res % 1000000:
        return str(res)[:-6] + 'Mb'
    if not res % 1000:
        return str(res)[:-3] + 'Kb'
    return str(res) + 'b'


COLOR = {None: '\033[31m', # red
         0   : '\033[34m', # blue
         1   : '\033[34m', # blue
         2   : '\033[34m', # blue
         3   : '\033[36m', # cyan
         4   : '\033[0m' , # white
         5   : '\033[1m' , # bold white
         6   : '\033[33m', # yellow
         7   : '\033[33m', # yellow
         8   : '\033[35m', # purple
         9   : '\033[35m', # purple
         10  : '\033[31m'  # red
         }

COLORHTML = {None: '<span style="color:red;">'       , # red
             0   : '<span>'                          , # blue
             1   : '<span style="color:blue;">'      , # blue
             2   : '<span style="color:blue;">'      , # blue
             3   : '<span style="color:purple;">'    , # purple
             4   : '<span style="color:purple;">'    , # purple
             5   : '<span style="color:teal;">'      , # cyan
             6   : '<span style="color:teal;">'      , # cyan
             7   : '<span style="color:olive;">'     , # yellow
             8   : '<span style="color:olive;">'     , # yellow
             9   : '<span style="color:red;">'       , # red
             10  : '<span style="color:red;">'         # red
             }


def colorize(string, num, ftype='ansi'):
    """
    Colorize with ANSII colors a string for printing in shell. this acording to
    a given number between 0 and 10

    :param string: the string to colorize
    :param num: a number between 0 and 10 (if None, number will be equal to 10)

    :returns: the string 'decorated' with ANSII color code
    """
    color = COLOR if ftype=='ansi' else COLORHTML
    return '%s%s%s' % (color[num], string,
                       '\033[m' if ftype=='ansi' else '</span>')


def color_residues(n_part):
    """
    :param n_part: number of particles
    
    :returns: a list of rgb tuples (red, green, blue)
    """
    result = []
    for n in xrange(n_part):
        red = float(n + 1) / n_part
        result.append((red, 0, 1 - red))
    return result


def augmented_dendrogram(clust_count=None, dads=None, objfun=None, color=False,
                         axe=None, savefig=None, *args, **kwargs):

    from scipy.cluster.hierarchy import dendrogram
    fig = plt.figure(figsize=(8, 8))
    if axe:
        ax = axe
        fig = axe.get_figure()
        ddata = dendrogram(*args, **kwargs)
        plt.clf()
    else:
        ddata = dendrogram(*args, **kwargs)
        plt.clf()
        ax = fig.add_subplot(111)
        ax.patch.set_facecolor('lightgrey')
        ax.patch.set_alpha(0.4)
        ax.grid(ls='-', color='w', lw=1.5, alpha=0.6, which='major')
        ax.grid(ls='-', color='w', lw=1, alpha=0.3, which='minor')
        ax.set_axisbelow(True)
        # remove tick marks
        ax.tick_params(axis='both', direction='out', top=False, right=False,
                       left=False, bottom=False)
        ax.tick_params(axis='both', direction='out', top=False, right=False,
                       left=False, bottom=False, which='minor')
    
    # set dict to store data of each cluster (count and energy), depending on
    # x position in graph.
    leaves = {}
    dist = ddata['icoord'][0][2] - ddata['icoord'][0][1]
    for i, x in enumerate(ddata['leaves']):
        leaves[dist*i + dist/2] = x
    minnrj = min(objfun.values())
    maxnrj = max(objfun.values())
    difnrj = maxnrj - minnrj
    total = sum(clust_count.values())
    if not kwargs.get('no_plot', False):
        for i, d, c in zip(ddata['icoord'], ddata['dcoord'],
                           ddata['color_list']):
            x = 0.5 * sum(i[1:3])
            y = d[1]
            # plt.plot(x, y, 'ro')
            plt.hlines(y, i[1], i[2], lw=2, color='grey')
            # for eaxch branch
            for i1, d1, d2 in zip(i[1:3], [d[0], d[3]], [d[1], d[2]]):
                try:
                    lw = float(clust_count[leaves[i1] + 1])/total*10*len(leaves)
                except KeyError:
                    lw = 1.0
                nrj = objfun[leaves[i1] + 1] if (leaves[i1] + 1) in objfun else maxnrj
                ax.vlines(i1, d1-(difnrj-(nrj-minnrj)), d2, lw=lw,
                          color=(c if color else 'grey'))
                if leaves[i1] + 1 in objfun:
                    ax.annotate("%.3g" % (leaves[i1] + 1),
                                (i1, d1-(difnrj-(nrj-minnrj))),
                                xytext=(0, -8),
                                textcoords='offset points',
                                va='top', ha='center')
            leaves[(i[1] + i[2])/2] = dads[leaves[i[1]] + 1]
    try:
        cutter = 10**int(np.log10(difnrj))
    except OverflowError: # case that the two are exactly the same
        cutter = 1
    cut = 10 if cutter >= 10 else 1
    bot = (-int(difnrj)/cutter * cutter) or -1 # do not want this to be null
    # just to display nice numbers
    form = lambda x: ''.join([(s + ',') if not i%3 and i else s
                              for i, s in enumerate(str(x)[::-1])][::-1])
    plt.yticks([bot+i for i in xrange(0, -bot-bot/cut, -bot/cut)],
               # ["{:,}".format (int(minnrj)/cutter * cutter  + i)
               ["%s" % (form(int(minnrj)/cutter * cutter  + i))
                for i in xrange(0, -bot-bot/cut, -bot/cut)], size='small')
    ax.set_ylabel('Minimum IMP objective function')
    ax.set_xticks([])
    ax.set_xlim((plt.xlim()[0] - 2, plt.xlim()[1] + 2))
    ax.figure.suptitle("Dendogram of clusters of 3D models")
    ax.set_title("Branch length proportional to model's objective function " +
                 "final value\n" +
                 "Branch width to the number of models in the cluster",
                 size='small')
    if savefig:
        fig.savefig(savefig)
    elif not axe:
        plt.show()
    return ddata


def plot_hist_box(data, part1, part2, axe=None, savefig=None):
    # setup the figure and axes
    if axe:
        fig = axe.get_figure()
    else:
        fig = plt.figure(figsize=(6, 6))
    bpAx = fig.add_axes([0.2, 0.7, 0.7, 0.2])   # left, bottom, width, height:
                                                # (adjust as necessary)
    bpAx.patch.set_facecolor('lightgrey')
    bpAx.patch.set_alpha(0.4)
    bpAx.grid(ls='-', color='w', lw=1.5, alpha=0.6, which='major')
    bpAx.grid(ls='-', color='w', lw=1, alpha=0.3, which='minor')
    bpAx.set_axisbelow(True)
    bpAx.minorticks_on() # always on, not only for log
    # remove tick marks
    bpAx.tick_params(axis='both', direction='out', top=False, right=False,
                   left=False, bottom=False)
    bpAx.tick_params(axis='both', direction='out', top=False, right=False,
                   left=False, bottom=False, which='minor')
    # plot stuff
    bp = bpAx.boxplot(data, vert=False)
    plt.setp(bp['boxes'], color='black')
    plt.setp(bp['whiskers'], color='black')
    plt.setp(bp['medians'], color='darkred')
    plt.setp(bp['fliers'], color='darkred', marker='+')
    bpAx.plot(sum(data)/len(data), 1, 
              color='w', marker='*', markeredgecolor='k')
    bpAx.annotate('%.4f' % (bp['boxes'][0].get_xdata()[0]),
                  (bp['boxes'][0].get_xdata()[0], bp['boxes'][0].get_ydata()[1]),
                  va='bottom', ha='center', xytext=(0, 2),
                  textcoords='offset points',
                  size='small')
    bpAx.annotate('%.4f' % (bp['boxes'][0].get_xdata()[2]),
                  (bp['boxes'][0].get_xdata()[2], bp['boxes'][0].get_ydata()[1]),
                  va='bottom', ha='center', xytext=(0, 2),
                  textcoords='offset points',
                  size='small')
    bpAx.annotate('%.4f' % (bp['medians'][0].get_xdata()[0]),
                  (bp['medians'][0].get_xdata()[0], bp['boxes'][0].get_ydata()[0]),
                  va='top', ha='center', xytext=(0, -2),
                  textcoords='offset points', color='darkred',
                  size='small')
    histAx = fig.add_axes([0.2, 0.2, 0.7, 0.5]) # left specs should match and
                                                # bottom + height on this line should
                                                # equal bottom on bpAx line
    histAx.patch.set_facecolor('lightgrey')
    histAx.patch.set_alpha(0.4)
    histAx.grid(ls='-', color='w', lw=1.5, alpha=0.6, which='major')
    histAx.grid(ls='-', color='w', lw=1, alpha=0.3, which='minor')
    histAx.set_axisbelow(True)
    histAx.minorticks_on() # always on, not only for log
    # remove tick marks
    histAx.tick_params(axis='both', direction='out', top=False, right=False,
                   left=False, bottom=False)
    histAx.tick_params(axis='both', direction='out', top=False, right=False,
                   left=False, bottom=False, which='minor')
    h = histAx.hist(data, bins=20, alpha=0.5, color='darkgreen')
    # confirm that the axes line up 
    xlims = np.array([bpAx.get_xlim(), histAx.get_xlim()])
    for ax in [bpAx, histAx]:
        ax.set_xlim([xlims.min(), xlims.max()])
    bpAx.set_xticklabels([])  # clear out overlapping xlabels
    bpAx.set_yticks([])  # don't need that 1 tick mark
    plt.xlabel('Distance between particles (nm)')
    plt.ylabel('Number of observations')
    bpAx.set_title('Histogram and boxplot of distances between particles ' +
                   '%s and %s' % (part1, part2))
    if savefig:
        fig.savefig(savefig)
    elif not axe:
        plt.show()



def chimera_view(cmm_files, chimera_bin='chimera',
                 shape='tube', chimera_cmd=None,
                 savefig=None):
    """
    """
    pref_f = '/tmp/tmp.cmd'
    out = open(pref_f, 'w')
    for cmm_file in cmm_files:
        out.write('open %s\n' % (cmm_file))
    if len(cmm_files) > 1:
        for i in xrange(len(cmm_files - 1)):
            out.write('match #%s #0\n' % (i))
    if not chimera_cmd:
        out.write('''
focus
set bg_color white
windowsize 800 600
bonddisplay never #0
represent wire
shape tube #0 radius 5 bandLength 100 segmentSubdivisions 1 followBonds on
clip yon -500
~label
set subdivision 1
set depth_cue
set dc_color black
set dc_start 0.5
set dc_end 1
scale 0.8
''')
        if savefig:
            if savefig.endswith('.png'):
                out.write('copy file %s png' % (savefig))
            elif savefig[-4:] in ('.mov', 'webm'):
                out.write('''
movie record supersample 1
turn y 3 120
wait 120
movie stop
movie encode output %s
''' % (savefig))
            elif savefig:
                raise Exception('Not supported format, must be png, mov or webm\n')
    else:
        out.write('\n'.join(chimera_cmd) + '\n')
    out.close()
    
    Popen('%s %s' % (chimera_bin, pref_f), shell=True).communicate()


[docs]def plot_3d_optimization_result(result, axes=('scale', 'maxdist', 'upfreq', 'lowfreq')): """ Displays a three dimensional scatter plot representing the result of the optimization. :param result: 3D numpy array contating correlation values :param 'scale','maxdist','upfreq','lowfreq' axes: tuple of axes to represent. The order will define which parameter will be placed on the w, z, y or x axe. """ ori_axes, axes_range, result = result trans = [ori_axes.index(a) for a in axes] axes_range = [axes_range[i] for i in trans] # transpose results result = result.transpose(trans) wax = [my_round(i, 3) for i in axes_range[0]] zax = [my_round(i, 3) for i in axes_range[1]] xax = [my_round(i, 3) for i in axes_range[3]] yax = [my_round(i, 3) for i in axes_range[2]] sort_result = sorted([(result[i, j, k, l], wax[i], zax[j], xax[l], yax[k]) for i in range(len(wax)) for j in range(len(zax)) for k in range(len(yax)) for l in range(len(xax)) if not np.isnan(result[i, j, k, l]) ], key=lambda x: x[0], reverse=True)[0] x = [i for i in axes_range[1] for j in axes_range[2] for k in axes_range[3]] y = [j for i in axes_range[1] for j in axes_range[2] for k in axes_range[3]] z = [k for i in axes_range[1] for j in axes_range[2] for k in axes_range[3]] from mpl_toolkits.mplot3d import Axes3D ncols = int(np.sqrt(len(wax)) + 0.999) nrows = int(np.sqrt(len(wax)) + 0.5) fig = plt.figure(figsize=((ncols)*6,(nrows)*4.5)) for i in xrange(len(wax)): col = [result[i, j, k, l] for j in range(len(axes_range[1])) for k in range(len(axes_range[2])) for l in range(len(axes_range[3]))] ax = fig.add_subplot(int(str(nrows) + str(ncols) + str(i)), projection='3d') ax.set_xlabel(axes[1]) ax.set_ylabel(axes[2]) ax.set_zlabel(axes[3]) lol = ax.scatter(x, y, z, c=col, s=100, alpha=0.9) cbar = fig.colorbar(lol) cbar.ax.set_ylabel('Correlation value') tit = 'Optimal IMP parameters (subplot %s=%s)\n' % (axes[0], wax[i]) tit += 'Best: %s=%%s, %s=%%s, %s=%%s, %s=%%s' % (axes[0], axes[1], axes[3], axes[4]) plt.title(tit % tuple([my_round(r, 3) for r in sort_result[1:]])) plt.show()
def my_round(num, val): num = round(num, val) return int(num) if num == int(num) else num
[docs]def plot_2d_optimization_result(result, axes=('scale', 'maxdist', 'upfreq', 'lowfreq'), show_best=0, skip=None): """ A grid of heatmaps representing the result of the optimization. :param result: 3D numpy array contating correlation values :param 'scale','maxdist','upfreq','lowfreq' axes: tuple of axes to represent. The order will define which parameter will be placed on the w, z, y or x axe. :param 0 show_best: number of best correlation value to identifie. :param None skip: a dict can be passed here in order to fix a given axe, e.g.: {'scale': 0.001, 'maxdist': 500} """ from mpl_toolkits.axes_grid1 import AxesGrid import matplotlib.patches as patches from matplotlib.cm import jet ori_axes, axes_range, result = result trans = [ori_axes.index(a) for a in axes] axes_range = [axes_range[i] for i in trans] # transpose results result = result.transpose(trans) # set NaNs result = np.ma.array(result, mask=np.isnan(result)) cmap = jet cmap.set_bad('w', 1.) # defines axes vmin = result.min() vmax = result.max() wax = [my_round(i, 3) for i in axes_range[0]] zax = [my_round(i, 3) for i in axes_range[1]] xax = [my_round(i, 3) for i in axes_range[3]] yax = [my_round(i, 3) for i in axes_range[2]] # get best correlations sort_result = sorted([(result[i, j, k, l], wax[i], zax[j], xax[l], yax[k]) for i in range(len(wax)) for j in range(len(zax)) for k in range(len(yax)) for l in range(len(xax)) if str(result[i, j, k, l]) != '--'], key=lambda x: x[0], reverse=True)[:show_best+1] # skip axes? wax_range = range(len(wax))[::-1] zax_range = range(len(zax)) skip = {} if not skip else skip for i, k in enumerate(axes): if not k in skip: continue if i == 0: wax_range = [wax.index(skip[k])] elif i==1: zax_range = [zax.index(skip[k])] else: raise Exception(('ERROR: skip keys must be one of the two first' + ' keywords passed as axes parameter')) # best number of rows/columns ncols = len(zax_range) nrows = len(wax_range) fig = plt.figure(figsize=(max(6, float(ncols) * len(xax) / 3), max(6, float(nrows) * len(yax) / 3))) grid = AxesGrid(fig, [.1,.1,.9,.75], nrows_ncols = (nrows+1, ncols+1), axes_pad = 0.0, label_mode = "1", share_all = False, cbar_location="right", cbar_mode="single", cbar_size="%s%%" % (7./(float(ncols) * len(xax) / 3)), cbar_pad="10%", ) cell = ncols used = [] for ii in wax_range: cell+=1 for i in zax_range: used.append(cell) im = grid[cell].imshow(result[ii, i, :, :], interpolation="nearest", origin='lower', vmin=vmin, vmax=vmax, cmap=cmap) grid[cell].tick_params(axis='both', direction='out', top=False, right=False, left=False, bottom=False) for j, best in enumerate(sort_result[:-1]): if best[2] == zax[i] and best[1] == wax[ii]: grid[cell].text(xax.index(best[3]), yax.index(best[4]), str(j), {'ha':'center', 'va':'center'}) if ii == wax_range[0]: rect = patches.Rectangle((-0.5, len(yax)-0.5),len(xax), 1.5, facecolor='grey', alpha=0.5) rect.set_clip_on(False) grid[cell].add_patch(rect) grid[cell].text(len(xax) / 2 - 0.5, len(yax), axes[1] + ' ' + str(my_round(zax[i], 3)), {'ha':'center', 'va':'center'}) cell += 1 rect = patches.Rectangle((len(xax)-.5, -0.5), 1.5, len(yax), facecolor='grey', alpha=0.5) rect.set_clip_on(False) grid[cell-1].add_patch(rect) grid[cell-1].text(len(xax)+.5, len(yax)/2-.5, axes[0] + ' ' + str(my_round(wax[ii], 3)), {'ha':'right', 'va':'center'}, rotation=90) for i in range(cell+1): if not i in used: grid[i].set_visible(False) # This affects all axes because we set share_all = True. # grid.axes_llc.set_ylim(-0.5, len(yax)+1) grid.axes_llc.set_xticks(range(0, len(xax), 2)) grid.axes_llc.set_yticks(range(0, len(yax), 2)) grid.axes_llc.set_xticklabels([my_round(i, 3) for i in xax][::2]) grid.axes_llc.set_yticklabels([my_round(i, 3) for i in yax][::2]) grid.axes_llc.set_ylabel(axes[2]) grid.axes_llc.set_xlabel(axes[3]) grid.cbar_axes[0].colorbar(im) grid.cbar_axes[0].set_ylabel('Correlation value') tit = 'Optimal IMP parameters\n' tit += 'Best: %s=%%s, %s=%%s, %s=%%s, %s=%%s' % (axes[0], axes[1], axes[3], axes[2]) fig.suptitle(tit % tuple([my_round(i, 3) for i in sort_result[0][1:]]), size='large') plt.show()
def compare_models(sm1, sm2, cutoff=150, models1=None, cluster1=None, models2=None, cluster2=None): """ Plots the difference of contact maps of two group of structural models. :param sm1: a StructuralModel :param sm2: a StructuralModel :param 150 dcutoff: distance threshold (nm) to determine if two particles are in contact :param None models: if None (default) the contact map will be computed using all the models. A list of numbers corresponding to a given set of models can be passed :param None cluster: compute the contact map only for the models in the cluster number 'cluster' """ mtx1 = sm1.get_contact_matrix(models=models1, cluster=cluster1, cutoff=cutoff) mtx2 = sm2.get_contact_matrix(models=models2, cluster=cluster2, cutoff=cutoff) mtx3 = [[mtx2[i][j] - mtx1[i][j] for j in xrange(len(mtx1))] for i in xrange(len(mtx1))] fig = plt.figure(figsize=(8, 6)) axe = fig.add_subplot(111) im = axe.imshow(mtx3, origin='lower', interpolation="nearest") axe.set_ylabel('Particle') axe.set_xlabel('Particle') cbar = axe.figure.colorbar(im) cbar.ax.set_ylabel('Signed log difference between models') plt.show()

Table Of Contents