Source code for plotpy.image

# -*- coding: utf-8 -*-
#
# Copyright © 2009-2010 CEA
# Pierre Raybaut
# Licensed under the terms of the CECILL License
# (see plotpy/__init__.py for details)

# pylint: disable=C0103

"""
plotpy.image
------------

The `image` module provides image-related objects and functions:

    * :py:class:`plotpy.image.ImagePlot`: a 2D curve and image plotting widget,
      derived from :py:class:`plotpy.curve.CurvePlot`
    * :py:class:`plotpy.image.ImageItem`: simple images
    * :py:class:`plotpy.image.TrImageItem`: images supporting arbitrary
      affine transform
    * :py:class:`plotpy.image.XYImageItem`: images with non-linear X/Y axes
    * :py:class:`plotpy.image.Histogram2DItem`: 2D histogram
    * :py:class:`plotpy.image.ImageFilterItem`: rectangular filtering area
      that may be resized and moved onto the processed image
    * :py:func:`plotpy.image.assemble_imageitems`
    * :py:func:`plotpy.image.get_plot_source_rect`
    * :py:func:`plotpy.image.get_image_from_plot`

``ImageItem``, ``TrImageItem``, ``XYImageItem``, ``Histogram2DItem`` and
``ImageFilterItem`` objects are plot items (derived from QwtPlotItem) that
may be displayed on a :py:class:`plotpy.image.ImagePlot` plotting widget.

.. seealso::

    Module :py:mod:`plotpy.curve`
        Module providing curve-related plot items and plotting widgets

    Module :py:mod:`plotpy.plot`
        Module providing ready-to-use curve and image plotting widgets and
        dialog boxes

Examples
~~~~~~~~

Create a basic image plotting widget:
    
    * before creating any widget, a `QApplication` must be instantiated (that
      is a `Qt` internal requirement):

>>> import guidata
>>> app = guidata.qapplication()

    * that is mostly equivalent to the following (the only difference is that
      the `guidata` helper function also installs the `Qt` translation
      corresponding to the system locale):

>>> from PyQt4.QtGui import QApplication
>>> app = QApplication([])

    * now that a `QApplication` object exists, we may create the plotting
      widget:

>>> from plotpy.image import ImagePlot
>>> plot = ImagePlot(title="Example")

Generate random data for testing purpose:

>>> import numpy as np
>>> data = np.random.rand(100, 100)

Create a simple image item:
    
    * from the associated plot item class (e.g. `XYImageItem` to create
      an image with non-linear X/Y axes): the item properties are then
      assigned by creating the appropriate style parameters object
      (e.g. :py:class:`plotpy.styles.ImageParam`)

>>> from plotpy.curve import ImageItem
>>> from plotpy.styles import ImageParam
>>> param = ImageParam()
>>> param.label = 'My image'
>>> image = ImageItem(param)
>>> image.set_data(data)

    * or using the `plot item builder` (see :py:func:`plotpy.builder.make`):

>>> from plotpy.builder import make
>>> image = make.image(data, title='My image')

Attach the image to the plotting widget:

>>> plot.add_item(image)

Display the plotting widget:

>>> plot.show()
>>> app.exec_()

Reference
~~~~~~~~~

.. autoclass:: ImagePlot
   :members:
   :inherited-members:
.. autoclass:: BaseImageItem
   :members:
   :inherited-members:
.. autoclass:: RawImageItem
   :members:
   :inherited-members:
.. autoclass:: ImageItem
   :members:
   :inherited-members:
.. autoclass:: TrImageItem
   :members:
   :inherited-members:
.. autoclass:: XYImageItem
   :members:
   :inherited-members:
.. autoclass:: RGBImageItem
   :members:
   :inherited-members:
.. autoclass:: MaskedImageItem
   :members:
   :inherited-members:
.. autoclass:: ImageFilterItem
   :members:
   :inherited-members:
.. autoclass:: XYImageFilterItem
   :members:
   :inherited-members:
.. autoclass:: Histogram2DItem
   :members:
   :inherited-members:

.. autofunction:: assemble_imageitems
.. autofunction:: get_plot_qrect
.. autofunction:: get_image_from_plot
"""

#FIXME: traceback in scaler when adding here 'from __future__ import division'

from __future__ import print_function, unicode_literals

import sys
import os.path as osp
from math import fabs

import numpy as np

from guidata.qt.QtGui import QColor, QImage
from guidata.qt.QtCore import QRectF, QPointF, QRect

from guidata.utils import assert_interfaces_valid, update_dataset
from guidata.py3compat import getcwd, is_text_string

# Local imports
from plotpy.transitional import QwtPlotItem, QwtInterval
from plotpy.config import _
from plotpy.interfaces import (IBasePlotItem, IBaseImageItem, IHistDataSource,
                               IImageItemType, ITrackableItemType,
                               IColormapImageItemType, IVoiImageItemType,
                               ISerializableType, ICSImageItemType,
                               IExportROIImageItemType, IStatsImageItemType)
from plotpy.curve import CurvePlot, CurveItem, PolygonMapItem
from plotpy.colormap import FULLRANGE, get_cmap, get_cmap_name
from plotpy.styles import (ImageParam, ImageAxesParam, TrImageParam,
                           RGBImageParam, MaskedImageParam, XYImageParam,
                           RawImageParam)
from plotpy.shapes import RectangleShape
from plotpy import io
from plotpy.geometry import translate, scale, rotate, colvector
from plotpy.baseplot import canvas_to_axes, axes_to_canvas

stderr = sys.stderr
try:
    from plotpy.histogram2d import histogram2d, histogram2d_func
    from plotpy._scaler import (_histogram, _scale_tr, _scale_xy, _scale_rect,
                                _scale_quads,
                                INTERP_NEAREST, INTERP_LINEAR, INTERP_AA)
except ImportError:
    print(("Module 'plotpy.image': missing C extension"), file=sys.stderr)
    print(("try running :"
                         "python setup.py build_ext --inplace -c mingw32" ), file=sys.stderr)
    raise

LUT_SIZE = 1024
LUT_MAX  = float(LUT_SIZE-1)

def _nanmin(data):
    if isinstance(data, np.ma.MaskedArray):
        data = data.data
    if data.dtype.name in ("float32", "float64", "float128"):
        return np.nanmin(data)
    else:
        return data.min()

def _nanmax(data):
    if isinstance(data, np.ma.MaskedArray):
        data = data.data
    if data.dtype.name in ("float32", "float64", "float128"):
        return np.nanmax(data)
    else:
        return data.max()


def pixelround(x, corner=None):
    """
    Return pixel index (int) from pixel coordinate (float)
    corner: None (not a corner), 'TL' (top-left corner),
    'BR' (bottom-right corner)
    """
    assert corner is None or corner in ('TL', 'BR')
    if corner is None:
        return np.floor(x)
    elif corner == 'BR':
        return np.ceil(x)
    elif corner == 'TL':
        return np.floor(x)


