Source code for ptychography40.reconstruction.ssb.trotters

import functools

import numpy as np
import sparse
import numba

from libertem.masks import circular
from libertem.corrections.coordinates import identity

from ptychography40.reconstruction.common import get_shifted


@functools.lru_cache(None)
def empty_mask(mask_shape, dtype):
    '''
    Return an empty sparse mask

    Improve readability
    '''
    return sparse.zeros(mask_shape, dtype=dtype)


def mask_pair_subpix(cy, cx, sy, sx, filter_center, semiconv_pix):
    '''
    Calculate positive and negative trotter mask for circular illumination
    using a method with subpixel capability.

    Parameters
    ----------

    cy, cx : float
        Position of the optical axis on the detector in px, center of illumination
    sy, sx : float
        Trotter shift value in px
    filter_center : numpy.ndarray
        Center illumination, i.e. zero order disk. This has to be circular and match the radius
        semiconv_pix. It is just passed as an argument for efficientcy to avoid unnecessary
        recalculation.
    semiconv_pix : float
        Semiconvergence angle in measured in detector pixel, i.e. radius of the zero order disk.

    Returns
    -------

    mask_positive, mask_negative : numpy.ndarray
        Positive and negative trotter mask
    '''
    mask_shape = filter_center.shape

    filter_positive = circular(
        centerX=cx+sx, centerY=cy+sy,
        imageSizeX=mask_shape[1], imageSizeY=mask_shape[0],
        radius=semiconv_pix,
        antialiased=True
    )

    filter_negative = circular(
        centerX=cx-sx, centerY=cy-sy,
        imageSizeX=mask_shape[1], imageSizeY=mask_shape[0],
        radius=semiconv_pix,
        antialiased=True
    )
    mask_positive = filter_center * filter_positive * (filter_negative == 0)
    mask_negative = filter_center * filter_negative * (filter_positive == 0)
    return mask_positive, mask_negative


@numba.njit
def mask_tile_pair(center_tile, tile_origin, tile_shape, filter_center, sy, sx):
    '''
    Numerical work horse for :meth:`mask_pair_shift`, including tiling support.

    The tiling support could be used to calculate the mask stack on the fly,
    including support for UDF.process_tile().

    Parameters
    ----------

    center_tile : numpy.ndarray
        Tile cut out from :code:`filter_center` for re-use to increase efficiency
    tile_origin : tuple
        Origin of the tile to calculate
    tile_shape : tuple
        Shape of the tile
    filter_center : numpy.ndarray
        Center illumination, i.e. zero order disk.
    sy, sx : float
        Trotter shift value in px

    Returns
    -------
    mask_positive : numpy.ndarray
        Positive trotter tile
    target_tup_p : numpy.ndarray of int
        Start and stop indices per axis that were used for shifting the positive trotter tile.
    offsets_p : numpy.ndarray
        Offsets per axis that were used for shifting the positive trotter tile.
    mask_negative : numpy.ndarray
        Negative trotter tile
    target_tup_n : numpy.ndarray
        Start and stop indices per axis that were used for shifting the negative trotter tile.
    offsets_n : numpy.ndarray
        Offsets per axis that were used for shifting the negative trotter tile.
    '''
    sy, sx, = int(np.round(sy)), int(np.round(sx))
    positive_tile = np.zeros_like(center_tile)
    negative_tile = np.zeros_like(center_tile)
    # We get from negative coordinates,
    # that means it looks like shifted to positive
    target_tup_p, offsets_p = get_shifted(
        arr_shape=np.array(filter_center.shape),
        tile_origin=tile_origin,
        tile_shape=tile_shape,
        shift=np.array((-sy, -sx))
    )
    # We get from positive coordinates,
    # that means it looks like shifted to negative
    target_tup_n, offsets_n = get_shifted(
        arr_shape=np.array(filter_center.shape),
        tile_origin=tile_origin,
        tile_shape=tile_shape,
        shift=np.array((sy, sx))
    )

    sta_y, sto_y, sta_x, sto_x = target_tup_p.flatten()
    off_y, off_x = offsets_p
    positive_tile[sta_y:sto_y, sta_x:sto_x] = filter_center[
        sta_y+off_y:sto_y+off_y,
        sta_x+off_x:sto_x+off_x
    ]
    sta_y, sto_y, sta_x, sto_x = target_tup_n.flatten()
    off_y, off_x = offsets_n
    negative_tile[sta_y:sto_y, sta_x:sto_x] = filter_center[
        sta_y+off_y:sto_y+off_y,
        sta_x+off_x:sto_x+off_x
    ]

    mask_positive = center_tile * positive_tile * (negative_tile == 0)
    mask_negative = center_tile * negative_tile * (positive_tile == 0)

    return (mask_positive, target_tup_p, offsets_p, mask_negative, target_tup_n, offsets_n)


