import numpy as np
import numba
from libertem.udf import UDF
from libertem.common.container import MaskContainer
from ptychography40.reconstruction.ssb.trotters import generate_masks
@numba.njit(fastmath=True, cache=True, parallel=True, nogil=True)
def rmatmul_csc_fourier(n_threads, left_dense, right_data, right_indices, right_indptr,
coordinates, row_exp, col_exp, res_inout):
'''
Fold :meth:`SSB_UDF.merge_dot_result` into a sparse dot product from
:meth:`libertem.common.numba._rmatmul_csc`
The result can be directly merged into the result buffer instead of
instantiating the intermediate dot result, which can be large. That way,
this method can process entire memory-mapped partitions efficiently.
Furthermore, it allows early skipping of trotters that are empty in this tile.
It uses multithreading to process different parts of a tile in parallel.
'''
left_rows = left_dense.shape[0]
p_size = col_exp.shape[1]
q_size = row_exp.shape[1]
# We subdivide in blocks per thread so that each thread
# writes exclusively to its own part of an intermediate result buffer.
# Using prange and automated '+=' merge leads to wrong results when using threading.
blocksize = max(int(np.ceil(left_rows / n_threads)), 1)
resbuf = np.zeros((n_threads, q_size, p_size), dtype=res_inout.dtype)
# The blocks are processed in parallel
for block in numba.prange(n_threads):
start = block * blocksize
stop = min((block + 1) * blocksize, left_rows)
for left_row in range(start, stop):
# Pixel coordinates in nav dimension
y, x = coordinates[left_row]
for q in range(q_size):
for p in range(p_size):
# right_column is the mask index
right_column = q * p_size + p
# Descent into CSC data structure
offset = right_indptr[right_column]
items = right_indptr[right_column+1] - offset
if items > 0:
# We accumulate for the whole mask into acc
# before applying the phase factor
acc = 0
# Iterate over non-zero entries in this mask
for i in range(items):
index = i + offset
right_row = right_indices[index]
right_value = right_data[index]
acc += left_dense[left_row, right_row] * right_value
# Phase factor for this scan point and mask
factor = row_exp[y, q] * col_exp[x, p]
# Applying the factor, accumulate in per-thread buffer
resbuf[block, q, p] += acc * factor
# Finally, accumulate per-thread buffer into result
res_inout += np.sum(resbuf, axis=0)
[docs]
class SSB_UDF(UDF):
'''
UDF to perform ptychography using the single side band (SSB) method :cite:`Pennycook2015`.
'''
[docs]
def __init__(self, lamb, dpix, semiconv, semiconv_pix,
dtype=np.float32, cy=None, cx=None, transformation=None,
cutoff=1, method='subpix', mask_container=None,):
'''
Parameters
----------
lamb: float
The illumination wavelength in m. The function :meth:`ptychography40.common.wavelength`
allows to calculate the electron wavelength as a function of acceleration voltage.
dpix: float or Iterable(y, x)
STEM pixel size in m
semiconv: float
STEM semiconvergence angle in radians
dtype: np.dtype
dtype to perform the calculation in
semiconv_pix: float
Diameter of the primary beam in the diffraction pattern in pixels
cy, cx : float, optional
Position of the optical axis on the detector in px, center of illumination.
Default: Center of the detector
transformation: numpy.ndarray() of shape (2, 2) or None
Transformation matrix to apply to shift vectors. This allows to adjust for scan rotation
and mismatch of detector coordinate system handedness, such as flipped y axis for MIB.
cutoff : int
Minimum number of pixels in a trotter
method : 'subpix' or 'shift'
Method to use for generating the mask stack
mask_container: MaskContainer
Allows to pass in a precomputed mask stack when using with single thread live data
or with an inline executor as a work-around. The number of masks is sanity-checked
to match the other parameters.
The proper fix is https://github.com/LiberTEM/LiberTEM/issues/335
'''
super().__init__(lamb=lamb, dpix=dpix, semiconv=semiconv, semiconv_pix=semiconv_pix,
dtype=dtype, cy=cy, cx=cx, mask_container=mask_container,
transformation=transformation, cutoff=cutoff, method=method)
EFFICIENT_THREADS = 4
def get_result_buffers(self):
''
dtype = np.result_type(np.complex64, self.params.dtype)
component_dtype = np.result_type(np.float32, self.params.dtype)
return {
'fourier': self.buffer(
kind="single", dtype=dtype, extra_shape=self.reconstruct_shape,
where='device'
),
'complex': self.buffer(
kind="single", dtype=dtype, extra_shape=self.reconstruct_shape,
use='result_only',
),
'amplitude': self.buffer(
kind="single", dtype=component_dtype, extra_shape=self.reconstruct_shape,
use='result_only',
),
'phase': self.buffer(
kind="single", dtype=component_dtype, extra_shape=self.reconstruct_shape,
use='result_only',
),
}
@property
def reconstruct_shape(self):
return tuple(self.meta.dataset_shape.nav)
def get_task_data(self):
''
# shorthand, cupy or numpy
xp = self.xp
if self.meta.device_class == 'cpu':
backend = 'numpy'
elif self.meta.device_class == 'cuda':
backend = 'cupy'
else:
raise ValueError("Unknown device class")
# Hack to pass a fixed external container
# In particular useful for single-process live processing
# or inline executor
if self.params.mask_container is None:
masks = generate_masks(
reconstruct_shape=self.reconstruct_shape,
mask_shape=tuple(self.meta.dataset_shape.sig),
dtype=self.params.dtype,
lamb=self.params.lamb,
dpix=self.params.dpix,
semiconv=self.params.semiconv,
semiconv_pix=self.params.semiconv_pix,
cy=self.params.cy,
cx=self.params.cx,
transformation=self.params.transformation,
cutoff=self.params.cutoff,
method=self.params.method,
)
container = MaskContainer(
mask_factories=lambda: masks, dtype=masks.dtype,
use_sparse='scipy.sparse.csr', count=masks.shape[0], backend=backend
)
else:
container = self.params.mask_container
target_size = (self.reconstruct_shape[0] // 2 + 1)*self.reconstruct_shape[1]
container_shape = container.computed_masks.shape
expected_shape = (target_size, ) + tuple(self.meta.dataset_shape.sig)
if container_shape != expected_shape:
raise ValueError(
f"External mask container doesn't have the expected shape. "
f"Got {container_shape}, expected {expected_shape}. "
"Mask count (self.meta.dataset_shape.nav[0] // 2 + 1) "
"* self.meta.dataset_shape.nav[1], "
"Mask shape self.meta.dataset_shape.sig. "
"The methods generate_masks_*() help to generate a suitable mask stack."
)
# Precalculated LUT for Fourier transform
# The y axis is trimmed in half since the full trotter stack is symmetric,
# i.e. the missing half can be reconstructed from the other results
row_steps = -2j*np.pi*np.linspace(0, 1, self.reconstruct_shape[0], endpoint=False)
col_steps = -2j*np.pi*np.linspace(0, 1, self.reconstruct_shape[1], endpoint=False)
half_y = self.reconstruct_shape[0] // 2 + 1
full_x = self.reconstruct_shape[1]
# This creates a 2D array of row x spatial frequency
row_exp = np.exp(
row_steps[:, np.newaxis]
* np.arange(half_y)[np.newaxis, :]
)
# This creates a 2D array of col x spatial frequency
col_exp = np.exp(
col_steps[:, np.newaxis]
* np.arange(full_x)[np.newaxis, :]
)
steps_dtype = np.result_type(np.complex64, self.params.dtype)
return {
"masks": container,
"row_exp": xp.array(row_exp.astype(steps_dtype)),
"col_exp": xp.array(col_exp.astype(steps_dtype)),
"backend": backend
}
def merge(self, dest, src):
''
dest.fourier[:] += src.fourier
def merge_dot_result(self, dot_result):
# shorthand, cupy or numpy
xp = self.xp
tile_depth = dot_result.shape[0]
# We calculate only half of the Fourier transform due to the
# inherent symmetry of the mask stack. In this case we
# cut the y axis in half. The "+ 1" accounts for odd sizes
# The mask stack is already trimmed in y direction to only contain
# one of the trotter pairs
half_y = self.results.fourier.shape[0] // 2 + 1
# Get the real x and y indices within the dataset navigation dimension
# for the current tile
y_indices = self.meta.coordinates[:, 0]
x_indices = self.meta.coordinates[:, 1]
# This loads the correct entries for the current tile from the pre-calculated
# 1-D DFT matrices using the x and y indices of the frames in the current tile
# fourier_factors_row is already trimmed for half_y, but the explicit index
# is kept for clarity
fourier_factors_row = self.task_data.row_exp[y_indices, :half_y, np.newaxis]
fourier_factors_col = self.task_data.col_exp[x_indices, np.newaxis, :]
# The masks are in order [row, col], but flattened. Here we undo the flattening
dot_result = dot_result.reshape((tile_depth, half_y, self.results.fourier.shape[1]))
# Calculate the part of the Fourier transform for this tile.
# Reconstructed shape corresponds to depth of mask stack, see above.
# Shape of operands:
# dot_result: (tile_depth, reconstructed_shape_y // 2 + 1, reconstructed_shape_x)
# fourier_factors_row: (tile_depth, reconstructed_shape_y // 2 + 1, 1)
# fourier_factors_col: (tile_depth, 1, reconstructed_shape_x)
# The einsum is equivalent to
# (dot_result*fourier_factors_row*fourier_factors_col).sum(axis=0)
# The product of the Fourier factors for row and column implicitly builds part
# of the full 2D DFT matrix through NumPy broadcasting
# The sum of axis 0 (tile depth) already performs the accumulation for the tile
# stack before patching the missing half for the full result.
# Einsum is about 3x faster in this scenario, likely because of not building a large
# intermediate array before summation
self.results.fourier[:half_y] += xp.einsum(
'i...,i...,i...',
dot_result,
fourier_factors_row,
fourier_factors_col
)
def postprocess(self):
''
# shorthand, cupy or numpy
xp = self.xp
half_y = self.results.fourier.shape[0] // 2 + 1
# patch accounts for even and odd sizes
# FIXME make sure this is correct using an example that transmits also
# the high spatial frequencies
patch = self.results.fourier.shape[0] % 2
# We skip the first row since it would be outside the FOV
extracted = self.results.fourier[1:self.results.fourier.shape[0] // 2 + patch]
# The coordinates of the bottom half are inverted and
# the zero column is rolled around to the front
# The real part is inverted
self.results.fourier[half_y:] = -xp.conj(
xp.roll(xp.flip(xp.flip(extracted, axis=0), axis=1), shift=1, axis=1)
)
[docs]
def get_results(self):
'''
Since we derive the wave front with a linear function from intensities,
but the wave front is inherently an amplitude,
we take the square root of the calculated amplitude and
combine it with the calculated phase.
'''
inverse = np.fft.ifft2(self.results.fourier)
amp = np.sqrt(np.abs(inverse))
phase = np.angle(inverse)
return {
'fourier': self.results.fourier,
'complex': amp * np.exp(1j*phase),
'amplitude': amp,
'phase': phase,
}
def process_tile(self, tile):
''
# We flatten the signal dimension of the tile in preparation of the dot product
tile_flat = tile.reshape(tile.shape[0], -1)
if self.task_data.backend == 'cupy':
# Load the preconditioned trotter stack from the mask container:
# * Coutout for current tile
# * Flattened signal dimension
# * Clean CSR matrix
# * On correct device
masks = self.task_data.masks.get(
self.meta.slice, transpose=False, backend=self.task_data.backend,
sparse_backend='scipy.sparse.csr'
)
# Skip an empty tile since the contribution is 0
if masks.nnz == 0:
return
# This performs the trotter integration
# As of now, cupy doesn't seem to support __rmatmul__ with sparse matrices
dot_result = masks.dot(tile_flat.T).T
self.merge_dot_result(dot_result)
else:
# We calculate only half of the Fourier transform due to the
# inherent symmetry of the mask stack. In this case we
# cut the y axis in half. The "+ 1" accounts for odd sizes
# The mask stack is already trimmed in y direction to only contain
# one of the trotter pairs
half_y = self.results.fourier.shape[0] // 2 + 1
tpw = self.meta.threads_per_worker
if (tpw is None) or (tpw >= self.EFFICIENT_THREADS):
# Load the preconditioned trotter stack from the mask container:
# * Coutout for current tile
# * Flattened signal dimension
# * Transposed for right hand side of dot product
# (rmatmul_csc_fourier() takes care of putting things where they belong
# in the result)
# * Clean CSC matrix
# * On correct device (CPU)
masks = self.task_data.masks.get(
self.meta.slice, transpose=True, backend=self.task_data.backend,
sparse_backend='scipy.sparse.csc'
)
# Skip an empty tile since the contribution is 0
if masks.nnz == 0:
return
# This combines the trotter integration with merge_dot_result
# into a Numba loop that eliminates a potentially large intermediate result
# and allows efficient multithreading on a large tile
rmatmul_csc_fourier(
n_threads=self.meta.threads_per_worker,
left_dense=tile_flat,
right_data=masks.data,
right_indices=masks.indices,
right_indptr=masks.indptr,
coordinates=self.meta.coordinates,
row_exp=self.task_data.row_exp,
col_exp=self.task_data.col_exp,
res_inout=self.results.fourier[:half_y]
)
else:
masks = self.task_data.masks.get(
self.meta.slice, transpose=False, backend=self.task_data.backend,
sparse_backend='scipy.sparse.csr'
)
# Skip an empty tile since the contribution is 0
if masks.nnz == 0:
return
# This performs the trotter integration
# Transposed copy of the input data for fast scipy.sparse __matmul__()
dot_result = masks.dot(tile_flat.T.copy()).T
self.merge_dot_result(dot_result)
def get_backends(self):
''
return ('numpy', 'cupy')
def get_tiling_preferences(self):
''
dtype = np.result_type(np.complex64, self.params.dtype)
result_size = np.prod(self.reconstruct_shape) * dtype.itemsize
if self.meta.device_class == 'cuda':
free, total = self.xp.cuda.runtime.memGetInfo()
total_size = min(100e6, free // 4)
good_depth = max(1, total_size / result_size * 4)
return {
"depth": good_depth,
"total_size": total_size,
}
else:
tpw = self.meta.threads_per_worker
if (tpw is None) or (tpw >= self.EFFICIENT_THREADS):
# The parallel Numba loop can process entire partitions efficiently
# since it accumulates in the result without large intermediate
# data
return {
"depth": self.TILE_DEPTH_MAX,
"total_size": self.TILE_SIZE_MAX,
}
else:
# We limit the depth of a tile so that the intermediate
# results from processing a tile fit into the CPU cache.
good_depth = max(1, 1e6 / result_size)
return {
"depth": int(good_depth),
# We reduce the size of a tile since a transposed copy of
# each tile will be taken to speed up the sparse matrix product
"total_size": 0.5e6,
}