#==============================================================================
# Base image item class
#==============================================================================
[docs]class BaseImageItem(QwtPlotItem): __implements__ = (IBasePlotItem, IBaseImageItem, IHistDataSource, IVoiImageItemType, ICSImageItemType, IStatsImageItemType, IExportROIImageItemType) _can_select = True _can_resize = False _can_move = False _can_rotate = False _readonly = False _private = False def __init__(self, data=None, param=None): super(BaseImageItem, self).__init__() self.bg_qcolor = QColor() self.bounds = QRectF() # BaseImageItem needs: # param.background # param.alpha_mask # param.alpha # param.colormap if param is None: param = self.get_default_param() self.imageparam = param self.selected = False self.data = None self.min = 0.0 self.max = 1.0 self.cmap_table = None self.cmap = None self.colormap_axis = None self._offscreen = np.array((1, 1), np.uint32) # Linear interpolation is the default interpolation algorithm: # it's almost as fast as 'nearest pixel' method but far smoother self.interpolate = None self.set_interpolation(INTERP_LINEAR) x1, y1 = self.bounds.left(), self.bounds.top() x2, y2 = self.bounds.right(), self.bounds.bottom() self.border_rect = RectangleShape(x1, y1, x2, y2) self.border_rect.set_style("plot", "shape/imageborder") # A, B, Background, Colormap self.lut = (1.0, 0.0, None, np.zeros((LUT_SIZE, ), np.uint32)) self.set_lut_range([0., 255.]) self.setItemAttribute(QwtPlotItem.AutoScale) self.setItemAttribute(QwtPlotItem.Legend, True) self._filename = None # The file this image comes from self.histogram_cache = None if data is not None: self.set_data(data) self.imageparam.update_image(self) #---- Public API ----------------------------------------------------------
[docs] def get_default_param(self): """Return instance of the default imageparam DataSet""" raise NotImplementedError
def set_filename(self, fname): self._filename = fname def get_filename(self): fname = self._filename if fname is not None and not osp.isfile(fname): other_try = osp.join(getcwd(), osp.basename(fname)) if osp.isfile(other_try): self.set_filename(other_try) fname = other_try return fname
[docs] def get_filter(self, filterobj, filterparam): """Provides a filter object over this image's content""" raise NotImplementedError
[docs] def get_pixel_coordinates(self, xplot, yplot): """ Return (image) pixel coordinates Transform the plot coordinates (arbitrary plot Z-axis unit) into the image coordinates (pixel unit) Rounding is necessary to obtain array indexes from these coordinates """ return xplot, yplot
[docs] def get_plot_coordinates(self, xpixel, ypixel): """ Return plot coordinates Transform the image coordinates (pixel unit) into the plot coordinates (arbitrary plot Z-axis unit) """ return xpixel, ypixel
[docs] def get_closest_indexes(self, x, y, corner=None): """ Return closest image pixel indexes corner: None (not a corner), 'TL' (top-left corner), 'BR' (bottom-right corner) """ x, y = self.get_pixel_coordinates(x, y) i_max = self.data.shape[1]-1 j_max = self.data.shape[0]-1 if corner == 'BR': i_max += 1 j_max += 1 i = max([0, min([i_max, int(pixelround(x, corner))])]) j = max([0, min([j_max, int(pixelround(y, corner))])]) return i, j
[docs] def get_closest_index_rect(self, x0, y0, x1, y1): """ Return closest image rectangular pixel area index bounds Avoid returning empty rectangular area (return 1x1 pixel area instead) Handle reversed/not-reversed Y-axis orientation """ ix0, iy0 = self.get_closest_indexes(x0, y0, corner='TL') ix1, iy1 = self.get_closest_indexes(x1, y1, corner='BR') if ix0 > ix1: ix1, ix0 = ix0, ix1 if iy0 > iy1: iy1, iy0 = iy0, iy1 if ix0 == ix1: ix1 += 1 if iy0 == iy1: iy1 += 1 return ix0, iy0, ix1, iy1
[docs] def align_rectangular_shape(self, shape): """Align rectangular shape to image pixels""" ix0, iy0, ix1, iy1 = self.get_closest_index_rect(*shape.get_rect()) x0, y0 = self.get_plot_coordinates(ix0, iy0) x1, y1 = self.get_plot_coordinates(ix1, iy1) shape.set_rect(x0, y0, x1, y1)
[docs] def get_closest_pixel_indexes(self, x, y): """ Return closest pixel indexes Instead of returning indexes of an image pixel like the method 'get_closest_indexes', this method returns the indexes of the closest pixel which is not necessarily on the image itself (i.e. indexes may be outside image index bounds: negative or superior than the image dimension) .. note:: This is *not* the same as retrieving the canvas pixel coordinates (which depends on the zoom level) """ x, y = self.get_pixel_coordinates(x, y) i = int(pixelround(x)) j = int(pixelround(y)) return i, j
def get_x_values(self, i0, i1): return np.arange(i0, i1) def get_y_values(self, j0, j1): return np.arange(j0, j1)
[docs] def get_data(self, x0, y0, x1=None, y1=None): """ Return image data Arguments: x0, y0 [, x1, y1] Return image level at coordinates (x0,y0) If x1,y1 are specified: Return image levels (np.ndarray) in rectangular area (x0,y0,x1,y1) """ i0, j0 = self.get_closest_indexes(x0, y0) if x1 is None or y1 is None: return self.data[j0, i0] else: i1, j1 = self.get_closest_indexes(x1, y1) i1 += 1 j1 += 1 return (self.get_x_values(i0, i1), self.get_y_values(j0, j1), self.data[j0:j1, i0:i1])
[docs] def get_closest_coordinates(self, x, y): """Return closest image pixel coordinates""" return self.get_closest_indexes(x, y)
def get_coordinates_label(self, xc, yc): title = self.title().text() z = self.get_data(xc, yc) return "%s:<br>x = %d<br>y = %d<br>z = %g" % (title, xc, yc, z) def set_background_color(self, qcolor): #mask = np.uint32(255*self.imageparam.alpha+0.5).clip(0,255) << 24 self.bg_qcolor = qcolor a, b, _bg, cmap = self.lut if qcolor is None: self.lut = (a, b, None, cmap) else: self.lut = (a, b, np.uint32(QColor(qcolor).rgb() & 0xffffff), cmap) def set_color_map(self, name_or_table): if name_or_table is self.cmap_table: # This avoids rebuilding the LUT all the time return if is_text_string(name_or_table): table = get_cmap(name_or_table) else: table = name_or_table self.cmap_table = table self.cmap = table.colorTable(FULLRANGE) cmap_a = self.lut[3] alpha = self.imageparam.alpha alpha_mask = self.imageparam.alpha_mask for i in range(LUT_SIZE): if alpha_mask: pix_alpha = alpha*(i/float(LUT_SIZE-1)) else: pix_alpha = alpha alpha_channel = np.uint32(255*pix_alpha+0.5).clip(0, 255) << 24 cmap_a[i] = np.uint32((table.rgb(FULLRANGE, i/LUT_MAX)) & 0xffffff) | alpha_channel plot = self.plot() if plot: plot.update_colormap_axis(self) def get_color_map(self): return self.cmap_table def get_color_map_name(self): return get_cmap_name(self.get_color_map())
[docs] def set_interpolation(self, interp_mode, size=None): """ Set image interpolation mode interp_mode: INTERP_NEAREST, INTERP_LINEAR, INTERP_AA size (integer): (for anti-aliasing only) AA matrix size """ if interp_mode in (INTERP_NEAREST, INTERP_LINEAR): self.interpolate = (interp_mode,) if interp_mode == INTERP_AA: aa = np.ones((size, size), self.data.dtype) self.interpolate = (interp_mode, aa)
[docs] def get_interpolation(self): """Get interpolation mode""" return self.interpolate
[docs] def set_lut_range(self, lut_range): """ Set LUT transform range *lut_range* is a tuple: (min, max) """ self.min, self.max = lut_range _a, _b, bg, cmap = self.lut if self.max == self.min: self.lut = (LUT_MAX, self.min, bg, cmap) else: fmin, fmax = float(self.min), float(self.max) # avoid overflows self.lut = (LUT_MAX/(fmax-fmin), -LUT_MAX*fmin/(fmax-fmin), bg, cmap)
[docs] def get_lut_range(self): """Return the LUT transform range tuple: (min, max)""" return self.min, self.max
[docs] def get_lut_range_full(self): """Return full dynamic range""" return _nanmin(self.data), _nanmax(self.data)
[docs] def get_lut_range_max(self): """Get maximum range for this dataset""" kind = self.data.dtype.kind if kind in np.typecodes['AllFloat']: info = np.finfo(self.data.dtype) else: info = np.iinfo(self.data.dtype) return info.min, info.max
[docs] def update_border(self): """Update image border rectangle to fit image shape""" bounds = self.boundingRect().getCoords() self.border_rect.set_rect(*bounds)
[docs] def draw_border(self, painter, xMap, yMap, canvasRect): """Draw image border rectangle""" self.border_rect.draw(painter, xMap, yMap, canvasRect)
[docs] def draw_image(self, painter, canvasRect, src_rect, dst_rect, xMap, yMap): """ Draw image with painter on canvasRect .. warning:: `src_rect` and `dst_rect` are coordinates tuples (xleft, ytop, xright, ybottom) """ dest = _scale_rect(self.data, src_rect, self._offscreen, dst_rect, self.lut, self.interpolate) qrect = QRectF(QPointF(dest[0], dest[1]), QPointF(dest[2], dest[3])) painter.drawImage(qrect, self._image, qrect)
[docs] def export_roi(self, src_rect, dst_rect, dst_image, apply_lut=False, apply_interpolation=False, original_resolution=False): """Export Region Of Interest to array""" if apply_lut: a, b, _bg, _cmap = self.lut else: a, b = 1., 0. interp = self.interpolate if apply_interpolation else (INTERP_NEAREST,) _scale_rect(self.data, src_rect, dst_image, dst_rect, (a, b, None), interp) #---- QwtPlotItem API -----------------------------------------------------
def draw(self, painter, xMap, yMap, canvasRect): x1, y1, x2, y2 = canvasRect.getCoords() i1, i2 = xMap.invTransform(x1), xMap.invTransform(x2) j1, j2 = yMap.invTransform(y1), yMap.invTransform(y2) xl, yt, xr, yb = self.boundingRect().getCoords() dest = (xMap.transform(xl), yMap.transform(yt), xMap.transform(xr)+1, yMap.transform(yb)+1) W = canvasRect.right() H = canvasRect.bottom() if self._offscreen.shape != (H, W): self._offscreen = np.empty((H, W), np.uint32) self._image = QImage(self._offscreen, W, H, QImage.Format_ARGB32) self._image.ndarray = self._offscreen self.notify_new_offscreen() self.draw_image(painter, canvasRect, (i1, j1, i2, j2), dest, xMap, yMap) self.draw_border(painter, xMap, yMap, canvasRect) def boundingRect(self): return self.bounds def notify_new_offscreen(self): # callback for those derived classes who need it pass def setVisible(self, enable): if not enable: self.unselect() # when hiding item, unselect it if enable: self.border_rect.show() else: self.border_rect.hide() QwtPlotItem.setVisible(self, enable) #---- IBasePlotItem API ---------------------------------------------------- def types(self): return (IImageItemType, IVoiImageItemType, IColormapImageItemType, ITrackableItemType, ICSImageItemType, IExportROIImageItemType, IStatsImageItemType, IStatsImageItemType)
[docs] def set_readonly(self, state): """Set object readonly state""" self._readonly = state
[docs] def is_readonly(self): """Return object readonly state""" return self._readonly
[docs] def set_private(self, state): """Set object as private""" self._private = state
[docs] def is_private(self): """Return True if object is private""" return self._private
[docs] def select(self): """Select item""" self.selected = True self.border_rect.select()
[docs] def unselect(self): """Unselect item""" self.selected = False self.border_rect.unselect()
[docs] def is_empty(self): """Return True if item data is empty""" return self.data is None or self.data.size == 0
[docs] def set_selectable(self, state): """Set item selectable state""" self._can_select = state
[docs] def set_resizable(self, state): """Set item resizable state (or any action triggered when moving an handle, e.g. rotation)""" self._can_resize = state
[docs] def set_movable(self, state): """Set item movable state""" self._can_move = state
[docs] def set_rotatable(self, state): """Set item rotatable state""" self._can_rotate = state
def can_select(self): return self._can_select def can_resize(self): return self._can_resize def can_move(self): return self._can_move def can_rotate(self): return self._can_rotate def hit_test(self, pos): plot = self.plot() ax = self.xAxis() ay = self.yAxis() return self.border_rect.poly_hit_test(plot, ax, ay, pos) def update_item_parameters(self): pass def get_item_parameters(self, itemparams): itemparams.add("ShapeParam", self, self.border_rect.shapeparam) def set_item_parameters(self, itemparams): self.border_rect.set_item_parameters(itemparams)
[docs] def move_local_point_to(self, handle, pos, ctrl=None): """Move a handle as returned by hit_test to the new position pos ctrl: True if <Ctrl> button is being pressed, False otherwise""" pass
[docs] def move_local_shape(self, old_pos, new_pos): """Translate the shape such that old_pos becomes new_pos in canvas coordinates""" pass
[docs] def move_with_selection(self, delta_x, delta_y): """ Translate the shape together with other selected items delta_x, delta_y: translation in plot coordinates """ pass #---- IBaseImageItem API --------------------------------------------------
def can_setfullscale(self): return True def can_sethistogram(self): return False
[docs] def get_histogram(self, nbins): """interface de IHistDataSource""" if self.data is None: return [0,], [0, 1] if self.histogram_cache is None \ or nbins != self.histogram_cache[0].shape[0]: #from guidata.utils import tic, toc if True: #tic("histo1") res = np.histogram(self.data, nbins) #toc("histo1") else: #TODO: _histogram is faster, but caching is buggy # in this version #tic("histo2") _min = _nanmin(self.data) _max = _nanmax(self.data) if self.data.dtype in (np.float64, np.float32): bins = np.unique(np.array(np.linspace(_min, _max, nbins+1), dtype=self.data.dtype)) else: bins = np.arange(_min, _max+2, dtype=self.data.dtype) res2 = np.zeros((bins.size+1,), np.uint32) _histogram(self.data.flatten(), bins, res2) #toc("histo2") res = res2[1:-1], bins self.histogram_cache = res else: res = self.histogram_cache return res
def __process_cross_section(self, ydata, apply_lut): if apply_lut: a, b, bg, cmap = self.lut return (ydata*a+b).clip(0, LUT_MAX) else: return ydata
[docs] def get_stats(self, x0, y0, x1, y1): """Return formatted string with stats on image rectangular area (output should be compatible with AnnotatedShape.get_infos)""" ix0, iy0, ix1, iy1 = self.get_closest_index_rect(x0, y0, x1, y1) data = self.data[iy0:iy1, ix0:ix1] xfmt = self.imageparam.xformat yfmt = self.imageparam.yformat zfmt = self.imageparam.zformat return "<br>".join([ "<b>%s</b>" % self.imageparam.label, "%sx%s %s" % (self.data.shape[1], self.data.shape[0], str(self.data.dtype)), "", "%s ≤ x ≤ %s" % (xfmt % x0, xfmt % x1), "%s ≤ y ≤ %s" % (yfmt % y0, yfmt % y1), "%s ≤ z ≤ %s" % (zfmt % data.min(), zfmt % data.max()), "‹z› = " + zfmt % data.mean(), "σ(z) = " + zfmt % data.std(), ])
[docs] def get_xsection(self, y0, apply_lut=False): """Return cross section along x-axis at y=y0""" _ix, iy = self.get_closest_indexes(0, y0) return (self.get_x_values(0, self.data.shape[1]), self.__process_cross_section(self.data[iy,:], apply_lut))
[docs] def get_ysection(self, x0, apply_lut=False): """Return cross section along y-axis at x=x0""" ix, _iy = self.get_closest_indexes(x0, 0) return (self.get_y_values(0, self.data.shape[0]), self.__process_cross_section(self.data[:, ix], apply_lut))
[docs] def get_average_xsection(self, x0, y0, x1, y1, apply_lut=False): """Return average cross section along x-axis""" ix0, iy0, ix1, iy1 = self.get_closest_index_rect(x0, y0, x1, y1) ydata = self.data[iy0:iy1, ix0:ix1].mean(axis=0) return (self.get_x_values(ix0, ix1), self.__process_cross_section(ydata, apply_lut))
[docs] def get_average_ysection(self, x0, y0, x1, y1, apply_lut=False): """Return average cross section along y-axis""" ix0, iy0, ix1, iy1 = self.get_closest_index_rect(x0, y0, x1, y1) ydata = self.data[iy0:iy1, ix0:ix1].mean(axis=1) return (self.get_y_values(iy0, iy1), self.__process_cross_section(ydata, apply_lut))
assert_interfaces_valid(BaseImageItem) #============================================================================== # Raw Image item (image item without scale) #==============================================================================
[docs]class RawImageItem(BaseImageItem): """ Construct a simple image item * data: 2D NumPy array * param (optional): image parameters (:py:class:`plotpy.styles.RawImageParam` instance) """ __implements__ = (IBasePlotItem, IBaseImageItem, IHistDataSource, IVoiImageItemType, ISerializableType) #---- BaseImageItem API ---------------------------------------------------
[docs] def get_default_param(self): """Return instance of the default imageparam DataSet""" return RawImageParam(_("Image")) #---- Serialization methods -----------------------------------------------
def __reduce__(self): fname = self.get_filename() if fname is None: fn_or_data = self.data else: fn_or_data = fname state = self.imageparam, self.get_lut_range(), fn_or_data, self.z() res = ( self.__class__, (), state ) return res def __setstate__(self, state): param, lut_range, fn_or_data, z = state self.imageparam = param if is_text_string(fn_or_data): self.set_filename(fn_or_data) self.load_data() elif fn_or_data is not None: # should happen only with previous API self.set_data(fn_or_data) self.set_lut_range(lut_range) self.setZ(z) self.imageparam.update_image(self)
[docs] def serialize(self, writer): """Serialize object to HDF5 writer""" fname = self.get_filename() load_from_fname = fname is not None data = None if load_from_fname else self.data writer.write(load_from_fname, group_name='load_from_fname') writer.write(fname, group_name='fname') writer.write(data, group_name='Zdata') writer.write(self.get_lut_range(), group_name='lut_range') writer.write(self.z(), group_name='z') self.imageparam.update_param(self) writer.write(self.imageparam, group_name='imageparam')
[docs] def deserialize(self, reader): """Deserialize object from HDF5 reader""" lut_range = reader.read(group_name='lut_range') if reader.read(group_name='load_from_fname'): self.set_filename(reader.read(group_name='fname', func=reader.read_unicode)) self.load_data() else: data = reader.read(group_name='Zdata', func=reader.read_array) self.set_data(data) self.set_lut_range(lut_range) self.setZ(reader.read('z')) self.imageparam = self.get_default_param() reader.read('imageparam', instance=self.imageparam) self.imageparam.update_image(self) #---- Public API ----------------------------------------------------------
[docs] def load_data(self, lut_range=None): """ Load data from *filename* and eventually apply specified lut_range *filename* has been set using method 'set_filename' """ data = io.imread(self.get_filename(), to_grayscale=True) self.set_data(data, lut_range=lut_range)
[docs] def set_data(self, data, lut_range=None): """ Set Image item data * data: 2D NumPy array * lut_range: LUT range -- tuple (levelmin, levelmax) """ if lut_range is not None: _min, _max = lut_range else: _min, _max = _nanmin(data), _nanmax(data) self.data = data self.histogram_cache = None self.update_bounds() self.update_border() self.set_lut_range([_min, _max])
def update_bounds(self): if self.data is None: return self.bounds = QRectF(0, 0, self.data.shape[1], self.data.shape[0]) #---- IBasePlotItem API --------------------------------------------------- def types(self): return (IImageItemType, IVoiImageItemType, IColormapImageItemType, ITrackableItemType, ICSImageItemType, ISerializableType, IExportROIImageItemType, IStatsImageItemType) def update_item_parameters(self): self.imageparam.update_param(self) def get_item_parameters(self, itemparams): BaseImageItem.get_item_parameters(self, itemparams) self.update_item_parameters() itemparams.add("ImageParam", self, self.imageparam) def set_item_parameters(self, itemparams): update_dataset(self.imageparam, itemparams.get("ImageParam"), visible_only=True) self.imageparam.update_image(self) BaseImageItem.set_item_parameters(self, itemparams) #---- IBaseImageItem API -------------------------------------------------- def can_setfullscale(self): return True def can_sethistogram(self): return True
assert_interfaces_valid(RawImageItem) #============================================================================== # Image item #==============================================================================
[docs]class ImageItem(RawImageItem): """ Construct a simple image item * data: 2D NumPy array * param (optional): image parameters (:py:class:`plotpy.styles.ImageParam` instance) """ __implements__ = (IBasePlotItem, IBaseImageItem, IHistDataSource, IVoiImageItemType, IExportROIImageItemType) def __init__(self, data=None, param=None): self.xmin = None self.xmax = None self.ymin = None self.ymax = None super(ImageItem, self).__init__(data=data, param=param) #---- BaseImageItem API ---------------------------------------------------
[docs] def get_default_param(self): """Return instance of the default imageparam DataSet""" return ImageParam(_("Image")) #---- Serialization methods -----------------------------------------------
def __reduce__(self): fname = self.get_filename() if fname is None: fn_or_data = self.data else: fn_or_data = fname (xmin, xmax), (ymin, ymax) = self.get_xdata(), self.get_ydata() state = (self.imageparam, self.get_lut_range(), fn_or_data, self.z(), xmin, xmax, ymin, ymax) res = ( self.__class__, (), state ) return res def __setstate__(self, state): param, lut_range, fn_or_data, z, xmin, xmax, ymin, ymax = state self.set_xdata(xmin, xmax) self.set_ydata(ymin, ymax) self.imageparam = param if is_text_string(fn_or_data): self.set_filename(fn_or_data) self.load_data() elif fn_or_data is not None: # should happen only with previous API self.set_data(fn_or_data) self.set_lut_range(lut_range) self.setZ(z) self.imageparam.update_image(self)
[docs] def serialize(self, writer): """Serialize object to HDF5 writer""" super(ImageItem, self).serialize(writer) (xmin, xmax), (ymin, ymax) = self.get_xdata(), self.get_ydata() writer.write(xmin, group_name='xmin') writer.write(xmax, group_name='xmax') writer.write(ymin, group_name='ymin') writer.write(ymax, group_name='ymax')
[docs] def deserialize(self, reader): """Deserialize object from HDF5 reader""" super(ImageItem, self).deserialize(reader) for attr in ('xmin', 'xmax', 'ymin', 'ymax'): # Note: do not be tempted to write the symetric code in `serialize` # because calling `get_xdata` and `get_ydata` is necessary setattr(self, attr, reader.read(attr, func=reader.read_float)) #---- Public API ----------------------------------------------------------
[docs] def get_xdata(self): """Return (xmin, xmax)""" xmin, xmax = self.xmin, self.xmax if xmin is None: xmin = 0. if xmax is None: xmax = self.data.shape[1] return xmin, xmax
[docs] def get_ydata(self): """Return (ymin, ymax)""" ymin, ymax = self.ymin, self.ymax if ymin is None: ymin = 0. if ymax is None: ymax = self.data.shape[0] return ymin, ymax
def set_xdata(self, xmin=None, xmax=None): self.xmin, self.xmax = xmin, xmax def set_ydata(self, ymin=None, ymax=None): self.ymin, self.ymax = ymin, ymax def update_bounds(self): if self.data is None: return (xmin, xmax), (ymin, ymax) = self.get_xdata(), self.get_ydata() self.bounds = QRectF(QPointF(xmin, ymin), QPointF(xmax, ymax)) #---- BaseImageItem API ---------------------------------------------------
[docs] def get_pixel_coordinates(self, xplot, yplot): """Return (image) pixel coordinates (from plot coordinates)""" (xmin, xmax), (ymin, ymax) = self.get_xdata(), self.get_ydata() xpix = self.data.shape[1]*(xplot-xmin)/float(xmax-xmin) ypix = self.data.shape[0]*(yplot-ymin)/float(ymax-ymin) return xpix, ypix
[docs] def get_plot_coordinates(self, xpixel, ypixel): """Return plot coordinates (from image pixel coordinates)""" (xmin, xmax), (ymin, ymax) = self.get_xdata(), self.get_ydata() xplot = xmin+(xmax-xmin)*xpixel/float(self.data.shape[1]) yplot = ymin+(ymax-ymin)*ypixel/float(self.data.shape[0]) return xplot, yplot
def get_x_values(self, i0, i1): xmin, xmax = self.get_xdata() xfunc = lambda index: xmin+(xmax-xmin)*index/float(self.data.shape[1]) return np.linspace(xfunc(i0), xfunc(i1), i1-i0) def get_y_values(self, j0, j1): ymin, ymax = self.get_ydata() yfunc = lambda index: ymin+(ymax-ymin)*index/float(self.data.shape[0]) return np.linspace(yfunc(j0), yfunc(j1), j1-j0)
[docs] def get_closest_coordinates(self, x, y): """Return closest image pixel coordinates""" (xmin, xmax), (ymin, ymax) = self.get_xdata(), self.get_ydata() i, j = self.get_closest_indexes(x, y) xpix = np.linspace(xmin, xmax, self.data.shape[1]+1) ypix = np.linspace(ymin, ymax, self.data.shape[0]+1) return xpix[i], ypix[j]
def _rescale_src_rect(self, src_rect): sxl, syt, sxr, syb = src_rect xl, yt, xr, yb = self.boundingRect().getCoords() H, W = self.data.shape[:2] x0 = W*(sxl-xl)/(xr-xl) x1 = W*(sxr-xl)/(xr-xl) y0 = H*(syt-yt)/(yb-yt) y1 = H*(syb-yt)/(yb-yt) return x0, y0, x1, y1 def draw_image(self, painter, canvasRect, src_rect, dst_rect, xMap, yMap): if self.data is None: return src2 = self._rescale_src_rect(src_rect) dst_rect = tuple([int(i) for i in dst_rect]) dest = _scale_rect(self.data, src2, self._offscreen, dst_rect, self.lut, self.interpolate) qrect = QRectF(QPointF(dest[0], dest[1]), QPointF(dest[2], dest[3])) painter.drawImage(qrect, self._image, qrect)
[docs] def export_roi(self, src_rect, dst_rect, dst_image, apply_lut=False, apply_interpolation=False, original_resolution=False): """Export Region Of Interest to array""" if apply_lut: a, b, _bg, _cmap = self.lut else: a, b = 1., 0. interp = self.interpolate if apply_interpolation else (INTERP_NEAREST,) _scale_rect(self.data, self._rescale_src_rect(src_rect), dst_image, dst_rect, (a, b, None), interp)
assert_interfaces_valid(ImageItem) #============================================================================== # QuadGrid item #============================================================================== class QuadGridItem(RawImageItem): """ Construct a QuadGrid image * X, Y, Z: A structured grid of quadrilaterals each quad is defined by (X[i], Y[i]), (X[i], Y[i+1]), (X[i+1], Y[i+1]), (X[i+1], Y[i]) * param (optional): image parameters (ImageParam instance) """ __implements__ = (IBasePlotItem, IBaseImageItem, IHistDataSource, IVoiImageItemType) def __init__(self, X, Y, Z, param=None): assert X is not None assert Y is not None assert Z is not None self.X = X self.Y = Y assert X.shape == Y.shape assert Z.shape == X.shape super(QuadGridItem, self).__init__(Z, param) self.set_data(Z) self.grid = 1 self.interpolate = (0, 0.5, 0.5) self.imageparam.update_image(self) #---- BaseImageItem API --------------------------------------------------- def get_default_param(self): """Return instance of the default imageparam DataSet""" return QuadGridParam(_("Quadrilaterals")) def types(self): return (IImageItemType, IVoiImageItemType, IColormapImageItemType, ITrackableItemType) def update_bounds(self): xmin = self.X.min() xmax = self.X.max() ymin = self.Y.min() ymax = self.Y.max() self.bounds = QRectF(xmin, ymin, xmax-xmin, ymax-ymin) def set_data(self, data, X=None, Y=None, lut_range=None): """ Set Image item data * data: 2D NumPy array * lut_range: LUT range -- tuple (levelmin, levelmax) """ if lut_range is not None: _min, _max = lut_range else: _min, _max = _nanmin(data), _nanmax(data) self.data = data self.histogram_cache = None if X is not None: assert Y is not None self.X = X self.Y = Y self.update_bounds() self.update_border() self.set_lut_range([_min, _max]) def draw_image(self, painter, canvasRect, src_rect, dst_rect, xMap, yMap): self._offscreen[...] = np.uint32(0) dest = _scale_quads(self.X, self.Y, self.data, src_rect, self._offscreen, dst_rect, self.lut, self.interpolate, self.grid) qrect = QRectF(QPointF(dest[0], dest[1]), QPointF(dest[2], dest[3])) painter.drawImage(qrect, self._image, qrect) xl, yt, xr, yb = dest self._offscreen[yt:yb, xl:xr] = 0 def notify_new_offscreen(self): # we always ensure the offscreen is clean before drawing self._offscreen[...] = 0 assert_interfaces_valid(QuadGridItem) #============================================================================== # Image with a custom linear transform #==============================================================================
[docs]class TrImageItem(RawImageItem): """ Construct a transformable image item * data: 2D NumPy array * param (optional): image parameters (:py:class:`plotpy.styles.TrImageParam` instance) """ __implements__ = (IBasePlotItem, IBaseImageItem, IExportROIImageItemType) _can_select = True _can_resize = True _can_rotate = True _can_move = True def __init__(self, data=None, param=None): self.tr = np.eye(3, dtype=float) self.itr = np.eye(3, dtype=float) self.points = np.array([ [0, 0, 2, 2], [0, 2, 2, 0], [1, 1, 1, 1] ], float) super(TrImageItem, self).__init__(data, param) #---- BaseImageItem API ---------------------------------------------------
[docs] def get_default_param(self): """Return instance of the default imageparam DataSet""" return TrImageParam(_("Image")) #---- Public API ----------------------------------------------------------
def set_transform(self, x0, y0, angle, dx=1.0, dy=1.0, hflip=False, vflip=False): self.imageparam.set_transform(x0, y0, angle, dx, dy, hflip, vflip) if self.data is None: return ni, nj = self.data.shape rot = rotate(-angle) tr1 = translate(nj/2.+0.5, ni/2.+0.5) xflip = -1. if hflip else 1. yflip = -1. if vflip else 1. sc = scale(xflip/dx, yflip/dy) tr2 = translate(-x0, -y0) self.tr = tr1*sc*rot*tr2 self.itr = self.tr.I self.compute_bounds() def get_transform(self): return self.imageparam.get_transform() def debug_transform(self, pt): x0, y0, angle, dx, dy, _hflip, _vflip = self.get_transform() ni, nj = self.data.shape rot = rotate(-angle) tr1 = translate(ni/2.+0.5, nj/2.+0.5) sc = scale(dx, dy) tr2 = translate(-x0, -y0) p1 = tr1.I*pt p2 = rot.I*pt p3 = sc.I*pt p4 = tr2.I*pt print("src=", pt.T) print("tr1:", p1.T) print("tr1+rot:", p2.T) print("tr1+rot+sc:", p3.T) print("tr1+rot+tr2:", p4.T) def set_crop(self, left, top, right, bottom): self.imageparam.set_crop(left, top, right, bottom) def get_crop(self): return self.imageparam.get_crop()
[docs] def get_crop_coordinates(self): """Return crop rectangle coordinates""" tpos = np.array(np.dot(self.itr, self.points)) xmin, ymin, _ = tpos.min(axis=1).flatten() xmax, ymax, _ = tpos.max(axis=1).flatten() left, top, right, bottom = self.imageparam.get_crop() return (xmin+left, ymin+top, xmax-right, ymax-bottom)
def compute_bounds(self): x0, y0, x1, y1 = self.get_crop_coordinates() self.bounds = QRectF(QPointF(x0, y0), QPointF(x1, y1)) self.update_border() #--- RawImageItem API ----------------------------------------------------- def set_data(self, data, lut_range=None): RawImageItem.set_data(self, data, lut_range) ni, nj = self.data.shape self.points = np.array([[0, 0, nj, nj], [0, ni, ni, 0], [1, 1, 1, 1]], float) self.compute_bounds() #--- BaseImageItem API ----------------------------------------------------
[docs] def get_filter(self, filterobj, filterparam): """Provides a filter object over this image's content""" raise NotImplementedError #TODO: Implement TrImageFilterItem # return TrImageFilterItem(self, filterobj, filterparam)
[docs] def get_pixel_coordinates(self, xplot, yplot): """Return (image) pixel coordinates (from plot coordinates)""" v = self.tr*colvector(xplot, yplot) xpixel, ypixel, _ = v[:, 0] return xpixel, ypixel
[docs] def get_plot_coordinates(self, xpixel, ypixel): """Return plot coordinates (from image pixel coordinates)""" v0 = self.itr*colvector(xpixel, ypixel) xplot, yplot, _ = v0[:, 0].A.ravel() return xplot, yplot
def get_x_values(self, i0, i1): v0 = self.itr*colvector(i0, 0) x0, _y0, _ = v0[:, 0].A.ravel() v1 = self.itr*colvector(i1, 0) x1, _y1, _ = v1[:, 0].A.ravel() return np.linspace(x0, x1, i1-i0) def get_y_values(self, j0, j1): v0 = self.itr*colvector(0, j0) _x0, y0, _ = v0[:, 0].A.ravel() v1 = self.itr*colvector(0, j1) _x1, y1, _ = v1[:, 0].A.ravel() return np.linspace(y0, y1, j1-j0)
[docs] def get_closest_coordinates(self, x, y): """Return closest image pixel coordinates""" xi, yi = self.get_closest_indexes(x, y) v = self.itr*colvector(xi, yi) x, y, _ = v[:, 0].A.ravel() return x, y
def update_border(self): tpos = np.dot(self.itr, self.points) self.border_rect.set_points(tpos.T[:, :2]) def draw_border(self, painter, xMap, yMap, canvasRect): self.border_rect.draw(painter, xMap, yMap, canvasRect) def draw_image(self, painter, canvasRect, src_rect, dst_rect, xMap, yMap): W = canvasRect.width() H = canvasRect.height() if W <= 1 or H <= 1: return x0, y0, x1, y1 = src_rect cx = canvasRect.left() cy = canvasRect.top() sx = (x1-x0)/(W-1) sy = (y1-y0)/(H-1) # tr1 = tr(x0,y0)*scale(sx,sy)*tr(-cx,-cy) tr = np.matrix( [[sx, 0, x0-cx*sx], [ 0, sy, y0-cy*sy], [ 0, 0, 1]], float) mat = self.tr*tr dst_rect = tuple([int(i) for i in dst_rect]) dest = _scale_tr(self.data, mat, self._offscreen, dst_rect, self.lut, self.interpolate) qrect = QRectF(QPointF(dest[0], dest[1]), QPointF(dest[2], dest[3])) painter.drawImage(qrect, self._image, qrect)
[docs] def export_roi(self, src_rect, dst_rect, dst_image, apply_lut=False, apply_interpolation=False, original_resolution=False): """Export Region Of Interest to array""" if apply_lut: a, b, _bg, _cmap = self.lut else: a, b = 1., 0. xs0, ys0, xs1, ys1 = src_rect xd0, yd0, xd1, yd1 = dst_rect if original_resolution: _t1, _t2, _t3, xscale, yscale, _t4, _t5 = self.get_transform() else: xscale, yscale = (xs1-xs0)/float(xd1-xd0), (ys1-ys0)/float(yd1-yd0) mat = self.tr*( translate(xs0, ys0)*scale(xscale, yscale) ) x0, y0, x1, y1 = self.get_crop_coordinates() xd0 = max([xd0, xd0+int((x0-xs0)/xscale)]) yd0 = max([yd0, yd0+int((y0-ys0)/xscale)]) xd1 = min([xd1, xd1+int((x1-xs1)/xscale)]) yd1 = min([yd1, yd1+int((y1-ys1)/xscale)]) dst_rect = xd0, yd0, xd1, yd1 interp = self.interpolate if apply_interpolation else (INTERP_NEAREST,) _scale_tr(self.data, mat, dst_image, dst_rect, (a, b, None), interp) #---- IBasePlotItem API ---------------------------------------------------
[docs] def move_local_point_to(self, handle, pos, ctrl=None): """Move a handle as returned by hit_test to the new position pos ctrl: True if <Ctrl> button is being pressed, False otherwise""" x0, y0, angle, dx, dy, hflip, vflip = self.get_transform() nx, ny = canvas_to_axes(self, pos) handles = self.itr*self.points p0 = colvector(nx, ny) #self.debug_transform(p0) center = handles.sum(axis=1)/4 vec0 = handles[:, handle] - center vec1 = p0 - center a0 = np.arctan2(vec0[1, 0], vec0[0, 0]) a1 = np.arctan2(vec1[1, 0], vec1[0, 0]) if self.can_rotate(): # compute angles angle += a1-a0 if self.can_resize(): # compute pixel size zoom = np.linalg.norm(vec1)/np.linalg.norm(vec0) dx = zoom*dx dy = zoom*dy self.set_transform(x0, y0, angle, dx, dy, hflip, vflip)
[docs] def move_local_shape(self, old_pos, new_pos): """Translate the shape such that old_pos becomes new_pos in canvas coordinates""" x0, y0, angle, dx, dy, hflip, vflip = self.get_transform() nx, ny = canvas_to_axes(self, new_pos) ox, oy = canvas_to_axes(self, old_pos) self.set_transform(x0+nx-ox, y0+ny-oy, angle, dx, dy, hflip, vflip) if self.plot(): self.plot().SIG_ITEM_MOVED.emit(self, ox, oy, nx, ny)
[docs] def move_with_selection(self, delta_x, delta_y): """ Translate the shape together with other selected items delta_x, delta_y: translation in plot coordinates """ x0, y0, angle, dx, dy, hflip, vflip = self.get_transform() self.set_transform(x0+delta_x, y0+delta_y, angle, dx, dy, hflip, vflip)
assert_interfaces_valid(TrImageItem)
[docs]def assemble_imageitems(items, src_qrect, destw, desth, align=None, add_images=False, apply_lut=False, apply_interpolation=False, original_resolution=False): """ Assemble together image items in qrect (`QRectF` object) and return resulting pixel data .. warning:: Does not support `XYImageItem` objects """ # align width to 'align' bytes if align is not None: print("plotpy.image.assemble_imageitems: since v2.2, "\ "the `align` option is ignored", file=sys.stderr) align = 1 #XXX: byte alignment is disabled until further notice! aligned_destw = int(align*((int(destw)+align-1)/align)) aligned_desth = int(desth*aligned_destw/destw) try: output = np.zeros((aligned_desth, aligned_destw), np.float32) except ValueError: raise MemoryError if not add_images: dst_image = output dst_rect = (0, 0, aligned_destw, aligned_desth) src_rect = list(src_qrect.getCoords()) # The source QRect is generally coming from a rectangle shape which is # adjusted to fit a given ROI on the image. So the rectangular area is # aligned with image pixel edges: to avoid any rounding error, we reduce # the rectangle area size by one half of a pixel, so that the area is now # aligned with the center of image pixels. pixel_width = src_qrect.width()/float(destw) pixel_height = src_qrect.height()/float(desth) src_rect[0] += .5*pixel_width src_rect[1] += .5*pixel_height src_rect[2] -= .5*pixel_width src_rect[3] -= .5*pixel_height for it in sorted(items, key=lambda obj: -obj.z()): if it.isVisible() and src_qrect.intersects(it.boundingRect()): if add_images: dst_image = np.zeros_like(output) it.export_roi(src_rect=src_rect, dst_rect=dst_rect, dst_image=dst_image, apply_lut=apply_lut, apply_interpolation=apply_interpolation, original_resolution=original_resolution) if add_images: output += dst_image return output
[docs]def get_plot_qrect(plot, p0, p1): """ Return `QRectF` rectangle object in plot coordinates from top-left and bottom-right `QPointF` objects in canvas coordinates """ ax, ay = plot.X_BOTTOM, plot.Y_LEFT p0x, p0y = plot.invTransform(ax, p0.x()), plot.invTransform(ay, p0.y()) p1x, p1y = plot.invTransform(ax, p1.x()+1), plot.invTransform(ay, p1.y()+1) return QRectF(p0x, p0y, p1x-p0x, p1y-p0y)
def get_items_in_rectangle(plot, p0, p1, item_type=None): """Return items which bounding rectangle intersects (p0, p1) item_type: default is `IExportROIImageItemType`""" if item_type is None: item_type = IExportROIImageItemType items = plot.get_items(item_type=IExportROIImageItemType) src_qrect = get_plot_qrect(plot, p0, p1) return [it for it in items if src_qrect.intersects(it.boundingRect())] def compute_trimageitems_original_size(items, src_w, src_h): """Compute `TrImageItem` original size from max dx and dy""" trparams = [item.get_transform() for item in items if isinstance(item, TrImageItem)] if trparams: dx_max = max([dx for _x, _y, _angle, dx, _dy, _hf, _vf in trparams]) dy_max = max([dy for _x, _y, _angle, _dx, dy, _hf, _vf in trparams]) return src_w/dx_max, src_h/dy_max else: return src_w, src_h def get_image_from_qrect(plot, p0, p1, src_size=None, adjust_range=None, item_type=None, apply_lut=False, apply_interpolation=False, original_resolution=False, add_images=False): """Return image array from `QRect` area (p0 and p1 are respectively the top-left and bottom-right `QPointF` objects) adjust_range: None (return raw data, dtype=np.float32), 'original' (return data with original data type), 'normalize' (normalize range with original data type)""" assert adjust_range in (None, 'normalize', 'original') items = get_items_in_rectangle(plot, p0, p1, item_type=item_type) if not items: raise TypeError(_("There is no supported image item in current plot.")) if src_size is None: _src_x, _src_y, src_w, src_h = get_plot_qrect(plot, p0, p1).getRect() else: # The only benefit to pass the src_size list is to avoid any # rounding error in the transformation computed in `get_plot_qrect` src_w, src_h = src_size destw, desth = compute_trimageitems_original_size(items, src_w, src_h) data = get_image_from_plot(plot, p0, p1, destw=destw, desth=desth, apply_lut=apply_lut, add_images=add_images, apply_interpolation=apply_interpolation, original_resolution=original_resolution) if adjust_range is None: return data dtype = None for item in items: if dtype is None or item.data.dtype.itemsize > dtype.itemsize: dtype = item.data.dtype if adjust_range == 'normalize': from plotpy import io data = io.scale_data_to_dtype(data, dtype=dtype) else: data = np.array(data, dtype=dtype) return data def get_image_in_shape(obj, norm_range=False, item_type=None, apply_lut=False, apply_interpolation=False): """Return image array from rectangle shape""" x0, y0, x1, y1 = obj.get_rect() (x0, x1), (y0, y1) = sorted([x0, x1]), sorted([y0, y1]) xc0, yc0 = axes_to_canvas(obj, x0, y0) xc1, yc1 = axes_to_canvas(obj, x1, y1) adjust_range = 'normalize' if norm_range else 'original' return get_image_from_qrect(obj.plot(), QPointF(xc0, yc0), QPointF(xc1, yc1), src_size=(x1-x0, y1-y0), adjust_range=adjust_range, item_type=item_type, apply_lut=apply_lut, apply_interpolation=apply_interpolation, original_resolution=True)
[docs]def get_image_from_plot(plot, p0, p1, destw=None, desth=None, add_images=False, apply_lut=False, apply_interpolation=False, original_resolution=False): """ Return pixel data of a rectangular plot area (image items only) p0, p1: resp. top-left and bottom-right points (`QPointF` objects) apply_lut: apply contrast settings add_images: add superimposed images (instead of replace by the foreground) .. warning:: Support only the image items implementing the `IExportROIImageItemType` interface, i.e. this does *not* support `XYImageItem` objects """ if destw is None: destw = p1.x()-p0.x()+1 if desth is None: desth = p1.y()-p0.y()+1 items = plot.get_items(item_type=IExportROIImageItemType) qrect = get_plot_qrect(plot, p0, p1) return assemble_imageitems(items, qrect, destw, desth,# align=4, add_images=add_images, apply_lut=apply_lut, apply_interpolation=apply_interpolation, original_resolution=original_resolution) #============================================================================== # Image with custom X, Y axes #==============================================================================
def to_bins(x): """Convert point center to point bounds""" bx = np.zeros((x.shape[0]+1,), float) bx[1:-1] = (x[:-1]+x[1:])/2 bx[0] = x[0]-(x[1]-x[0])/2 bx[-1] = x[-1]+(x[-1]-x[-2])/2 return bx
[docs]class XYImageItem(RawImageItem): """ Construct an image item with non-linear X/Y axes * x: 1D NumPy array, must be increasing * y: 1D NumPy array, must be increasing * data: 2D NumPy array * param (optional): image parameters (:py:class:`plotpy.styles.XYImageParam` instance) """ __implements__ = (IBasePlotItem, IBaseImageItem, ISerializableType) def __init__(self, x=None, y=None, data=None, param=None): # if x and y are not increasing arrays, sort them and data accordingly if not np.all(np.diff(x) >= 0): x_idx = np.argsort(x) x = x[x_idx] data = data[:, x_idx] if not np.all(np.diff(y) >= 0): y_idx = np.argsort(y) y = y[y_idx] data = data[y_idx, :] super(XYImageItem, self).__init__(data, param) self.x = None self.y = None if x is not None and y is not None: self.set_xy(x, y) #---- BaseImageItem API ---------------------------------------------------
[docs] def get_default_param(self): """Return instance of the default imageparam DataSet""" return XYImageParam(_("Image")) #---- Pickle methods ------------------------------------------------------
def __reduce__(self): fname = self.get_filename() if fname is None: fn_or_data = self.data else: fn_or_data = fname state = (self.imageparam, self.get_lut_range(), self.x, self.y, fn_or_data, self.z()) res = ( self.__class__, (), state ) return res def __setstate__(self, state): param, lut_range, x, y, fn_or_data, z = state self.imageparam = param if is_text_string(fn_or_data): self.set_filename(fn_or_data) self.load_data(lut_range) elif fn_or_data is not None: # should happen only with previous API self.set_data(fn_or_data, lut_range=lut_range) self.set_xy(x, y) self.setZ(z) self.imageparam.update_image(self)
[docs] def serialize(self, writer): """Serialize object to HDF5 writer""" super(XYImageItem, self).serialize(writer) writer.write(self.x, group_name='Xdata') writer.write(self.y, group_name='Ydata')
[docs] def deserialize(self, reader): """Deserialize object from HDF5 reader""" super(XYImageItem, self).deserialize(reader) x = reader.read(group_name='Xdata', func=reader.read_array) y = reader.read(group_name='Ydata', func=reader.read_array) self.set_xy(x, y) #---- Public API ----------------------------------------------------------
def set_xy(self, x, y): ni, nj = self.data.shape x = np.array(x, float) y = np.array(y, float) if not np.all(np.diff(x) >= 0): raise ValueError("x must be an increasing 1D array") if not np.all(np.diff(y) >= 0): raise ValueError("y must be an increasing 1D array") if x.shape[0] == nj: self.x = to_bins(x) elif x.shape[0] == nj+1: self.x = x else: raise IndexError("x must be a 1D array of length %d or %d" \ % (nj, nj+1)) if y.shape[0] == ni: self.y = to_bins(y) elif y.shape[0] == ni+1: self.y = y else: raise IndexError("y must be a 1D array of length %d or %d" \ % (ni, ni+1)) self.bounds = QRectF(QPointF(self.x[0], self.y[0]), QPointF(self.x[-1], self.y[-1])) self.update_border() #--- BaseImageItem API ----------------------------------------------------
[docs] def get_filter(self, filterobj, filterparam): """Provides a filter object over this image's content""" return XYImageFilterItem(self, filterobj, filterparam)
def draw_image(self, painter, canvasRect, src_rect, dst_rect, xMap, yMap): xytr = (self.x, self.y, src_rect) dst_rect = tuple([int(i) for i in dst_rect]) dest = _scale_xy(self.data, xytr, self._offscreen, dst_rect, self.lut, self.interpolate) qrect = QRectF(QPointF(dest[0], dest[1]), QPointF(dest[2], dest[3])) painter.drawImage(qrect, self._image, qrect)
[docs] def get_pixel_coordinates(self, xplot, yplot): """Return (image) pixel coordinates (from plot coordinates)""" return self.x.searchsorted(xplot), self.y.searchsorted(yplot)
[docs] def get_plot_coordinates(self, xpixel, ypixel): """Return plot coordinates (from image pixel coordinates)""" return self.x[int(pixelround(xpixel))], self.y[int(pixelround(ypixel))]
def get_x_values(self, i0, i1): return self.x[i0:i1] def get_y_values(self, j0, j1): return self.y[j0:j1]
[docs] def get_closest_coordinates(self, x, y): """Return closest image pixel coordinates""" i, j = self.get_closest_indexes(x, y) return self.x[i], self.y[j] #---- IBasePlotItem API ---------------------------------------------------
def types(self): return (IImageItemType, IVoiImageItemType, IColormapImageItemType, ITrackableItemType, ISerializableType, ICSImageItemType) #---- IBaseImageItem API -------------------------------------------------- def can_setfullscale(self): return True def can_sethistogram(self): return True
assert_interfaces_valid(XYImageItem) #============================================================================== # RGB Image with alpha channel #==============================================================================
[docs]class RGBImageItem(ImageItem): """ Construct a RGB/RGBA image item * data: NumPy array of uint8 (shape: NxMx[34] -- 3: RGB, 4: RGBA) (last dimension: 0: Red, 1: Green, 2: Blue {, 3:Alpha}) * param (optional): image parameters (:py:class:`plotpy.styles.RGBImageParam` instance) """ __implements__ = (IBasePlotItem, IBaseImageItem, ISerializableType) def __init__(self, data=None, param=None): self.orig_data = None super(RGBImageItem, self).__init__(data, param) self.lut = None #---- BaseImageItem API ---------------------------------------------------
[docs] def get_default_param(self): """Return instance of the default imageparam DataSet""" return RGBImageParam(_("Image")) #---- Public API ----------------------------------------------------------
def recompute_alpha_channel(self): data = self.orig_data if self.orig_data is None: return H, W, NC = data.shape R = data[..., 0].astype(np.uint32) G = data[..., 1].astype(np.uint32) B = data[..., 2].astype(np.uint32) use_alpha = self.imageparam.alpha_mask alpha = self.imageparam.alpha if NC > 3 and use_alpha: A = data[..., 3].astype(np.uint32) else: A = np.zeros((H, W), np.uint32) A[:,:]=int(255*alpha) self.data[:,:] = (A<<24)+(R<<16)+(G<<8)+B #--- BaseImageItem API ---------------------------------------------------- # Override lut/bg handling def set_lut_range(self, range): pass def set_background_color(self, qcolor): self.lut = None def set_color_map(self, name_or_table): self.lut = None #---- RawImageItem API ----------------------------------------------------
[docs] def load_data(self): """ Load data from *filename* *filename* has been set using method 'set_filename' """ data = io.imread(self.get_filename(), to_grayscale=False) self.set_data(data)
def set_data(self, data): H, W, NC = data.shape self.orig_data = data self.data = np.empty((H, W), np.uint32) self.recompute_alpha_channel() self.update_bounds() self.update_border() self.lut = None #---- IBasePlotItem API --------------------------------------------------- def types(self): return (IImageItemType, ITrackableItemType, ISerializableType) #---- IBaseImageItem API -------------------------------------------------- def can_setfullscale(self): return True def can_sethistogram(self): return False
assert_interfaces_valid(RGBImageItem) #============================================================================== # Masked Image #============================================================================== class MaskedArea(object): """Defines masked areas for a masked image item""" def __init__(self, geometry=None, x0=None, y0=None, x1=None, y1=None, inside=None): self.geometry = geometry self.x0 = x0 self.y0 = y0 self.x1 = x1 self.y1 = y1 self.inside = inside def __eq__(self, other): return self.geometry == other.geometry and self.x0 == other.x0 and \ self.y0 == other.y0 and self.x1 == other.x1 and \ self.y1 == other.y1 and self.inside == other.inside def serialize(self, writer): """Serialize object to HDF5 writer""" for name in ('geometry', 'inside', 'x0', 'y0', 'x1', 'y1'): writer.write(getattr(self, name), name) def deserialize(self, reader): """Deserialize object from HDF5 reader""" self.geometry = reader.read('geometry') self.inside = reader.read('inside') for name in ('x0', 'y0', 'x1', 'y1'): setattr(self, name, reader.read(name, func=reader.read_float))
[docs]class MaskedImageItem(ImageItem): """ Construct a masked image item * data: 2D NumPy array * mask (optional): 2D NumPy array * param (optional): image parameters (:py:class:`plotpy.styles.MaskedImageParam` instance) """ __implements__ = (IBasePlotItem, IBaseImageItem, IHistDataSource, IVoiImageItemType) def __init__(self, data=None, mask=None, param=None): self.orig_data = None self._mask = mask self._mask_filename = None self._masked_areas = [] super(MaskedImageItem, self).__init__(data, param) #---- BaseImageItem API ---------------------------------------------------
[docs] def get_default_param(self): """Return instance of the default imageparam DataSet""" return MaskedImageParam(_("Image")) #---- Pickle methods -------------------------------------------------------
def __reduce__(self): fname = self.get_filename() if fname is None: fn_or_data = self.data else: fn_or_data = fname state = (self.imageparam, self.get_lut_range(), fn_or_data, self.z(), self.get_mask_filename(), self.get_masked_areas()) res = ( self.__class__, (), state ) return res def __setstate__(self, state): param, lut_range, fn_or_data, z, mask_fname, old_masked_areas = state if old_masked_areas and isinstance(old_masked_areas[0], MaskedArea): masked_areas = old_masked_areas else: # Compatibility with old format masked_areas = [] for geometry, x0, y0, x1, y1, inside in old_masked_areas: area = MaskedArea(geometry=geometry, x0=x0, y0=y0, x1=x1, y1=y1, inside=inside) masked_areas.append(area) self.imageparam = param if is_text_string(fn_or_data): self.set_filename(fn_or_data) self.load_data(lut_range) elif fn_or_data is not None: # should happen only with previous API self.set_data(fn_or_data, lut_range=lut_range) self.setZ(z) self.imageparam.update_image(self) if mask_fname is not None: self.set_mask_filename(mask_fname) self.load_mask_data() elif masked_areas and self.data is not None: self.set_masked_areas(masked_areas) self.apply_masked_areas()
[docs] def serialize(self, writer): """Serialize object to HDF5 writer""" super(MaskedImageItem, self).serialize(writer) writer.write(self.get_mask_filename(), group_name='mask_fname') writer.write_object_list(self._masked_areas, 'masked_areas')
[docs] def deserialize(self, reader): """Deserialize object from HDF5 reader""" super(MaskedImageItem, self).deserialize(reader) mask_fname = reader.read(group_name='mask_fname', func=reader.read_unicode) masked_areas = reader.read_object_list('masked_areas', MaskedArea) if mask_fname: self.set_mask_filename(mask_fname) self.load_mask_data() elif masked_areas and self.data is not None: self.set_masked_areas(masked_areas) self.apply_masked_areas() #---- Public API -----------------------------------------------------------
def update_mask(self): if isinstance(self.data, np.ma.MaskedArray): self.data.set_fill_value(self.imageparam.filling_value)
[docs] def set_mask(self, mask): """Set image mask""" self.data.mask = mask
[docs] def get_mask(self): """Return image mask""" return self.data.mask
[docs] def set_mask_filename(self, fname): """ Set mask filename There are two ways for pickling mask data of `MaskedImageItem` objects: 1. using the mask filename (as for data itself) 2. using the mask areas (`MaskedAreas` instance, see set_mask_areas) When saving objects, the first method is tried and then, if no filename has been defined for mask data, the second method is used. """ self._mask_filename = fname
def get_mask_filename(self): return self._mask_filename def load_mask_data(self): data = io.imread(self.get_mask_filename(), to_grayscale=True) self.set_mask(data) self._mask_changed()
[docs] def set_masked_areas(self, areas): """Set masked areas (see set_mask_filename)""" self._masked_areas = areas
def get_masked_areas(self): return self._masked_areas def add_masked_area(self, geometry, x0, y0, x1, y1, inside): area = MaskedArea(geometry=geometry, x0=x0, y0=y0, x1=x1, y1=y1, inside=inside) for _area in self._masked_areas: if area == _area: return self._masked_areas.append(area) def _mask_changed(self): """Emit the :py:data:`plotpy.baseplot.BasePlot.SIG_MASK_CHANGED` signal""" plot = self.plot() if plot is not None: plot.SIG_MASK_CHANGED.emit(self)
[docs] def apply_masked_areas(self): """Apply masked areas""" for area in self._masked_areas: if area.geometry == 'rectangular': self.mask_rectangular_area(area.x0, area.y0, area.x1, area.y1, area.inside, trace=False, do_signal=False) else: self.mask_circular_area(area.x0, area.y0, area.x1, area.y1, area.inside, trace=False, do_signal=False) self._mask_changed()
[docs] def mask_all(self): """Mask all pixels""" self.data.mask = True self._mask_changed()
[docs] def unmask_all(self): """Unmask all pixels""" self.data.mask = np.ma.nomask self.set_masked_areas([]) self._mask_changed()
[docs] def mask_rectangular_area(self, x0, y0, x1, y1, inside=True, trace=True, do_signal=True): """ Mask rectangular area If inside is True (default), mask the inside of the area Otherwise, mask the outside """ ix0, iy0, ix1, iy1 = self.get_closest_index_rect(x0, y0, x1, y1) if inside: self.data[iy0:iy1, ix0:ix1] = np.ma.masked else: indexes = np.ones(self.data.shape, dtype=np.bool) indexes[iy0:iy1, ix0:ix1] = False self.data[indexes] = np.ma.masked if trace: self.add_masked_area('rectangular', x0, y0, x1, y1, inside) if do_signal: self._mask_changed()
[docs] def mask_circular_area(self, x0, y0, x1, y1, inside=True, trace=True, do_signal=True): """ Mask circular area, inside the rectangle (x0, y0, x1, y1), i.e. circle with a radius of ``.5\*(x1-x0)`` If inside is True (default), mask the inside of the area Otherwise, mask the outside """ ix0, iy0, ix1, iy1 = self.get_closest_index_rect(x0, y0, x1, y1) xc, yc = .5*(x0+x1), .5*(y0+y1) radius = .5*(x1-x0) xdata, ydata = self.get_x_values(ix0, ix1), self.get_y_values(iy0, iy1) for ix in range(ix0, ix1): for iy in range(iy0, iy1): distance = np.sqrt((xdata[ix-ix0]-xc)**2+(ydata[iy-iy0]-yc)**2) if inside: if distance <= radius: self.data[iy, ix] = np.ma.masked elif distance > radius: self.data[iy, ix] = np.ma.masked if not inside: self.mask_rectangular_area(x0, y0, x1, y1, inside, trace=False) if trace: self.add_masked_area('circular', x0, y0, x1, y1, inside) if do_signal: self._mask_changed()
[docs] def is_mask_visible(self): """Return mask visibility""" return self.imageparam.show_mask
[docs] def set_mask_visible(self, state): """Set mask visibility""" self.imageparam.show_mask = state plot = self.plot() if plot is not None: plot.replot() #---- BaseImageItem API ----------------------------------------------------
def draw_image(self, painter, canvasRect, src_rect, dst_rect, xMap, yMap): ImageItem.draw_image(self, painter, canvasRect, src_rect, dst_rect, xMap, yMap) if self.data is None: return if self.is_mask_visible(): _a, _b, bg, _cmap = self.lut alpha_masked = np.uint32(255*self.imageparam.alpha_masked+0.5 ).clip(0, 255) << 24 alpha_unmasked = np.uint32(255*self.imageparam.alpha_unmasked+0.5 ).clip(0, 255) << 24 cmap = np.array([np.uint32(0x000000 & 0xffffff) | alpha_unmasked, np.uint32(0xffffff & 0xffffff) | alpha_masked], dtype=np.uint32) lut = (1, 0, bg, cmap) shown_data = np.ma.getmaskarray(self.data) src2 = self._rescale_src_rect(src_rect) dst_rect = tuple([int(i) for i in dst_rect]) dest = _scale_rect(shown_data, src2, self._offscreen, dst_rect, lut, (INTERP_NEAREST,)) qrect = QRectF(QPointF(dest[0], dest[1]), QPointF(dest[2], dest[3])) painter.drawImage(qrect, self._image, qrect) #---- RawImageItem API -----------------------------------------------------
[docs] def set_data(self, data, lut_range=None): """ Set Image item data * data: 2D NumPy array * lut_range: LUT range -- tuple (levelmin, levelmax) """ ImageItem.set_data(self, data, lut_range) self.orig_data = data self.data = data.view(np.ma.MaskedArray) self.set_mask(self._mask) self._mask = None # removing reference to this temporary array if self.imageparam.filling_value is None: self.imageparam.filling_value = self.data.get_fill_value() # self.data.harden_mask() self.update_mask() #============================================================================== # Image filter #============================================================================== #TODO: Implement get_filter methods for image items other than XYImageItem!
[docs]class ImageFilterItem(BaseImageItem): """ Construct a rectangular area image filter item * image: :py:class:`plotpy.image.RawImageItem` instance * filter: function (x, y, data) --> data * param: image filter parameters (:py:class:`plotpy.styles.ImageFilterParam` instance) """ __implements__ = (IBasePlotItem, IBaseImageItem) _can_select = True _can_resize = True _can_move = True def __init__(self, image, filter, param): self.use_source_cmap = None self.image = None # BaseImageItem constructor will try to set this # item's color map using the method 'set_color_map' super(ImageFilterItem, self).__init__(param=param) self.border_rect.set_style("plot", "shape/imagefilter") self.image = image self.filter = filter self.imagefilterparam = param self.imagefilterparam.update_imagefilter(self) #---- Public API -----------------------------------------------------------
[docs] def set_image(self, image): """ Set the image item on which the filter will be applied * image: :py:class:`plotpy.image.RawImageItem` instance """ self.image = image
[docs] def set_filter(self, filter): """ Set the filter function * filter: function (x, y, data) --> data """ self.filter = filter #---- QwtPlotItem API ------------------------------------------------------
def boundingRect(self): x0, y0, x1, y1 = self.border_rect.get_rect() return QRectF(x0, y0, x1-x0, y1-y0) #---- IBasePlotItem API ---------------------------------------------------- def update_item_parameters(self): BaseImageItem.update_item_parameters(self) self.imagefilterparam.update_param(self) def get_item_parameters(self, itemparams): BaseImageItem.get_item_parameters(self, itemparams) self.update_item_parameters() itemparams.add("ImageFilterParam", self, self.imagefilterparam) def set_item_parameters(self, itemparams): update_dataset(self.imagefilterparam, itemparams.get("ImageFilterParam"), visible_only=True) self.imagefilterparam.update_imagefilter(self) BaseImageItem.set_item_parameters(self, itemparams)
[docs] def move_local_point_to(self, handle, pos, ctrl=None): """Move a handle as returned by hit_test to the new position pos ctrl: True if <Ctrl> button is being pressed, False otherwise""" npos = canvas_to_axes(self, pos) self.border_rect.move_point_to(handle, npos)
[docs] def move_local_shape(self, old_pos, new_pos): """Translate the shape such that old_pos becomes new_pos in canvas coordinates""" old_pt = canvas_to_axes(self, old_pos) new_pt = canvas_to_axes(self, new_pos) self.border_rect.move_shape(old_pt, new_pt) if self.plot(): self.plot().SIG_ITEM_MOVED.emit(self, *(old_pt+new_pt))
[docs] def move_with_selection(self, delta_x, delta_y): """ Translate the shape together with other selected items delta_x, delta_y: translation in plot coordinates """ self.border_rect.move_with_selection(delta_x, delta_y)
def set_color_map(self, name_or_table): if self.use_source_cmap: if self.image is not None: self.image.set_color_map(name_or_table) else: BaseImageItem.set_color_map(self, name_or_table) def get_color_map(self): if self.use_source_cmap: return self.image.get_color_map() else: return BaseImageItem.get_color_map(self) def get_lut_range(self): if self.use_source_cmap: return self.image.get_lut_range() else: return BaseImageItem.get_lut_range(self) def set_lut_range(self, lut_range): if self.use_source_cmap: self.image.set_lut_range(lut_range) else: BaseImageItem.set_lut_range(self, lut_range) #---- IBaseImageItem API --------------------------------------------------- def types(self): return (IImageItemType, IVoiImageItemType, IColormapImageItemType, ITrackableItemType) def can_setfullscale(self): return False def can_sethistogram(self): return True
[docs]class XYImageFilterItem(ImageFilterItem): """ Construct a rectangular area image filter item * image: :py:class:`plotpy.image.XYImageItem` instance * filter: function (x, y, data) --> data * param: image filter parameters (:py:class:`plotpy.styles.ImageFilterParam` instance) """ def __init__(self, image, filter, param): ImageFilterItem.__init__(self, image, filter, param)
[docs] def set_image(self, image): """ Set the image item on which the filter will be applied * image: :py:class:`plotpy.image.XYImageItem` instance """ ImageFilterItem.set_image(self, image)
def draw_image(self, painter, canvasRect, src_rect, dst_rect, xMap, yMap): bounds = self.boundingRect() filt_qrect = bounds & self.image.boundingRect() x0, y0, x1, y1 = filt_qrect.getCoords() i0, i1 = xMap.transform(x0), xMap.transform(x1) j0, j1 = yMap.transform(y0), yMap.transform(y1) dstRect = QRect(i0, j0, i1-i0, j1-j0) if not dstRect.intersects(canvasRect): return x, y, data = self.image.get_data(x0, y0, x1, y1) new_data = self.filter(x, y, data) self.data = new_data if self.use_source_cmap: lut = self.image.lut else: lut = self.lut dest = _scale_xy(new_data, (x, y, src_rect), self._offscreen, dstRect.getCoords(), lut, self.interpolate) qrect = QRectF(QPointF(dest[0], dest[1]), QPointF(dest[2], dest[3])) painter.drawImage(qrect, self._image, qrect)
assert_interfaces_valid(ImageFilterItem) #============================================================================== # 2-D Histogram #==============================================================================
[docs]class Histogram2DItem(BaseImageItem): """ Construct a 2D histogram item * X: data (1-D array) * Y: data (1-D array) * param (optional): style parameters (:py:class:`plotpy.styles.Histogram2DParam` instance) """ __implements__ = (IBasePlotItem, IBaseImageItem, IHistDataSource, IVoiImageItemType,) def __init__(self, X, Y, param=None, Z=None): if param is None: param = ImageParam(_("Image")) self._z = Z # allows set_bins to super(Histogram2DItem, self).__init__(param=param) # Set by parameters self.nx_bins = 0 self.ny_bins = 0 self.logscale = None # internal use self._x = None self._y = None # Histogram parameters self.histparam = param self.histparam.update_histogram(self) self.set_lut_range([0, 10.]) self.set_data(X, Y, Z) #---- Public API -----------------------------------------------------------
[docs] def set_bins(self, NX, NY): """Set histogram bins""" self.nx_bins = NX self.ny_bins = NY self.data = np.zeros((self.ny_bins, self.nx_bins), float) if self._z is not None: self.data_tmp = np.zeros((self.ny_bins, self.nx_bins), float)
[docs] def set_data(self, X, Y, Z=None): """Set histogram data""" self._x = X self._y = Y self._z = Z self.bounds = QRectF(QPointF(X.min(), Y.min()), QPointF(X.max(), Y.max())) self.update_border() #---- QwtPlotItem API ------------------------------------------------------
fill_canvas = True def draw_image(self, painter, canvasRect, src_rect, dst_rect, xMap, yMap): computation = self.histparam.computation i1, j1, i2, j2 = src_rect if computation == -1 or self._z is None: self.data[:,:] = 0.0 nmax = histogram2d(self._x, self._y, i1, i2, j1, j2, self.data, self.logscale) else: self.data_tmp[:,:] = 0.0 if computation in (2, 4): # sum, avg self.data[:,:] = 0.0 elif computation in (1, 5): # min, argmin val = np.inf self.data[:,:] = val elif computation in (0, 6): # max, argmax val = -np.inf self.data[:,:] = val elif computation==3: self.data[:,:] = 1. histogram2d_func(self._x, self._y, self._z, i1, i2, j1, j2, self.data_tmp, self.data, computation) if computation in (0, 1, 5, 6): self.data[self.data==val] = np.nan else: self.data[self.data_tmp==0.0] = np.nan if self.histparam.auto_lut: nmin = _nanmin(self.data) nmax = _nanmax(self.data) self.set_lut_range([nmin, nmax]) self.plot().update_colormap_axis(self) src_rect = (0, 0, self.nx_bins, self.ny_bins) drawfunc = lambda *args: BaseImageItem.draw_image(self, *args) if self.fill_canvas: x1, y1, x2, y2 = canvasRect.getCoords() drawfunc(painter, canvasRect, src_rect, (x1, y1, x2, y2), xMap, yMap) else: drawfunc(painter, canvasRect, src_rect, dst_rect, xMap, yMap) #---- IBasePlotItem API --------------------------------------------------- def types(self): return (IColormapImageItemType, IImageItemType, ITrackableItemType, IVoiImageItemType, IColormapImageItemType, ICSImageItemType) def update_item_parameters(self): BaseImageItem.update_item_parameters(self) self.histparam.update_param(self) def get_item_parameters(self, itemparams): BaseImageItem.get_item_parameters(self, itemparams) itemparams.add("Histogram2DParam", self, self.histparam) def set_item_parameters(self, itemparams): update_dataset(self.histparam, itemparams.get("Histogram2DParam"), visible_only=True) self.histparam = itemparams.get("Histogram2DParam") self.histparam.update_histogram(self) BaseImageItem.set_item_parameters(self, itemparams) #---- IBaseImageItem API -------------------------------------------------- def can_setfullscale(self): return True def can_sethistogram(self): return True
[docs] def get_histogram(self, nbins): """interface de IHistDataSource""" if self.data is None: return [0,], [0, 1] _min = _nanmin(self.data) _max = _nanmax(self.data) if self.data.dtype in (np.float64, np.float32): bins = np.unique(np.array(np.linspace(_min, _max, nbins+1), dtype=self.data.dtype)) else: bins = np.arange(_min, _max+2, dtype=self.data.dtype) res2 = np.zeros((bins.size+1,), np.uint32) _histogram(self.data.flatten(), bins, res2) #toc("histo2") res = res2[1:-1], bins return res
assert_interfaces_valid(Histogram2DItem) #============================================================================== # Image Plot Widget #==============================================================================
[docs]class ImagePlot(CurvePlot): """ Construct a 2D curve and image plotting widget (this class inherits :py:class:`plotpy.curve.CurvePlot`) * parent: parent widget * title: plot title (string) * xlabel, ylabel, zlabel: resp. bottom, left and right axis titles (strings) * xunit, yunit, zunit: resp. bottom, left and right axis units (strings) * yreverse: reversing y-axis direction of increasing values (bool) * aspect_ratio: height to width ratio (float) * lock_aspect_ratio: locking aspect ratio (bool) """ DEFAULT_ITEM_TYPE = IImageItemType AUTOSCALE_TYPES = (CurveItem, BaseImageItem, PolygonMapItem) AXIS_CONF_OPTIONS = ("image_axis", "color_axis", "image_axis", None) def __init__(self, parent=None, title=None, xlabel=None, ylabel=None, zlabel=None, xunit=None, yunit=None, zunit=None, yreverse=True, aspect_ratio=1.0, lock_aspect_ratio=True, gridparam=None, section="plot"): self.lock_aspect_ratio = lock_aspect_ratio if zlabel is not None: if ylabel is not None and not is_text_string(ylabel): ylabel = ylabel[0] ylabel = (ylabel, zlabel) if zunit is not None: if yunit is not None and not is_text_string(yunit): yunit = yunit[0] yunit = (yunit, zunit) super(ImagePlot, self).__init__(parent=parent, title=title, xlabel=xlabel, ylabel=ylabel, xunit=xunit, yunit=yunit, gridparam=gridparam, section=section) self.colormap_axis = self.Y_RIGHT axiswidget = self.axisWidget(self.colormap_axis) axiswidget.setColorBarEnabled(True) self.enableAxis(self.colormap_axis) self.__aspect_ratio = None self.set_axis_direction('left', yreverse) self.set_aspect_ratio(aspect_ratio, lock_aspect_ratio) self.replot() # Workaround for the empty image widget bug #---- QwtPlot API ---------------------------------------------------------
[docs] def showEvent(self, event): """Override BasePlot method""" if self.lock_aspect_ratio: self._start_autoscaled = True CurvePlot.showEvent(self, event) #---- CurvePlot API -------------------------------------------------------
[docs] def do_zoom_view(self, dx, dy): """Reimplement CurvePlot method""" CurvePlot.do_zoom_view(self, dx, dy, lock_aspect_ratio=self.lock_aspect_ratio)
[docs] def do_zoom_rect_view(self, start, end): """Reimplement CurvePlot method""" CurvePlot.do_zoom_rect_view(self, start, end) if self.lock_aspect_ratio: self.apply_aspect_ratio() #---- Levels histogram-related API ----------------------------------------
[docs] def update_lut_range(self, _min, _max): """update the LUT scale""" #self.set_items_lut_range(_min, _max, replot=False) self.updateAxes() #---- Image scale/aspect ratio -related API -------------------------------
def set_full_scale(self, item): if item.can_setfullscale(): bounds = item.boundingRect() self.set_plot_limits(bounds.left(), bounds.right(), bounds.top(), bounds.bottom())
[docs] def get_current_aspect_ratio(self): """Return current aspect ratio""" dx = self.axisScaleDiv(self.X_BOTTOM).range() dy = self.axisScaleDiv(self.Y_LEFT).range() h = self.canvasMap(self.Y_LEFT).pDist() w = self.canvasMap(self.X_BOTTOM).pDist() return fabs((h*dx)/(w*dy))
[docs] def get_aspect_ratio(self): """Return aspect ratio""" return self.__aspect_ratio
[docs] def set_aspect_ratio(self, ratio=None, lock=None): """Set aspect ratio""" if ratio is not None: self.__aspect_ratio = ratio if lock is not None: self.lock_aspect_ratio = lock self.apply_aspect_ratio()
def apply_aspect_ratio(self, full_scale=False): if not self.isVisible(): return ymap = self.canvasMap(self.Y_LEFT) xmap = self.canvasMap(self.X_BOTTOM) h = ymap.pDist() w = xmap.pDist() dx1, dy1 = xmap.sDist(), fabs(ymap.sDist()) x0, y0 = xmap.s1(), ymap.s1() x1, y1 = xmap.s2(), ymap.s2() if y0 > y1: y0, y1 = y1, y0 if full_scale: if w == 0: return # avoid division by zero dy2 = (h*dx1)/(w*self.__aspect_ratio) fix_yaxis = dy2 > dy1 else: fix_yaxis = True if fix_yaxis: if w == 0: return # avoid division by zero dy2 = (h*dx1)/(w*self.__aspect_ratio) delta_y = .5*(dy2-dy1) y0 -= delta_y y1 += delta_y else: if h == 0: return # avoid division by zero dx2 = (w*dy1*self.__aspect_ratio)/h delta_x = .5*(dx2-dx1) x0 -= delta_x x1 += delta_x self.set_plot_limits(x0, x1, y0, y1) #---- LUT/colormap-related API --------------------------------------------
[docs] def notify_colormap_changed(self): """Levels histogram range has changed""" item = self.get_last_active_item(IColormapImageItemType) if item is not None: self.update_colormap_axis(item) self.replot() self.SIG_LUT_CHANGED.emit(self)
def update_colormap_axis(self, item): if IColormapImageItemType not in item.types(): return zaxis = self.colormap_axis axiswidget = self.axisWidget(zaxis) self.setAxisScale(zaxis, item.min, item.max) # XXX: the colormap can't be displayed if min>max, to fix this # we should pass an inverted colormap along with _max, _min values axiswidget.setColorMap(QwtInterval(item.min, item.max), item.get_color_map()) self.updateAxes() #---- QwtPlot API ---------------------------------------------------------
[docs] def resizeEvent(self, event): """Reimplement Qt method to resize widget""" CurvePlot.resizeEvent(self, event) if self.lock_aspect_ratio: self.apply_aspect_ratio() self.replot() #---- BasePlot API --------------------------------------------------------
[docs] def add_item(self, item, z=None, autoscale=True): """ Add a *plot item* instance to this *plot widget* * item: :py:data:`qwt.QwtPlotItem` object implementing the :py:data:`plotpy.interfaces.IBasePlotItem` interface * z: item's z order (None -> z = max(self.get_items())+1) autoscale: True -> rescale plot to fit image bounds """ CurvePlot.add_item(self, item, z) if isinstance(item, BaseImageItem): parent = self.parent() if parent is not None: parent.setUpdatesEnabled(False) self.update_colormap_axis(item) if autoscale: self.do_autoscale() if parent is not None: parent.setUpdatesEnabled(True)
[docs] def set_active_item(self, item): """Override base set_active_item to change the grid's axes according to the selected item""" old_active = self.active_item CurvePlot.set_active_item(self, item) if item is not None and old_active is not item: self.update_colormap_axis(item)
[docs] def disable_unused_axes(self): """Disable unused axes""" CurvePlot.disable_unused_axes(self) self.enableAxis(self.colormap_axis)
[docs] def do_autoscale(self, replot=True, axis_id=None): """Do autoscale on all axes""" CurvePlot.do_autoscale(self, replot=False, axis_id=axis_id) self.updateAxes() if self.lock_aspect_ratio: self.replot() self.apply_aspect_ratio(full_scale=True) if replot: self.replot() self.SIG_PLOT_AXIS_CHANGED.emit(self)
[docs] def get_axesparam_class(self, item): """Return AxesParam dataset class associated to item's type""" if isinstance(item, BaseImageItem): return ImageAxesParam else: return CurvePlot.get_axesparam_class(self, item)
[docs] def edit_axis_parameters(self, axis_id): """Edit axis parameters""" #XXX: removed the following workaround as the associated bug can't be # reproduced anymore with plotpy 3. However, keeping the workaround # here (commented) as it could become useful eventually. #----- # #FIXME: without the following workaround, aspect ratio is changed # # when applying axis parameters # # (see also plotpy.styles.ItemParameters.update) # ratio = self.get_current_aspect_ratio() #----- if axis_id != self.colormap_axis: CurvePlot.edit_axis_parameters(self, axis_id) #----- # self.set_aspect_ratio(ratio=ratio) # self.replot() #-----