Simple SSB implementation

This example contains a simplified implementation of the incremental SSB for re-implementation in other languages and for modifications.

  • Process whole frames, no slicing

  • No MaskContainer for the same reason

  • No CUDA

  • Only subpix method for mask generation

LiberTEM is still used to simplify decoding the input file. See https://libertem.github.io/LiberTEM/udf.html#how-udfs-works for a pseudo code equivalent of LiberTEM UDFs in case you’d like to implement this algorithm in a simple loop without LiberTEM.

The dataset can be downloaded at https://zenodo.org/record/5113449.

[1]:
%matplotlib nbagg
[2]:
import math

import numpy as np
import scipy.constants as const
import matplotlib.pyplot as plt

import sparse
[3]:
from libertem import api
from libertem.executor.inline import InlineJobExecutor
from libertem.udf.base import UDF
from libertem.masks import circular
from libertem.corrections.coordinates import flip_y, rotate_deg, identity
[4]:
# Calculation of the relativistic electron wavelength in meters
def wavelength(U):
    e = const.elementary_charge  # Elementary charge  !!! 1.602176634×10−19
    h = const.Planck  # Planck constant    !!! 6.62607004 × 10-34
    c = const.speed_of_light  # Speed of light
    m_0 = const.electron_mass  # Electron rest mass

    T = e*U*1000
    lambda_e = h*c/(math.sqrt(T**2+2*T*m_0*(c**2)))
    return lambda_e
[5]:
def empty_mask(mask_shape, dtype):
    '''
    Return an empty sparse mask

    Improve readability
    '''
    return sparse.zeros(mask_shape, dtype=dtype)


def mask_pair_subpix(cy, cx, sy, sx, filter_center, semiconv_pix):
    '''
    Calculate positive and negative trotter mask for circular illumination
    using a method with subpixel capability.

    Parameters
    ----------

    cy, cx : float
        Position of the optical axis on the detector in px, center of illumination
    sy, sx : float
        Trotter shift value in px
    filter_center : numpy.ndarray
        Center illumination, i.e. zero order disk. This has to be circular and match the radius
        semiconv_pix. It is just passed as an argument for efficientcy to avoid unnecessary
        recalculation.
    semiconv_pix : float
        Semiconvergence angle in measured in detector pixel, i.e. radius of the zero order disk.

    Returns
    -------

    mask_positive, mask_negative : numpy.ndarray
        Positive and negative trotter mask
    '''
    mask_shape = filter_center.shape

    filter_positive = circular(
        centerX=cx+sx, centerY=cy+sy,
        imageSizeX=mask_shape[1], imageSizeY=mask_shape[0],
        radius=semiconv_pix,
        antialiased=True
    )

    filter_negative = circular(
        centerX=cx-sx, centerY=cy-sy,
        imageSizeX=mask_shape[1], imageSizeY=mask_shape[0],
        radius=semiconv_pix,
        antialiased=True
    )
    mask_positive = filter_center * filter_positive * (filter_negative == 0)
    mask_negative = filter_center * filter_negative * (filter_positive == 0)
    return mask_positive, mask_negative



