# Copyright European Space Agency, 2013
"""
This module resamples mappings in a given resolution in the plate carree
projection, relative to either geodetic or MLat/MLT coordinates.
"""
from __future__ import division, print_function, absolute_import
from six.moves import map
import time
import numpy as np
import numpy.ma as ma
from distutils.version import LooseVersion
import astropy
import copy
from functools import partial
if LooseVersion(astropy.__version__) < '0.4':
raise RuntimeError('astropy<0.4 is unsupported due to bugs in handling non-contiguous arrays')
from astropy.coordinates import Angle
import astropy.units as u
import scipy.interpolate
from auromat.utils import pointsInsidePolygon, extend
from auromat.mapping.mapping import BaseMapping, MappingCollection,\
convertMappingToSM, convertSMMappingToGeo
from auromat.coordinates.transform import rotatePole
from auromat.util.histogram import histogram2d
from auromat.coordinates import geodesic
from auromat.coordinates.geodesic import Location
[docs]def plateCarreeResolution(boundingBox, arcsecPerPx):
"""
Approximates the latitude and longitude resolution of a plate carree
projection from the given spherical resolution for the area given
by the bounding box. The approximation is calculated for the bounding box center.
:type boundingBox: auromat.mapping.mapping.BoundingBox
:param arcsecPerPx: spherical resolution
:rtype: tuple (latPxPerDeg, lonPxPerDeg)
"""
degPerPx = (arcsecPerPx * u.arcsec).to(u.degree).value
latPxPerDeg = 1/degPerPx
latMiddle = (boundingBox.latNorth + boundingBox.latSouth)/2
middleLeft = Location(latMiddle, boundingBox.lonWest)
middleRight = Location(latMiddle, boundingBox.lonEast)
lonMiddleDistance = geodesic.angularDistance(middleLeft, middleRight)
px = lonMiddleDistance/degPerPx
lonEast = boundingBox.lonEast
if boundingBox.lonWest > lonEast:
lons = lonEast+360 - boundingBox.lonWest
else:
lons = lonEast - boundingBox.lonWest
lonMiddlePxPerDeg = px/lons
return latPxPerDeg, lonMiddlePxPerDeg
[docs]def resampleMLatMLT(mapping, **kw):
""" Resamples a mapping such that MLat/MLT become regular grids.
See :func:`resample` for parameters.
"""
sm = convertMappingToSM(mapping)
smResampled = resample(sm, **kw)
geo = convertSMMappingToGeo(smResampled)
return geo
[docs]def resample(mappingOrCollection, pxPerDeg=25, arcsecPerPx=None, containsPole=None, method='mean'):
"""
Returns a new mapping (or collection) where the colors and elevation
are resampled into a regular latitude/longitude grid (plate carree projection)
with y=latitude and x=longitude.
If 'mean' binning is used as resampling method then take into account that
this will lead to holes for low elevation angles if a high resampling resolution
is used. This is because binning does not interpolate when there are zero data
points in a given bin. Mask the mapping by elevation (e.g. 10deg) to get rid
of the areas with holes.
:param mappingOrCollection:
:param None|number|tuple pxPerDeg: tuple (latPxPerDeg, lonPxPerDeg) or a number if both are the same
:param None|number arcsecPerPx: spherical resolution, used to approximate pxPerDeg,
has precedence over pxPerDeg
:param None|bool containsPole: specify True|False to skip pole checking algorithm
:param method: binning: 'mean'; interpolation: 'nearest', 'linear', 'cubic';
Note that linear and cubic take considerably longer and use much more memory while
they don't bring any benefit over 'nearest' if the goal is downsampling.
:rtype: a subclass of BaseMapping or MappingCollection
"""
def doResample(mapping, pxPerDeg, arcsecPerPx, containsPole):
# trigger calculation of properties so that they are not included in the timing measurements
mapping.lats
mapping.latsCenter
mapping.elevation
mapping.img
t0 = time.time()
if containsPole is None:
containsPole = mapping.containsPole
if arcsecPerPx:
pxPerDeg = plateCarreeResolution(mapping.boundingBox, arcsecPerPx)
else:
try:
_, _ = pxPerDeg
except TypeError:
assert pxPerDeg is not None
pxPerDeg = (pxPerDeg, pxPerDeg)
print('pxPerDeg: ' + str(pxPerDeg))
imgIsInt = mapping.img.dtype in [np.uint8, np.uint16, np.uint32, np.uint64, np.int8, np.int16, np.int32, np.int64]
# merge elevation with rgb array and extract channels afterwards
merged = np.dstack((mapping.img.astype(np.float64).filled(np.nan),
mapping.elevation.filled(np.nan)))
lats, lons, latsCenter, lonsCenter, merged = \
_resample(mapping.latsCenter.filled(np.nan), mapping.lonsCenter.filled(np.nan), mapping.altitude,
merged,
lambda: mapping.outline, mapping.boundingBox,
pxPerDeg, mapping.containsDiscontinuity, containsPole,
method=method)
img, elevation = np.dsplit(merged, [-1])
if imgIsInt:
with np.errstate(invalid='ignore'):
img = np.round(img)
img = np.require(ma.masked_invalid(img, copy=False), mapping.img.dtype)
if mapping.img.ndim == 2:
img = img.reshape(img.shape[0], img.shape[1])
elevation = elevation.reshape(elevation.shape[0],elevation.shape[1])
elevation = ma.masked_invalid(elevation, copy=False)
resampledMapping = mapping.createResampled(lats, lons, latsCenter, lonsCenter, elevation, img)
print('resampling:', time.time()-t0, 's')
return resampledMapping
if isinstance(mappingOrCollection, BaseMapping):
resampled = doResample(mappingOrCollection, pxPerDeg, arcsecPerPx, containsPole)
elif isinstance(mappingOrCollection, MappingCollection):
mappings = []
for mapping in mappingOrCollection.mappings:
mappings.append(doResample(mapping, pxPerDeg, arcsecPerPx, containsPole))
resampled = MappingCollection(mappings, mayOverlap=mappingOrCollection.mayOverlap)
else:
raise ValueError('First argument must be a mapping or a mapping collection, but is: {}'.
format(type(mappingOrCollection)))
return resampled
def _resample(latsCenter, lonsCenter, altitude, data, outlineLatLonFn, boundingBox, pxPerDeg,
containsDiscontinuity=False, containsPole=False, method='mean'):
"""
Note: Each channel is resampled on its own.
:param latsCenter: (h,w)
:param lonsCenter: (h,w)
:param data: float data for each pixel center, (h,w,n) with n>0, or (h,w)
:param pxPerDeg: tuple (latPxPerDeg, lonPxPerDeg)
:rtype: tuple (lat, lon, latCenter, lonCenter, data)
"""
latMin = boundingBox.latSouth
latMax = boundingBox.latNorth
lonMin = boundingBox.lonWest
lonMax = boundingBox.lonEast
if containsPole:
print('contains pole')
outlineLatLon = outlineLatLonFn()
outlineLats = outlineLatLon[:,0]
outlineLons = outlineLatLon[:,1]
# rotation of latitude/poles needs to happen in cartesian space based on earth as sphere
# only a very small error will be introduced here as the outline form is not a segment of a sphere but an ellipsoid
# -> as the outline is very small in size, this won't be a problem
angle = 90
axis = [1,0,0]
outlineLats, outlineLons = rotatePole(np.deg2rad(outlineLats), np.deg2rad(outlineLons),
altitude, angle=angle, axis=axis)
outlineLats, outlineLons = np.rad2deg(outlineLats), np.rad2deg(outlineLons)
outlineLatLon[:,0] = outlineLats
outlineLatLon[:,1] = outlineLons
latMin, latMax = np.min(outlineLats), np.max(outlineLats)
lonMin, lonMax = np.min(outlineLons), np.max(outlineLons)
latsCenter_, lonsCenter_ = rotatePole(np.deg2rad(np.ravel(latsCenter)), np.deg2rad(np.ravel(lonsCenter)),
altitude, angle=angle, axis=axis)
latsCenter = np.rad2deg(latsCenter_.reshape(latsCenter.shape))
lonsCenter = np.rad2deg(lonsCenter_.reshape(lonsCenter.shape))
elif containsDiscontinuity:
print('contains discontinuity')
outlineLatLon = outlineLatLonFn()
outlineLats = outlineLatLon[:,0]
outlineLons = outlineLatLon[:,1]
# rotate longitudes out of 180° discontinuity; poles stay where they are
# this introduces no additional error (e.g. due to ellipsoidal form)
angle = 180
outlineLons = Angle((outlineLons + angle) * u.deg).wrap_at(angle * u.deg).degree
outlineLatLon[:,1] = outlineLons
lonMin, lonMax = np.min(outlineLons), np.max(outlineLons)
lonsCenter = Angle((lonsCenter + angle) * u.deg).wrap_at(angle * u.deg).degree
# create regular plate carree grid within bounding box where y=lat and x=lon
# Note: For a given pxPerDeg, all resamplings are aligned to the same global grid.
latPxPerDeg, lonPxPerDeg = pxPerDeg
assert latPxPerDeg > 0 and lonPxPerDeg > 0
nLat, nLon, latMinInGrid, latMaxInGrid, lonMinInGrid, lonMaxInGrid =\
fixedGrid(pxPerDeg, latMin, latMax, lonMin, lonMax)
assert nLat > 1, 'nlat={}, latMax={}, latMin={}, pxperdeg={}'.format(nLat, latMaxInGrid, latMinInGrid, pxPerDeg)
assert nLon > 1, 'nlon={}, lonMax={}, lonMin={}, pxperdeg={}'.format(nLon, lonMaxInGrid, lonMinInGrid, pxPerDeg)
# the center coordinates are the ones which lie on the grid, corners are calculated
latSpaceCenter, latStep = np.linspace(latMaxInGrid, latMinInGrid, num=nLat, retstep=True)
lonSpaceCenter, lonStep = np.linspace(lonMinInGrid, lonMaxInGrid, num=nLon, retstep=True)
# skip first and last coordinate center coordinate, otherwise we would have to calculate corner
# coordinates outside the determined range, which could trigger certain edge cases
latSpace = latSpaceCenter[:-1] + latStep/2
lonSpace = lonSpaceCenter[:-1] + lonStep/2
latSpaceCenter = latSpaceCenter[1:-1]
lonSpaceCenter = lonSpaceCenter[1:-1]
#latGrid, lonGrid = np.meshgrid(latSpace, lonSpace, indexing='ij') # 'indexing' not supported in np 1.6
latGrid, lonGrid = np.dstack(np.meshgrid(latSpace, lonSpace)).T
latGridCenter, lonGridCenter = np.dstack(np.meshgrid(latSpaceCenter, lonSpaceCenter)).T
# do the actual resampling
dataResampled = _resampleCenterData(latsCenter, lonsCenter,
data, latSpaceCenter, lonSpaceCenter, latStep, lonStep,
method)
# mask grid points which are outside the outline
# This is needed as 'linear' and 'cubic' only mask points outside the *convex hull*,
# which is not enough as we have concave forms. In those corner cases the data is interpolated.
# With 'nearest', nothing is masked.
# With 'mean', there is no inter/extrapolation, so we can skip the additional masking.
if method != 'mean':
# Based on the masked grid points the data is masked if any of its 4 corner points is masked.
outlineLatLon = outlineLatLonFn()
latLonGridFlat = np.asarray([np.ravel(latGrid), np.ravel(lonGrid)]).T
isOutside = ~pointsInsidePolygon(latLonGridFlat, outlineLatLon).reshape(latGrid.shape)
mask = np.logical_or.reduce((isOutside[:-1,:-1], isOutside[1:,:-1], isOutside[:-1,1:], isOutside[1:,1:]))
dataResampled[mask] = np.nan
# rotate back coordinates if previously rotated
if containsPole:
angle = -90
axis = [1,0,0]
latGridFlat, lonGridFlat = rotatePole(np.deg2rad(latGrid.ravel()), np.deg2rad(lonGrid.ravel()),
altitude, angle=angle, axis=axis)
latGrid = np.rad2deg(latGridFlat.reshape(latGrid.shape))
lonGrid = np.rad2deg(lonGridFlat.reshape(latGrid.shape))
latGridCenterFlat, lonGridCenterFlat = rotatePole(np.deg2rad(latGridCenter.ravel()), np.deg2rad(lonGridCenter.ravel()),
altitude, angle=angle, axis=axis)
latGridCenter = np.rad2deg(latGridCenterFlat.reshape(latGridCenter.shape))
lonGridCenter = np.rad2deg(lonGridCenterFlat.reshape(latGridCenter.shape))
elif containsDiscontinuity:
angle = 180
lonGrid = Angle((lonGrid + angle) * u.deg).wrap_at(angle * u.deg).degree
lonGridCenter = Angle((lonGridCenter + angle) * u.deg).wrap_at(angle * u.deg).degree
return latGrid, lonGrid, latGridCenter, lonGridCenter, dataResampled
[docs]def fixedGrid(pxPerDeg, latMin, latMax, lonMin, lonMax):
"""
Aligns the given bounding box to a fixed plate carree grid as defined
by `pxPerDeg`.
:param lonMin,lonMax: must NOT contain the discontinuity
"""
latPxPerDeg, lonPxPerDeg = pxPerDeg
nLatAll = latPxPerDeg*180 + 1
nLonAll = lonPxPerDeg*360 + 1
latSpaceAll = np.linspace(-90, 90, int(round(nLatAll)))
lonSpaceAll = np.linspace(-180, 180, int(round(nLonAll)))
latMinInGrid = latSpaceAll[np.argmax(latSpaceAll > latMin) - 1]
latMaxInGrid = latSpaceAll[np.argmax(latSpaceAll >= latMax)]
lonMinInGrid = lonSpaceAll[np.argmax(lonSpaceAll > lonMin) - 1]
lonMaxInGrid = lonSpaceAll[np.argmax(lonSpaceAll >= lonMax)]
nLat = int(round(latPxPerDeg*(latMaxInGrid-latMinInGrid) + 1))
nLon = int(round(lonPxPerDeg*(lonMaxInGrid-lonMinInGrid) + 1))
return nLat, nLon, latMinInGrid, latMaxInGrid, lonMinInGrid, lonMaxInGrid
def _resampleCenterData(latsCenter, lonsCenter, centerData, latSpaceCenter, lonSpaceCenter, latStep, lonStep,
method):
"""
:param method: binning: 'mean'; interpolation: 'nearest', 'linear', 'cubic';
Note that linear and cubic take considerably longer and use much more memory while
they don't bring any benefit over 'nearest' if the goal is downsampling.
"""
if centerData.ndim == 2:
scalarData = True
centerData = centerData[...,None]
else:
scalarData = False
# interpolate center data at grid point centers
centerInLatsFlat, centerInLonsFlat = np.ravel(latsCenter), np.ravel(lonsCenter)
centerNonNans = ~np.isnan(centerInLatsFlat)
centerInLatsFlatFiltered = centerInLatsFlat[centerNonNans]
centerInLonsFlatFiltered = centerInLonsFlat[centerNonNans]
centerFlat = centerData.reshape(-1,centerData.shape[2])
centerFlatFiltered = centerFlat[centerNonNans]
if method in ['nearest', 'linear', 'cubic']:
centerResampled = scipy.interpolate.griddata(
(centerInLatsFlatFiltered,centerInLonsFlatFiltered), centerFlatFiltered,
(latSpaceCenter[:,None], lonSpaceCenter[None,:]), method=method)
elif method == 'mean':
# this is about 20-50% slower than griddata's 'nearest'
bins = (len(lonSpaceCenter), len(latSpaceCenter))
# the bin egdes must be monotonically increasing, therefore we do that and flip it afterwards
# so that latitudes are decreasing
range_ = [[lonSpaceCenter[0]-lonStep/2, lonSpaceCenter[-1]+lonStep/2],
[latSpaceCenter[-1]+latStep/2, latSpaceCenter[0]-latStep/2]]
data = [centerFlatFiltered[:,d] for d in range(centerData.shape[2])]
countAndData,_,_ = histogram2d(centerInLonsFlatFiltered, centerInLatsFlatFiltered, bins=bins, range=range_,
weights=[None]+data)
count = countAndData[0].T
resampled = []
for d in range(centerData.shape[2]):
data = centerFlatFiltered[:,d]
dataOut = countAndData[d+1].T
dataOut[count==0.0] = np.nan
with np.errstate(invalid='ignore'):
dataOut /= count
# flip so that latitudes are decreasing
dataOut = np.flipud(dataOut)
resampled.append(dataOut)
centerResampled = np.dstack(resampled)
elif method == 'median':
raise NotImplementedError
# https://stackoverflow.com/a/15488537
# https://stackoverflow.com/a/10324083
# http://fspaolo.net/code/bin-data.html
else:
raise NotImplementedError
assert centerResampled.shape == (len(latSpaceCenter), len(lonSpaceCenter), centerData.shape[2]),\
str(centerResampled.shape) + ' != ' + str((len(latSpaceCenter), len(lonSpaceCenter), centerData.shape[2]))
if scalarData:
centerResampled = centerResampled.reshape(centerResampled.shape[0], centerResampled.shape[1])
return centerResampled
[docs]def ResampleProvider(provider, **kw):
"""
Wrap the given mapping provider by resampling every returned mapping.
:param provider: the provider to wrap
See :func:`resample` for masking parameters.
"""
resampleFn = partial(resample, **kw)
class ResamplingProvider(object):
def get(self, *args, **kw):
m = super(ResamplingProvider, self).get(*args, **kw)
return resampleFn(m)
def getById(self, *args, **kw):
m = super(ResamplingProvider, self).getById(*args, **kw)
return resampleFn(m)
def getSequence(self, *args, **kw):
m = super(ResamplingProvider, self).getSequence(*args, **kw)
return map(resampleFn, m)
provider = copy.copy(provider)
extend(provider, ResamplingProvider)
return provider