"""
operators
=========
The classes defined in this module, implement different operators that
operate on input signals. These operators are used for defining
problems. opBase should be subclassed for creating new operators.
..
This module is based on MATLAB SPARCO Toolbox.
Copyright 2008, Ewout van den Berg and Michael P. Friedlander
http://www.cs.ubc.ca/labs/scl/sparco
.. codeauthor:: Amit Aides <amitibo@tx.technion.ac.il>
"""
from __future__ import division
import numpy as np
import numpy.fft as npfft
import scipy.fftpack as spfft
import rwt
from rwt import wavelets
[docs]class opBase(object):
"""
Base class for operators
Attributes
----------
name : string
Name of operator.
shape : (int, int)
The shape of the operator.
in_signal_shape : tuple of ints
The shape of the input signal.
out_signal_shape : tuple of ints
The shape of the output signal.
T : type(self)
The transpose of the operator.
Methods
-------
"""
def __init__(self, name, shape, in_signal_shape=None, out_signal_shape=None):
"""
Parameters
----------
name : string
Name of operator.
shape : (int, int)
The shape of the operator. `shape[1]` is the size of the input signal.
`shape[0]` is the size of the output signal.
in_signal_shape : tuple of integers, optional (default=None)
The shape of the input signal. The product of `in_signal_shape` should
be equal to `shape[1]`. If `None`, then it is set to (shape[1], 1).
out_signal_shape : tuple of ints
The shape of the output signal. The product of `out_signal_shape` should
be equal to `shape[1]`. If `in_signal_shape=None`, then it is set to
(shape[0], 1). If `out_signal_shape=None` and `shape[0]=shape[1]` then
`out_signal_shape=in_signal_shape`.
"""
if in_signal_shape==None:
in_signal_shape = (shape[1], 1)
out_signal_shape = (shape[0], 1)
elif out_signal_shape==None:
if shape[0]==shape[1]:
out_signal_shape = in_signal_shape
else:
out_signal_shape = (shape[0], 1)
assert np.prod(in_signal_shape)==shape[1], 'Input signal shape does not conform to the shape of the operator'
assert np.prod(out_signal_shape)==shape[0], 'Output signal shape does not conform to the shape of the operator'
self._name = name
self._shape = shape
self._in_signal_shape = in_signal_shape
self._out_signal_shape = out_signal_shape
self._conj = False
@property
def name(self):
"""Name of operator.
"""
return self._name
@property
def shape(self):
"""The shape of the operator.
"""
if self._conj:
return self._shape[::-1]
else:
return self._shape
@property
def in_signal_shape(self):
"""The shape of the input signal for the operator.
"""
if self._conj:
return self._out_signal_shape
else:
return self._in_signal_shape
@property
def out_signal_shape(self):
"""The shape of the output signal for the operator.
"""
if self._conj:
return self._in_signal_shape
else:
return self._out_signal_shape
@property
def T(self):
"""The transpose of the operator.
"""
import copy
new_copy = copy.copy(self)
new_copy._conj = True
return new_copy
def _checkDimensions(self, x):
"""Check that the size of the input signal is correct.
This function is called by the `__call__` method.
Parameters
==========
x : array
Input signal in columnstack order.
"""
if x.shape == (1, 1) and self._shape != (1, 1):
raise Exception('Operator-scalar multiplication not yet supported')
if x.shape[0] != self.shape[1]:
raise Exception('Incompatible dimensions')
if x.shape[1] != 1:
raise Exception('Operator-matrix multiplication not yet supported')
def _apply(self, x):
"""Apply the operator on the input signal. Should be overwritten by the operator.
This function is called by the `__call__` method.
Parameters
==========
x : array
Input signal in columnstack order.
"""
raise NotImplementedError()
[docs] def __call__(self, x):
x = x.reshape((-1, 1))
self._checkDimensions(x)
return self._apply(x).reshape(self.out_signal_shape)
[docs]class opMatrix(opBase):
"""
Operator that wraps a simple matrix.
"""
def __init__(self, A):
"""
Parameters
----------
A : array like, [m, n]
Matrix of dimension m, n.
"""
try:
self._A = np.array(A)
except:
raise Exception('Parameter A must be array like object')
assert self._A.ndim == 2, "opMatrix supports only 2D matrices"
m, n = self._A.shape
super(opMatrix, self).__init__(
name='Matrix',
shape=(m, n),
in_signal_shape=(n, 1),
out_signal_shape=(m, 1)
)
def _apply(self, x):
if not self._conj:
y = np.dim(self._A.T, x)
else:
h = np.dim(self._A, x)
return y
[docs]class opBlur(opBase):
"""
Two-dimensional blurring operator. creates a blurring operator
for M by N images. This function is used for the GPSR-based test
problems and is based on the implementation by Figueiredo, Nowak
and Wright, 2007.
Parameters
----------
shape : (int, int)
Shape of target images.
"""
def __init__(self, shape):
assert len(shape) == 2, "opBlur supports operations on 2D matrices only"
m, n = shape
size = m * n
super(opBlur, self).__init__(
name='Blur',
shape=(size, size),
in_signal_shape=shape
)
yc = int(m/2 + 1)
xc = int(n/2 + 1)
#
# Create blurring mask
#
h = np.zeros((m, n))
g = np.arange(-4, 5)
for i in g:
h[i+yc, g+xc] = 1 / (1 + i*i + g**2)
h = npfft.fftshift(h)
h /= h.sum()
self._h = npfft.fft2(h)
def _apply(self, x):
if not self._conj:
h = self._h
else:
h = self._h.conj()
y = npfft.ifft2(h * npfft.fft2(x.reshape(self._in_signal_shape))).reshape((-1, 1))
if np.isrealobj(x):
y = np.real(y)
return y
[docs]class opWavelet(opBase):
"""Wavelet operator.
Create an operator that applies a given wavelet transform to
a 2D input signal.
"""
def __init__(self, shape, family='Daubechies', filter=8, levels=5, type='min'):
"""
Parameters
==========
shape : (m, n)
Shape of the 2D input signal.
family : {'Daubechies', 'haar'}, optional
The family of the wavelet.
filter : integer, optional (default=8)
Length of the wavelet filter.
levels : integer, optional (default=5)
Number of levels in the transformation. Both `m` and `n` must
be divisible by 2**levels.
type : {'min', 'max', 'mid'}, optional (default='min')
Indicates what type of solution is desired; 'min' for minimum
phase, 'max' for maximum phase, and 'mid' for mid-phase
solutions.
"""
assert len(shape) == 2, "opWavelet supports operations on 2D matrices only"
size = shape[0] * shape[1]
super(opWavelet, self).__init__(
name='Wavelet',
shape=(size, size),
in_signal_shape=shape
)
family = family.lower()
if family == 'daubechies':
self._cl, self._ch, self._rl, self._rh = wavelets.waveletCoeffs('db%d' % filter)
elif family == 'haar':
self._cl, self._ch, self._rl, self._rh = wavelets.waveletCoeffs('db1')
else:
self._cl, self._ch, self._rl, self._rh = wavelets.waveletCoeffs(family)
self._level = levels
def _apply(self, x):
if self._conj:
wf = rwt.dwt
h0 = self._cl
h1 = self._ch
else:
wf = rwt.idwt
h0 = self._rl
h1 = self._rh
if np.isrealobj(x):
y, l = wf(x.reshape(self._in_signal_shape), h0, h1, self._level)
else:
[y1, l] = wf(x.real.reshape(self._in_signal_shape), h0, h1, self._level)
[y2, l] = wf(x.imag.reshape(self._in_signal_shape), h0, h1, self._level)
y = y1 + 1j*y2
y.shape = (-1, 1)
return y
[docs]class opFFT2d(opBase):
"""Two-dimensional fast Fourier transform (FFT) operator.
Create an operator that applies a normalized fourier transform to
a 2D input signal.
"""
def __init__(self, shape):
"""
Parameters
==========
shape : (m, n)
Shape of the 2D input signal.
"""
assert len(shape) == 2, "opFFT2d supports operations on 2D matrices only"
size = shape[0] * shape[1]
super(opFFT2d, self).__init__(
name='FFT2d',
shape=(size, size),
in_signal_shape=shape
)
self._normalization_coeff = np.sqrt(size)
def _apply(self, x):
if self._conj:
y = npfft.ifft2(x.reshape(self._in_signal_shape)) * self._normalization_coeff
else:
y = npfft.fft2(x.reshape(self._in_signal_shape)) / self._normalization_coeff
y = np.ascontiguousarray(y).reshape((-1, 1))
y = np.real_if_close(y, tol=1e6)
return y
[docs]class opDCT(opBase):
"""Arbitrary dimensional discrete cosine transform (DCT).
Create an operator that applies the discrete cosine transform
to vectors of arbitray dimension.
"""
def __init__(self, shape, axis=-1):
"""
Parameters
==========
shape : list of integers
Shape of the input signal.
axis : integer, optional (default=-1)
Axis along which the dct is computed. If -1 then the
transform is multidimensional(default=-1)
"""
_shape = [int(i) for i in shape]
assert list(shape) == _shape, "shape must be a list of integers"
assert axis > -1 or axis < len(shape), "axis must be either -1 or one of the dimension indices"
size = np.prod(shape)
super(opDCT, self).__init__(
name='DCT',
shape=(size, size),
in_signal_shape=shape
)
self._axis = axis
def _apply(self, x):
if self._conj:
f = spfft.idct
else:
f = spfft.dct
x = x.reshape(self._in_signal_shape)
if self._axis == -1:
y = x
for i in range(x.ndim):
y = f(y, axis=i, norm='ortho')
else:
y = f(x, axis=self._axis, norm='ortho')
y.shape = (-1, 1)
return y
[docs]class opDirac(opBase):
"""Identity operator
Create an operator whose output signal equals the input signal.
"""
def __init__(self, shape):
"""
Parameters
==========
shape : list of integers
Shape of the input signal.
"""
if isinstance(shape, (int, long)):
shape = [shape]
_shape = [int(i) for i in list(shape)]
assert list(shape) == _shape, "shape must be a list of integers"
size = np.prod(shape)
super(opDirac, self).__init__(name='Dirac', shape=(size, size), in_signal_shape=shape)
def _apply(self, x):
return x.copy()
[docs]class opFoG(opBase):
"""Concatenate a sequence of operators into a single operator.
"""
def __init__(self, operators_list):
"""
Parameters
==========
operators_list : list
List of operators. All the operators must be instances of
`opBase` or its subclasses. The `opFoG` operator applies
the operators to the input signal in reverse order, i.e.
starting with `operators_list[-1]`.
"""
if len(operators_list) == 0:
raise Exception('At least one operator must be specified')
#
# Check operator consistency and space
#
m, n = operators_list[0].shape
for oper in operators_list[1:]:
m_, n_ = oper.shape
if m_ != n:
raise Exception('Operator %s is not consistent with the previous operators' % oper.name)
n = n_
super(opFoG, self).__init__(
name='FoG',
shape=(m, n),
in_signal_shape=operators_list[-1].in_signal_shape,
out_signal_shape=operators_list[0].out_signal_shape
)
self._operators_list = operators_list
@property
def operators_list(self):
"""The list of operators that make up the opFoG.
"""
if self._conj:
return [op.T for op in self._operators_list[::-1]]
else:
return self._operators_list
def _apply(self, x):
if self._conj:
y = self._operators_list[0].T(x)
for oper in self._operators_list[1:]:
y = oper.T(y)
else:
y = self._operators_list[-1](x)
for oper in self._operators_list[-2::-1]:
y = oper(y)
return y
[docs]class op3DStack(opBase):
"""Extend an operator to process a stack of signals.
The op3DStack operator is useful for example when the input signal
is a stack of images and the base operator is applied to each
image separately.
"""
def __init__(self, operator, dim3):
"""
Parameters
==========
operator : instance of a subclass of opBase
The base operator. This operator is applied separately
to each of the sections that make up the stacked input
signal.
dim3 : integer
The size of the stack.
"""
if not isinstance(operator, opBase):
raise Exception('operator should be an instance of opBase.')
#
# Check operator consistency and space
#
m, n = operator.shape
in_signal_shape = operator.in_signal_shape
if in_signal_shape[1] == 1:
in_signal_shape = (in_signal_shape[0]*dim3, 1)
else:
in_signal_shape = (in_signal_shape[0], in_signal_shape[1], dim3)
out_signal_shape = operator.out_signal_shape
if out_signal_shape[1] == 1:
out_signal_shape = (out_signal_shape[0]*dim3, 1)
else:
out_signal_shape = (out_signal_shape[0], out_signal_shape[1], dim3)
super(op3DStack, self).__init__(
name='3DStack',
shape=(m*dim3, n*dim3),
in_signal_shape=in_signal_shape,
out_signal_shape=out_signal_shape
)
self._operator = operator
self._dim3 = dim3
def _apply(self, x):
if self._conj:
op = self._operator.T
else:
op = self._operator
y = []
for x_ in np.split(x, self._dim3):
y.append(op(x_))
return np.vstack(y)
[docs]class opRandMask(opBase):
"""Random binary mask.
The opRandMask operator creates and applies a random binary mask.
"""
def __init__(self, shape, fill_ratio):
"""
Parameters
==========
shape : list of integers
Shape of the input signal.
fill_ratio : float
Ratio of non zero (1) values in the mask.
"""
_shape = [int(i) for i in shape]
assert list(shape) == _shape, "shape must be a list of integers"
assert fill_ratio > 0 and fill_ratio < 1, "fill_ratio must be a float in the range (0, 1)"
size = np.prod(shape)
super(opRandMask, self).__init__(
name='RandomMask',
shape=(size, size),
in_signal_shape=shape
)
self._mask = np.zeros(shape, dtype=np.bool)
indices = np.arange(size)
np.random.shuffle(indices)
indices = indices[:int(size*fill_ratio)]
self._mask.ravel()[indices] = 1
def _apply(self, x):
x.shape = self.in_signal_shape
y = x * self._mask
y.shape = (-1, 1)
return y
def test_DCT():
"""
Test the opDCT operator
"""
from scipy.misc import lena
import matplotlib.pyplot as plt
from compsense.utilities import softThreshold, hardThreshold
img = lena().astype(np.double)
img /= img.max()
op = opDCT(img.shape)
img_conv = op.T(img)
img_recon = op(hardThreshold(img_conv, 0.5))
plt.figure()
plt.gray()
plt.imshow(img)
plt.figure()
plt.imshow(img_conv)
plt.figure()
plt.imshow(img_recon)
plt.show()
def test_FFT():
"""
Test the opFFT2d operator
"""
from scipy.misc import lena
import matplotlib.pyplot as plt
img = lena().astype(np.double)
img /= img.max()
op = opFFT2d(img.shape)
img_conv = op(img)
img_recon = op.T(img_conv)
plt.figure()
plt.gray()
plt.imshow(img)
plt.figure()
plt.imshow(np.abs(img_conv))
plt.figure()
plt.imshow(img_recon)
plt.show()
def test_RandomMask():
"""
Test the opRandMask operator
"""
from scipy.misc import lena
import matplotlib.pyplot as plt
img = lena().astype(np.double)
img /= img.max()
op = opRandMask(img.shape, fill_ratio=0.6)
img_masked = op(img)
plt.figure()
plt.gray()
plt.imshow(img_masked)
plt.show()
if __name__ == '__main__':
test_RandomMask()