def mask_pair_shift(cy, cx, sy, sx, filter_center, semiconv_pix):
    '''
    Calculate positive and negative trotter mask using a fast shifting method.

    It has the same signature as :meth:`mask_pair_subpix` for easy changing
    between the methods. That means several parameters that are only relevant
    for the subpix method are ignored in this function.

    Parameters
    ----------

    cy, cx : float
        Ignored, given implicitly by filter_center
    sy, sx : float
        Trotter shift value in px
    filter_center : numpy.ndarray
        Center illumination, i.e. zero order disk.
    semiconv_pix : float
        Ignored, given implicitly by filter_center

    Returns
    -------

    mask_positive, mask_negative : numpy.ndarray
        Positive and negative trotter mask
    '''
    (mask_positive, target_tup_p, offsets_p,
    mask_negative, target_tup_n, offsets_n) = mask_tile_pair(
        center_tile=np.array(filter_center),
        tile_origin=np.array((0, 0)),
        tile_shape=np.array(filter_center.shape),

        filter_center=np.array(filter_center),
        sy=sy,
        sx=sx
    )
    return mask_positive, mask_negative


def generate_mask(cy, cx, sy, sx, filter_center, semiconv_pix,
                  cutoff, dtype, method='subpix'):
    '''
    Generate the trotter mask for a specific shift sy, sx

    Parameters
    ----------

    cy, cx : float
        Position of the optical axis on the detector in px, center of illumination
    sy, sx : float
        Trotter shift value in px
    filter_center : numpy.ndarray
        Center illumination, i.e. zero order disk. This has to be circular and match the radius
        semiconv_pix if :code:`method=subpix`. It is passed for re-use to avoid unnecessary
        recalculation
    semiconv_pix : float
        Semiconvergence angle in measured in detector pixel, i.e. radius of the zero order disk.
    cutoff : int
        Minimum number of pixels in the positive and negative trotter. This can be used to purge
        very small trotters to reduce noise.
    dtype : numpy dtype
        dtype to use for the mask
    method : str, optional
        Can be :code:`'subpix'`(default) or :code:`'shift'` to switch between
        :meth:`mask_pair_subpix` and :meth:`mask_pair_shift` to generate the trotter pair.

    Returns
    -------
    mask : sparse.COO
        Mask in sparse.pydata.org COO format

    '''
    mask_shape = filter_center.shape
    # 1st diffraction order and primary beam don't overlap
    if sx**2 + sy**2 > 4*np.sum(semiconv_pix**2):
        return empty_mask(mask_shape, dtype=dtype)

    if np.allclose((sy, sx), (0, 0)):
        # The zero order component (0, 0) is special, comes out zero with above code
        m_0 = filter_center / filter_center.sum()
        return sparse.COO(m_0.astype(dtype))

    params = dict(
        cy=cy, cx=cx, sy=sy, sx=sx,
        filter_center=filter_center,
        semiconv_pix=semiconv_pix,
    )

    if method == 'subpix':
        mask_positive, mask_negative = mask_pair_subpix(**params)
    elif method == 'shift':
        mask_positive, mask_negative = mask_pair_shift(**params)
    else:
        raise ValueError(f"Unsupported method {method}. Allowed are 'subpix' and 'shift'")

    non_zero_positive = mask_positive.sum()
    non_zero_negative = mask_negative.sum()

    if non_zero_positive >= cutoff and non_zero_negative >= cutoff:
        m = (
            mask_positive / non_zero_positive
            - mask_negative / non_zero_negative
        ) / 2
        m_dt = m.astype(dtype)
        m_sparse = sparse.COO(m_dt)
        return m_sparse
    else:
        # Exclude small, missing or unbalanced trotters
        return empty_mask(mask_shape, dtype=dtype)


