import numpy as np
import dask.array as da
from math import sqrt
from pylops_distributed import LinearOperator
[docs]class FFT(LinearOperator):
r"""One dimensional Fast-Fourier Transform.
Apply Fast-Fourier Transform (FFT) along a specific direction ``dir`` of a
multi-dimensional array of size ``dim``.
Note that the FFT operator is an overload to the dask
:py:func:`dask.array.fft.fft` (or :py:func:`dask.array.fft.rfft` for
real models) in forward mode and to the dask :py:func:`dask.array.fft.ifft`
(or :py:func:`dask.array.fft.irfft` for real models) in adjoint mode.
Scaling is properly taken into account to guarantee
that the operator is passing the dot-test.
.. note:: For a real valued input signal, it is possible to store the
values of the Fourier transform at positive frequencies only as values
at negative frequencies are simply their complex conjugates.
However as the operation of removing the negative part of the frequency
axis in forward mode and adding the complex conjugates in adjoint mode
is nonlinear, the Linear Operator FTT with ``real=True`` is not expected
to pass the dot-test. It is thus *only* advised to use this flag when a
forward and adjoint FFT is used in the same chained operator
(e.g., ``FFT.H*Op*FFT``) such as in
:py:func:`pylops_distributed.waveeqprocessing.mdd.MDC`.
Parameters
----------
dims : :obj:`tuple`
Number of samples for each dimension
dir : :obj:`int`, optional
Direction along which FFT is applied.
nfft : :obj:`int`, optional
Number of samples in Fourier Transform (same as input if ``nfft=None``)
sampling : :obj:`float`, optional
Sampling step ``dt``.
real : :obj:`bool`, optional
Model to which fft is applied has real numbers (``True``) or not
(``False``). Used to enforce that the output of adjoint of a real
model is real.
fftshift : :obj:`bool`, optional
Apply fftshift/ifftshift (``True``) or not (``False``)
compute : :obj:`tuple`, optional
Compute the outcome of forward and adjoint or simply define the graph
and return a :obj:`dask.array.array`
chunks : :obj:`tuple`, optional
Chunk size for model and data. If provided it will rechunk the model
before applying the forward pass and the data before applying the
adjoint pass
todask : :obj:`tuple`, optional
Apply :func:`dask.array.from_array` to model and data before applying
forward and adjoint respectively
dtype : :obj:`str`, optional
Type of elements in input array.
Attributes
----------
shape : :obj:`tuple`
Operator shape
explicit : :obj:`bool`
Operator contains a matrix that can be solved explicitly
(True) or not (False)
Raises
------
ValueError
If ``dims`` is not provided and if ``dir`` is bigger than ``len(dims)``
Notes
-----
Refer to :class:`pylops.signalprocessing.FFT` for implementation
details.
"""
def __init__(self, dims, dir=0, nfft=None, sampling=1.,
real=False, fftshift=False, compute=(False, False),
chunks=(None, None), todask=(None, None), dtype='float64'):
if isinstance(dims, int):
dims = (dims,)
if dir > len(dims) - 1:
raise ValueError('dir=%d must be smaller than '
'number of dims=%d...' % (dir, len(dims)))
self.dir = dir
self.nfft = nfft if nfft is not None else dims[self.dir]
self.real = real
self.fftshift = fftshift
self.f = np.fft.rfftfreq(self.nfft, d=sampling) if real \
else np.fft.fftfreq(self.nfft, d=sampling)
if len(dims) == 1:
self.dims = np.array([dims[0], 1])
self.dims_fft = self.dims.copy()
self.dims_fft[self.dir] = self.nfft // 2 + 1 if \
self.real else self.nfft
self.reshape = False
else:
self.dims = np.array(dims)
self.dims_fft = self.dims.copy()
self.dims_fft[self.dir] = self.nfft // 2 + 1 if \
self.real else self.nfft
self.reshape = True
self.shape = (int(np.prod(dims) * (self.nfft // 2 + 1 if self.real
else self.nfft) / self.dims[dir]),
int(np.prod(dims)))
# Find types to enforce to forward and adjoint outputs. This is
# required as np.fft.fft always returns complex128 even if input is
# float32 or less
self.dtype = np.dtype(dtype)
self.cdtype = (np.ones(1, dtype=self.dtype) +
1j*np.ones(1, dtype=self.dtype)).dtype
self.compute = compute
self.chunks = chunks
self.todask = todask
self.Op = None
self.explicit = False
def _matvec(self, x):
if self.reshape:
x = da.reshape(x, self.dims)
if self.chunks[0] is not None:
x = x.rechunk(self.chunks[0])
if not self.reshape:
if self.fftshift:
x = da.fft.ifftshift(x)
if self.real:
y = sqrt(1. / self.nfft) * da.fft.rfft(da.real(x),
n=self.nfft, axis=-1)
else:
y = sqrt(1. / self.nfft) * da.fft.fft(x, n=self.nfft,
axis=-1)
else:
if self.fftshift:
x = da.fft.ifftshift(x, axes=self.dir)
if self.real:
y = sqrt(1. / self.nfft) * da.fft.rfft(da.real(x),
n=self.nfft,
axis=self.dir)
else:
y = sqrt(1. / self.nfft) * da.fft.fft(x, n=self.nfft,
axis=self.dir)
y = y.ravel()
y = y.astype(self.cdtype)
return y
def _rmatvec(self, x):
if self.reshape:
x = da.reshape(x, self.dims_fft)
if self.chunks[1] is not None:
x = x.rechunk(self.chunks[1])
if not self.reshape:
if self.real:
y = sqrt(self.nfft) * da.fft.irfft(x, n=self.nfft, axis=-1)
y = da.real(y)
else:
y = sqrt(self.nfft) * da.fft.ifft(x, n=self.nfft, axis=-1)
if self.nfft != self.dims[self.dir]:
y = y[:self.dims[self.dir]]
if self.fftshift:
y = da.fft.fftshift(y)
else:
if self.real:
y = sqrt(self.nfft) * da.fft.irfft(x, n=self.nfft,
axis=self.dir)
y = da.real(y)
else:
y = sqrt(self.nfft) * da.fft.ifft(x, n=self.nfft,
axis=self.dir)
if self.nfft != self.dims[self.dir]:
y = da.take(y, np.arange(0, self.dims[self.dir]),
axis=self.dir)
if self.fftshift:
y = da.fft.fftshift(y, axes=self.dir)
y = y.ravel()
y = y.astype(self.dtype)
return y