def generate_mask(cy, cx, sy, sx, filter_center, semiconv_pix,
                  cutoff, dtype):
    '''
    Generate the trotter mask for a specific shift sy, sx

    Parameters
    ----------

    cy, cx : float
        Position of the optical axis on the detector in px, center of illumination
    sy, sx : float
        Trotter shift value in px
    filter_center : numpy.ndarray
        Center illumination, i.e. zero order disk. This has to be circular and match the radius
        semiconv_pix if :code:`method=subpix`. It is passed for re-use to avoid unnecessary
        recalculation
    semiconv_pix : float
        Semiconvergence angle in measured in detector pixel, i.e. radius of the zero order disk.
    cutoff : int
        Minimum number of pixels in the positive and negative trotter. This can be used to purge
        very small trotters to reduce noise.
    dtype : numpy dtype
        dtype to use for the mask

    Returns
    -------
    mask : sparse.COO
        Mask in sparse.pydata.org COO format

    '''
    mask_shape = filter_center.shape
    # 1st diffraction order and primary beam don't overlap
    if sx**2 + sy**2 > 4*np.sum(semiconv_pix**2):
        return empty_mask(mask_shape, dtype=dtype)

    if np.allclose((sy, sx), (0, 0)):
        # The zero order component (0, 0) is special, comes out zero with above code
        m_0 = filter_center / filter_center.sum()
        return sparse.COO(m_0.astype(dtype))

    params = dict(
        cy=cy, cx=cx, sy=sy, sx=sx,
        filter_center=filter_center,
        semiconv_pix=semiconv_pix,
    )

    mask_positive, mask_negative = mask_pair_subpix(**params)

    non_zero_positive = mask_positive.sum()
    non_zero_negative = mask_negative.sum()

    if non_zero_positive >= cutoff and non_zero_negative >= cutoff:
        m = (
            mask_positive / non_zero_positive
            - mask_negative / non_zero_negative
        ) / 2
        return sparse.COO(m.astype(dtype))
    else:
        # Exclude small, missing or unbalanced trotters
        return empty_mask(mask_shape, dtype=dtype)


