Source code for spartan.dense.extent

from spartan import util
from spartan.util import Assert
import numpy as np

[docs]class TileExtent(object): '''A rectangular tile of a distributed array. These correspond (roughly) to a `slice` taken from an array (without any step component). Arrays are indexed from the upper-left; for an array of shape (sx, sy, sz): (0,0...) is the upper-left corner of an array, and (sx,sy,sz...) the lower-right. Extents are represented by an upper-left corner (inclusive) and a lower right corner (exclusive): [ul, lr). In addition, they carry the shape of the array they are a part of; this is used to compute global position information. ''' @property
[docs] def size(self): return np.prod(self.shape)
@property
[docs] def shape(self): result = self.lr_array - self.ul_array result[result == 0] = 1 #util.log_info('Shape: %s', result) return tuple(result)
@property
[docs] def ndim(self): return len(self.lr)
def __reduce__(self): return (create, (self.ul, self.lr, self.array_shape))
[docs] def to_slice(self): return tuple([slice(ul, lr, None) for ul, lr in zip(self.ul, self.lr)])
def __repr__(self): return 'extent(' + ','.join('%s:%s' % (a, b) for a, b in zip(self.ul, self.lr)) + ')' def __getitem__(self, idx): return create([self.ul[idx]], [self.lr[idx]], [self.array_shape[idx]]) def __hash__(self): return hash(self.ul) #return hash(self.ul[-2:]) #return ravelled_pos(self.ul, self.array_shape) def __eq__(self, other): for i in range(len(self.ul)): if other.ul[i] != self.ul[i] or other.lr[i] != self.lr[i]: return False return True
[docs] def ravelled_pos(self): return ravelled_pos(self.ul, self.array_shape)
[docs] def to_global(self, idx, axis): '''Convert ``idx`` from a local offset in this tile to a global offset.''' if axis is not None: return idx + self.ul[axis] # first unravel idx to a local position local_idx = idx unravelled = [] shp = self.shape for i in range(len(shp)): unravelled.append(local_idx % shp[i]) local_idx /= shp[i] unravelled = np.array(list(reversed(unravelled))) unravelled += self.ul return ravelled_pos(unravelled, self.array_shape)
[docs] def add_dim(self): return create(self.ul + (0,), self.lr + (0,), self.array_shape + (1,))
[docs] def clone(self): return create(self.ul, self.lr, self.array_shape)
[docs]def create(ul, lr, array_shape): ''' Create a new extent with the given coordinates and array shape. :param ul: `tuple`: :param lr: :param array_shape: ''' ex = TileExtent() ex.ul = tuple(ul) ex.lr = tuple(lr) assert np.all(np.array(ex.lr) >= np.array(ex.ul)),\ 'Negative extent size: (%s, %s)' % (ul, lr) if array_shape is not None: ex.array_shape = tuple(array_shape) assert np.all(np.array(ex.lr) <= np.array(array_shape)),\ 'Extent lr (%s) falls outside of the array(%s)' % (lr, array_shape) else: ex.array_shape = None # cache some values as numpy arrays for faster access ex.ul_array = np.asarray(ex.ul, dtype=np.int) ex.lr_array = np.asarray(ex.lr, dtype=np.int) return ex
[docs]def from_shape(shp): return create(tuple([0] * len(shp)), tuple(v for v in shp), shp)
[docs]def unravelled_pos(idx, array_shape): ''' Unravel ``idx`` into an index into an array of shape ``array_shape``. :param idx: `int` :param array_shape: `tuple` :rtype: `tuple` indexing into ``array_shape`` ''' unravelled = [] for dim in reversed(array_shape): unravelled.append(idx % dim) idx /= dim return tuple(reversed(unravelled))
[docs]def ravelled_pos(idx, array_shape): rpos = 0 mul = 1 for i in range(len(array_shape) - 1, -1, -1): rpos += mul * idx[i] mul *= array_shape[i] return rpos
[docs]def find_overlapping(extents, region): ''' Return the extents that overlap with ``region``. :param extents: List of extents to search over. :param region: `Extent` to match. ''' for ex in extents: overlap = intersection(ex, region) if overlap is not None: yield (ex, overlap)
[docs]def compute_slice(base, idx): '''Return a new ``TileExtent`` representing ``base[idx]`` :param base: `TileExtent` :param idx: int, slice, or tuple(slice,...) ''' assert not np.isscalar(idx), idx if not isinstance(idx, tuple): idx = (idx,) ul = [] lr = [] array_shape = base.array_shape for i in range(len(base.ul)): if i >= len(idx): ul.append(base.ul[i]) lr.append(base.lr[i]) else: start, stop, step = idx[i].indices(base.shape[i]) ul.append(base.ul[i] + start) lr.append(base.ul[i] + stop) return create(ul, lr, array_shape)
[docs]def offset_from(base, other): ''' :param base: `TileExtent` to use as basis :param other: `TileExtent` into the same array. :rtype: A new extent using this extent as a basis, instead of (0,0,0...) ''' assert np.all(other.ul >= base.ul), (other, base) assert np.all(other.lr <= base.lr), (other, base) return create(np.array(other.ul) - np.array(base.ul), np.array(other.lr) - np.array(base.ul), other.array_shape)
[docs]def offset_slice(base, other): ''' :param base: `TileExtent` to use as basis :param other: `TileExtent` into the same array. :rtype: A slice representing the local offsets of ``other`` into this tile. ''' return offset_from(base, other).to_slice()
[docs]def from_slice(idx, shape): ''' Construct a `TileExtent` from a slice or tuple of slices. :param idx: int, slice, or tuple(slice...) :param shape: shape of the input array :rtype: `TileExtent` corresponding to ``idx``. ''' if not isinstance(idx, tuple): idx = (idx,) if len(idx) < len(shape): idx = tuple(list(idx) + [slice(None, None, None) for _ in range(len(shape) - len(idx))]) ul = [] lr = [] for i in range(len(shape)): dim = shape[i] slc = idx[i] if np.isscalar(slc): slc = int(slc) slc = slice(slc, slc + 1, None) if slc.start > 0: assert slc.start <= dim if slc.stop > 0: assert slc.stop <= dim indices = slc.indices(dim) ul.append(indices[0]) lr.append(indices[1]) return create(ul, lr, shape)
[docs]def intersection(a, b): ''' :rtype: The intersection of the 2 extents as a `TileExtent`, or None if the intersection is empty. ''' for i in range(len(a.lr)): if b.lr[i] < a.ul[i]: return None if a.lr[i] < b.ul[i]: return None Assert.eq(a.array_shape, b.array_shape) # if np.any(b.lr_array <= a.ul_array): return None # if np.any(a.lr_array <= b.ul_array): return None return create(np.maximum(b.ul_array, a.ul_array), np.minimum(b.lr_array, a.lr_array), a.array_shape)
TileExtent.intersection = intersection
[docs]def shape_for_reduction(input_shape, axis): ''' Return the shape for the result of applying a reduction along ``axis`` to an input of shape ``input_shape``. :param input_shape: :param axis: ''' if axis == None: return () input_shape = list(input_shape) del input_shape[axis] return input_shape
[docs]def shapes_match(offset, data): ''' Return true if the shape of ``data`` matches the extent ``offset``. :param offset: :param data: ''' return np.all(offset.shape == data.shape)
[docs]def drop_axis(ex, axis): if axis is None: return create((), (), ()) if axis < 0: axis = len(ex.ul) + axis ul = list(ex.ul) lr = list(ex.lr) shape = list(ex.array_shape) del ul[axis] del lr[axis] del shape[axis] return create(ul, lr, shape)
[docs]def index_for_reduction(index, axis): return drop_axis(index, axis)
[docs]def find_shape(extents): ''' Given a list of extents, return the shape of the array necessary to fit all of them. :param extents: ''' #util.log_info('Finding shape... %s', extents) shape = np.max([ex.lr for ex in extents], axis=0) shape[shape == 0] = 1 return tuple(shape)