[docs] def generate_masks(reconstruct_shape, mask_shape, dtype, lamb, dpix, semiconv, semiconv_pix, transformation=None, cy=None, cx=None, cutoff=1, cutoff_freq=np.float32('inf'), method='subpix'): ''' Generate the trotter mask stack. The y dimension is trimmed to size(y)//2 + 1 to exploit the inherent symmetry of the mask stack. Parameters ---------- reconstruct_shape : tuple(int) Shape of the reconstructed area mask_shape : tuple(int) Shape of the detector dtype : numpy dtype dtype to use for the mask stack lamb : float Wavelength of the illuminating radiation in m dpix : float or (float, float) Scan step in m. Tuple (y, x) in case scan step is different in x and y direction. semiconv : float Semiconvergence angle of the illumination in radians semiconv_pix : float Semiconvergence angle in measured in detector pixel, i.e. radius of the zero order disk. transformation : numpy.ndarray, optional Matrix for affine transformation from the scan coordinate directions to the detector coordinate directions. This does not include the scale, which is handled by dpix, lamb, semiconv and semiconv_pix. It should only be used to rotate and flip the coordinate system as necessary. See also https://github.com/LiberTEM/LiberTEM/blob/master/src/libertem/corrections/coordinates.py cy, cx : float, optional Position of the optical axis on the detector in px, center of illumination. Default: Center of the detector cutoff : int, optional Minimum number of pixels in the positive and negative trotter. This can be used to purge very small trotters to reduce noise. Default is 1, i.e. no cutoff unless one trotter is empty. cutoff_freq: float Trotters belonging to a spatial frequency higher than this value in reciprocal pixel coordinates will be cut off. method : str, optional Can be :code:`'subpix'`(default) or :code:`'shift'` to switch between :meth:`mask_pair_subpix` and :meth:`mask_pair_shift` to generate a trotter pair. Returns ------- masks : sparse.COO Masks in sparse.pydata.org COO format. y and x frequency index are FFT shifted, i.e. the zero frequency is at (0,0) and negative frequencies are in the far quadrant and reversed. The y frequency index is cut in half with size(y)//2 + 1 to exploit the inherent symmetry of a real-valued Fourier transform. The y and x index are then flattened to make it suitable for using it with MaskContainer. ''' reconstruct_shape = np.array(reconstruct_shape) dpix = np.array(dpix) d_Kf = np.sin(semiconv)/lamb/semiconv_pix d_Qp = 1/dpix/reconstruct_shape if cy is None: cy = mask_shape[0] / 2 if cx is None: cx = mask_shape[1] / 2 if transformation is None: transformation = identity() filter_center = circular( centerX=cx, centerY=cy, imageSizeX=mask_shape[1], imageSizeY=mask_shape[0], radius=semiconv_pix, antialiased=True ) half_reconstruct = (reconstruct_shape[0]//2 + 1, reconstruct_shape[1]) masks = [] for row in range(half_reconstruct[0]): for column in range(half_reconstruct[1]): # Do an fftshift of q and p qp = np.array((row, column)) flip = qp > (reconstruct_shape / 2) real_qp = qp.copy() real_qp[flip] = qp[flip] - reconstruct_shape[flip] if np.sum(real_qp**2) > cutoff_freq**2: masks.append(empty_mask(mask_shape, dtype=dtype)) continue # Shift of diffraction order relative to zero order # without rotation in physical coordinates real_sy_phys, real_sx_phys = real_qp * d_Qp # We apply the transformation backwards to go # from physical orientation to detector orientation, # while the forward direction in center of mass analysis # goes from detector coordinates to physical coordinates # Afterwards, we transform from physical detector coordinates # to pixel coordinates sy, sx = ((real_sy_phys, real_sx_phys) @ transformation) / d_Kf masks.append(generate_mask( cy=cy, cx=cx, sy=sy, sx=sx, filter_center=filter_center, semiconv_pix=semiconv_pix, cutoff=cutoff, dtype=dtype, method=method, )) # Since we go through masks in order, this gives a mask stack with # flattened (q, p) dimension to work with dot product and mask container masks = sparse.stack(masks) return masks