def generate_masks(reconstruct_shape, mask_shape, dtype, lamb, dpix, semiconv,
        semiconv_pix, transformation=None, cy=None, cx=None, cutoff=1,
        cutoff_freq=np.float32('inf')):
    '''
    Generate the trotter mask stack.

    The y dimension is trimmed to size(y)//2 + 1 to exploit the inherent
    symmetry of the mask stack.

    Parameters
    ----------

    reconstruct_shape : tuple(int)
        Shape of the reconstructed area
    mask_shape : tuple(int)
        Shape of the detector
    dtype : numpy dtype
        dtype to use for the mask stack
    lamb : float
        Wavelength of the illuminating radiation in m
    dpix : float or (float, float)
        Scan step in m. Tuple (y, x) in case scan step is different in x and y direction.
    semiconv : float
        Semiconvergence angle of the illumination in radians
    semiconv_pix : float
        Semiconvergence angle in measured in detector pixel, i.e. radius of the zero order disk.
    transformation : numpy.ndarray, optional
        Matrix for affine transformation from the scan coordinate directions
        to the detector coordinate directions. This does not include the scale, which is handled by
        dpix, lamb, semiconv and semiconv_pix. It should only be used to rotate and flip
        the coordinate system as necessary. See also
        https://github.com/LiberTEM/LiberTEM/blob/master/src/libertem/corrections/coordinates.py
    cy, cx : float, optional
        Position of the optical axis on the detector in px, center of illumination.
        Default: Center of the detector
    cutoff : int, optional
        Minimum number of pixels in the positive and negative trotter. This can be used to purge
        very small trotters to reduce noise. Default is 1, i.e. no cutoff unless one trotter is
        empty.
    cutoff_freq: float
        Trotters belonging to a spatial frequency higher than this value in reciprocal pixel
        coordinates will be cut off.

    Returns
    -------
    masks : sparse.COO
        Masks in sparse.pydata.org COO format. y and x frequency index are FFT shifted, i.e. the
        zero frequency is at (0,0) and negative frequencies are in the far quadrant and reversed.
        The y frequency index is cut in half with size(y)//2 + 1 to exploit the inherent symmetry
        of a real-valued Fourier transform. The y and x index are then flattened to make it
        suitable for using it with MaskContainer.
    '''
    reconstruct_shape = np.array(reconstruct_shape)

    dpix = np.array(dpix)

    d_Kf = np.sin(semiconv)/lamb/semiconv_pix
    d_Qp = 1/dpix/reconstruct_shape

    if cy is None:
        cy = mask_shape[0] / 2
    if cx is None:
        cx = mask_shape[1] / 2

    if transformation is None:
        transformation = identity()

    filter_center = circular(
        centerX=cx, centerY=cy,
        imageSizeX=mask_shape[1], imageSizeY=mask_shape[0],
        radius=semiconv_pix,
        antialiased=True
    )

    half_reconstruct = (reconstruct_shape[0]//2 + 1, reconstruct_shape[1])
    masks = []

    for row in range(half_reconstruct[0]):
        for column in range(half_reconstruct[1]):
            # Do an fftshift of q and p
            qp = np.array((row, column))
            flip = qp > (reconstruct_shape / 2)
            real_qp = qp.copy()
            real_qp[flip] = qp[flip] - reconstruct_shape[flip]

            if np.sum(real_qp**2) > cutoff_freq**2:
                masks.append(empty_mask(mask_shape, dtype=dtype))
                continue

            # Shift of diffraction order relative to zero order
            # without rotation in physical coordinates
            real_sy_phys, real_sx_phys = real_qp * d_Qp
            # We apply the transformation backwards to go
            # from physical orientation to detector orientation,
            # while the forward direction in center of mass analysis
            # goes from detector coordinates to physical coordinates
            # Afterwards, we transform from physical detector coordinates
            # to pixel coordinates
            sy, sx = ((real_sy_phys, real_sx_phys) @ transformation) / d_Kf

            masks.append(generate_mask(
                cy=cy, cx=cx, sy=sy, sx=sx,
                filter_center=filter_center,
                semiconv_pix=semiconv_pix,
                cutoff=cutoff,
                dtype=dtype,
            ))

    # Since we go through masks in order, this gives a mask stack with
    # flattened (q, p) dimension to work with dot product and mask container
    masks = sparse.stack(masks)
    return masks

[6]:
class SimpleSSB(UDF):
    def __init__(self, converted_masks, dtype):
        '''
        Parameters
        ----------

        converted_masks : scipy.sparse.csc
            Mask stack with flattened sig and transposed.

        dtype : dtype
            dtype to use for computation
        '''
        super().__init__(converted_masks=converted_masks, dtype=dtype)

    def get_result_buffers(self):
        '''
        The 'fourier' buffer contains the reconstruction in the Fourier domain.
        '''
        dtype = np.result_type(np.complex64, self.params.dtype)
        return {
            'fourier': self.buffer(
                kind="single", dtype=dtype, extra_shape=self.reconstruct_shape,
                where='device'
            ),
        }

    def get_task_data(self):
        '''
        Precalculated lookup table for the two-dimensional discrete Fourier transform.
        This will be re-used between frames.
        The symmetry of the mask stack is exploited by only calculating half in the y direction.
        '''
        ds_nav = tuple(self.meta.dataset_shape.nav)

        # Precalculate values for Fourier transform
        # The y axis is trimmed in half since the full trotter stack is symmetric,
        # i.e. the missing half can be reconstructed from the other results
        row_steps = -2j*np.pi*np.linspace(0, 1, self.reconstruct_shape[0], endpoint=False)
        col_steps = -2j*np.pi*np.linspace(0, 1, self.reconstruct_shape[1], endpoint=False)

        half_y = self.reconstruct_shape[0] // 2 + 1
        full_x = self.reconstruct_shape[1]

        # Precalculated LUT for Fourier transform
        row_exp = np.exp(
            row_steps[:, np.newaxis]
            * np.arange(half_y)[np.newaxis, :]
        )
        col_exp = np.exp(
            col_steps[:, np.newaxis]
            * np.arange(full_x)[np.newaxis, :]
        )
        steps_dtype = np.complex128

        return {
            "row_exp": row_exp.astype(steps_dtype),
            "col_exp": col_exp.astype(steps_dtype),
        }

    @property
    def reconstruct_shape(self):
        return tuple(self.meta.dataset_shape.nav)


    def merge(self, dest, src):
        dest.fourier[:] += src.fourier

    def process_frame(self, frame):
        dot_result = frame.reshape((-1, )) @ self.params.converted_masks
        self.merge_dot_result(dot_result)

    def merge_dot_result(self, dot_result):
        half_y = self.results.fourier.shape[0] // 2 + 1

        # Get the real x and y index within the dataset navigation dimension
        # for the current tile
        # Since we use process_frame, this is one value for each dimension.
        y_index = self.meta.coordinates[0, 0]
        x_index = self.meta.coordinates[0, 1]

        factors_dtype = np.result_type(np.complex64, self.params.dtype)

        fourier_factors_row = self.task_data.row_exp[y_index, :half_y, np.newaxis]
        fourier_factors_col = self.task_data.col_exp[x_index, np.newaxis, :]

        # Fold up the flat result to match the row and column Fourier factors
        dot_result = dot_result.reshape((half_y, self.results.fourier.shape[1]))

        # Calculate the part of the Fourier transform for this tile.
        # Reconstructed shape corresponds to depth of mask stack, see above.
        # Shape of operands:
        # dot_result: (reconstructed_shape_y // 2 + 1, reconstructed_shape_x)
        # fourier_factors_row: (reconstructed_shape_y // 2 + 1, 1)
        # fourier_factors_col: (1, reconstructed_shape_x)
        # The product of the Fourier factors for row and column implicitly builds part
        # of the full 2D DFT matrix through NumPy broadcasting
        self.results.fourier[:half_y] += (dot_result * fourier_factors_row * fourier_factors_col)

    def postprocess(self):
        '''
        Patch in the missing half of the result buffer that we saved ourselves to calculate
        '''
        half_y = self.results.fourier.shape[0] // 2 + 1
        # patch accounts for even and odd sizes
        # FIXME make sure this is correct using an example that transmits also
        # the high spatial frequencies
        patch = self.results.fourier.shape[0] % 2
        # We skip the first row since it would be outside the FOV
        extracted = self.results.fourier[1:self.results.fourier.shape[0] // 2 + patch]
        # The coordinates of the bottom half are inverted and
        # the zero column is rolled around to the front
        # The real part is inverted
        self.results.fourier[half_y:] = -np.conj(
            np.roll(np.flip(extracted, axis=(0, 1)), shift=1, axis=1)
        )
[7]:
# The notebook currently ONLY works well with inline job executor
# since the mask stack is pre-calculated externally
executor = InlineJobExecutor()
ctx = api.Context(executor=executor)
[8]:
ds_shape_sig = (256, 256)
ds_shape_nav = (128, 128)
# ds_shape_nav = (256, 256)
# Acceleration voltage in keV
U = 300
# U = 200
params = {
    "dpix": 12.7e-12,
    "cy": 126,
    "cx": 123,
    "semiconv": 22.1346e-3,
    "semiconv_pix": 31,
    "dtype": np.float64,
    # applied right to left
    "transformation": rotate_deg(88) @ flip_y(),
    "cutoff": 15,
}
[9]:
ds = ctx.load("MIB", path=r'E:\LargeData\LargeData\ER-C-1/projects/ptycho-4.0/data/live-ssb-paper/Ptycho01/20200518 165148/default.hdr')
[10]:
masks = generate_masks(
    **params,
    mask_shape=ds_shape_sig,
    lamb=wavelength(U),
    reconstruct_shape=ds_shape_nav,
)
[11]:
# Flatten sig and transpose for CSC dot product
converted_masks = masks.reshape((-1, np.prod(ds.shape.sig))).T.tocsc()
[12]:
udf = SimpleSSB(converted_masks=converted_masks, dtype=np.complex128)
[13]:
def get_phase(fourier):
    return np.angle(np.fft.ifft2(fourier))

def get_amplitude(fourier):
    return np.abs(np.fft.ifft2(fourier))
[14]:
%%time
result = ctx.run_udf(udf=udf, dataset=ds, plots=[(('fourier', get_phase), ('fourier', get_amplitude))], progress=True)
100%|████████████████████████████████████████████████████████████████████████████████████| 8/8 [01:57<00:00, 14.71s/it]
Wall time: 1min 58s
[ ]: