Tools for iterative ptychography
This notebook showcases how to use the tools for fast iterative ptychography in ptychography40.reconstruction.common
. They are designed to perform the “heavy lifting” for such algorithms in a consistent and performant way. This allows code re-use and handles parameters the same way across different implementations and methods to facilitate the FAIR principle.
In this demonstration, first a forward diffraction simulation is performed. The diffraction pattern and the object are then transformed individually to different coordinate systems to reflect the typical situation in real measurements where coodinates are not matched pixel-by-pixel, but defined by experimental geometry and settings.
After that, the transformed detector data is compared with a forward model in the transformed coordinate system in a way that is typical for iterative ptychography implementations. The tools in ptychography40.reconstruction.common
are compared to and benchmarked against a simple reference implementation based on scipy.ndimage
and NumPy.
[1]:
%matplotlib nbagg
#%load_ext line_profiler
[2]:
import numpy as np
import scipy.ndimage
import matplotlib.pyplot as plt
[3]:
from ptychography40.reconstruction.common import (
wavelength, diffraction_to_detector, image_transformation_matrix, apply_matrix,
fftshift_coords, ifftshift_coords, rolled_object_probe_product_cpu
)
Define parameters for the initial simulation
Instead of measured data we use a simulation that is better suited to verify the calculation.
[4]:
size = 256
semiconv = 0.020 # radian
lmbda = wavelength(300) # 300 kV
pixel_size_real_sim = 0.5 * lmbda # Simulate at high resolution
# One pixel in the diffracted space corresponds to the entire size in real space.
pixel_size_detector_sim = 1/size/pixel_size_real_sim*lmbda
[5]:
lmbda, pixel_size_real_sim, pixel_size_detector_sim
[5]:
(1.9687489006848795e-12, 9.843744503424398e-13, 0.0078125)
Create an object
This object has low symmetry and contains a wave modulation, which leaves a characteristic signature in the diffraction pattern.
[6]:
obj = np.ones((size, size), dtype=np.complex64)
y, x = np.ogrid[-size//2:size//2, -size//2:size//2]
outline = (((y*1.2)**2 + x**2) > 110**2) & ((((y*1.2)**2 + x**2) < 120**2))
obj[outline] = 0
left_eye = ((y + 40)**2 + (x + 40)**2) < 20**2
obj[left_eye] = 0
right_eye = (np.abs(y + 40) < 15) & (np.abs(x - 40) < 30)
obj[right_eye] = 0
nose = (y + 20 + x > 0) & (x < 0) & (y < 10)
obj[nose] = (0.05j * x + 0.05j * y)[nose]
mouth = (((y*1)**2 + x**2) > 50**2) & ((((y*1)**2 + x**2) < 70**2)) & (y > 20)
obj[mouth] = 0
tongue = (((y - 50)**2 + (x - 50)**2) < 20**2) & ((y**2 + x**2) > 70**2)
obj[tongue] = 0
# This wave modulation introduces a strong signature in the diffraction pattern
# that allows to confirm the correct scale and orientation.
signature_wave = np.exp(1j*(3 * y + 7 * x) * 2*np.pi/size)
obj += 0.3*signature_wave - 0.3
[7]:
fig, axes = plt.subplots(1, 3)
axes[0].imshow(np.abs(obj))
axes[1].imshow(np.log1p(np.abs(np.fft.fftshift(np.fft.fft2(obj)))))
axes[2].imshow(np.angle(np.fft.fftshift(np.fft.fft2(obj))))
[7]:
<matplotlib.image.AxesImage at 0x1bb76858cd0>
Calculate the illumination
The illumination is first defined in radial space with a circular aperture, i.e. a convergent beam, as it is common in scanning transmission electron microscopy. It is then projected into the object plane using an inverse Fourier transform.
The fftshift
takes care of proper shift to the center.
[8]:
illum_radial = np.zeros((size, size))
illum_radial[np.sqrt(y**2 + x**2) * pixel_size_detector_sim <= semiconv] = 1
[9]:
illum = np.fft.fftshift(np.fft.ifft2(np.fft.ifftshift(illum_radial)))
[10]:
fig, axes = plt.subplots(1, 3)
axes[0].imshow(illum_radial)
axes[0].set_title("Radial")
axes[1].imshow(np.abs(illum))
axes[1].set_title("Amplitude")
axes[2].imshow(np.angle(illum))
axes[2].set_title("Phase")
[10]:
Text(0.5, 1.0, 'Phase')
Forward simulation
Multiply the illumination with the object, project it to the far field and transform amplitude to intensity.
[11]:
exitwave = illum * obj
projection = np.fft.fftshift(np.fft.fft2(np.fft.ifftshift(exitwave)))
diffpattern_sim = np.abs(projection)**2
[12]:
fig, axes = plt.subplots(1, 3)
axes[0].imshow(np.abs(projection))
axes[0].set_title("Amplitude")
axes[1].imshow(np.angle(projection))
axes[1].set_title("Phase")
axes[2].imshow(diffpattern_sim)
axes[2].set_title("Intensity")
[12]:
Text(0.5, 1.0, 'Intensity')
Transform the “perfect” simulation to a typical experimental result
The data is resampled with user-defined pixel sizes that are mismatched between object and diffraction pattern.
Furthermore, the diffraction pattern is rotated and shifted. In real experiments the rotation is usually close to 90 ° increments, but electron microscopes can also include a free rotation in their projection. The detector data might also be mirrored along an axis, i.e. be recorded with different handedness than the object coordinates. This is not shown here for simplicity.
[13]:
pixel_size_real = 10e-12 # m
pixel_size_detector = 0.001 # radian
scan_rotation = 65 # degrees
shift = (7, 3)
[14]:
# Transform both object and illumination.
rec_obj = scipy.ndimage.zoom(obj, pixel_size_real_sim/pixel_size_real)
rec_illum = scipy.ndimage.zoom(illum, pixel_size_real_sim/pixel_size_real)
# Step-by-step transformation of the diffraction pattern.
real_diffpattern = scipy.ndimage.zoom(diffpattern_sim, pixel_size_detector_sim/pixel_size_detector)
# Positive angle means counter-clockwise in scipy.ndimage.rotate by default
# Note that this implies a left-handed coordinate system in scipy.ndimage, different from the convention in LiberTEM
# that is right-handed. See also https://libertem.github.io/LiberTEM/concepts.html#coordinate-system.
# That means by using the positive scan rotation both here and later in creating the transformation matrix,
# the matrix undoes the rotation introduced here.
real_diffpattern = scipy.ndimage.rotate(real_diffpattern, scan_rotation, reshape=False)
real_diffpattern = scipy.ndimage.shift(real_diffpattern, shift)
center = real_diffpattern.shape[0] // 2
# Crop
real_diffpattern = real_diffpattern[center-64:center+65, center-64:center+65]
# Truncate negative values from interpolation
real_diffpattern = np.maximum(real_diffpattern, 0)
Show mismatch between detector and forward model
A forward model multiplies the rescaled object with the rescaled illumination, project the resulting exit wave into the far field and compares the wave with the detector data. Here we see that this is not possible anymore pixel-by-pixel since the object and detector data underwent different, incompatible transformations.
[15]:
forward = np.fft.fftshift(np.fft.fft2(np.fft.ifftshift(rec_illum * rec_obj)))
[16]:
fig, axes = plt.subplots(2, 2)
axes[0, 0].imshow(np.abs(rec_obj))
axes[0, 0].set_title("object")
axes[0, 1].imshow(np.abs(rec_illum))
axes[0, 1].set_title("illumination")
axes[1, 0].imshow(np.abs(forward))
axes[1, 0].set_title("forward model")
axes[1, 1].imshow(np.log1p(real_diffpattern))
axes[1, 1].set_title("detector")
[16]:
Text(0.5, 1.0, 'detector')
Forward models
Here we define a reference forward model that can shift object and illumination relative to each other, and compare it with an optimized implementation based on the functions in ptychography40.reconstruction.common
.
The optimized implementation avoids several copies of the data since rolled_object_probe_product_cpu()
folds rolling, product and inverse FFT shift into a single operation. Furthermore, it can process a whole stack of shifts in one go, which reduces overheads for small illuminated areas. Note that the fast implementation omits the last FFT shift. This is performed later on the detector data instead.
rolled_object_probe_product_cpu()
also supports illuminations that are smaller than the object to improve the scaling behavior with large objects and small illuminated area. Furthermore it can perform subpixel shifts if an array of subpixel-shifted illuminations is provided. Those features are not demonstrated here.
[17]:
def reference_forward_model(illum, obj, shifts):
result = np.empty((len(shifts), *illum.shape), dtype=np.complex64)
for i, shift in enumerate(shifts):
tmp_obj = np.roll(obj, -np.array(shift), axis=(0, 1))
result[i] = np.fft.fftshift(np.fft.fft2(np.fft.ifftshift(illum * tmp_obj)))
return result
[18]:
def fast_forward_model(illum, obj, shifts):
fast_forward_buffer = np.empty((len(shifts), *illum.shape), dtype=np.complex64)
rolled_object_probe_product_cpu(
obj=obj,
probe=illum[np.newaxis, np.newaxis, ...],
shifts=np.array(shifts),
result_out=fast_forward_buffer,
ifftshift=True
)
return np.fft.fft2(fast_forward_buffer)
[19]:
ref = reference_forward_model(rec_illum, rec_obj, [(0, 0), (3, 4)])
fast = fast_forward_model(rec_illum, rec_obj, [(0, 0),(3, 4)])
[20]:
np.allclose(np.fft.fftshift(fast, axes=(1, 2)), ref)
[20]:
True
[21]:
fig, axes = plt.subplots(1, 2)
axes[0].imshow(np.abs(ref[1]))
axes[1].imshow(np.abs(fast[1]))
[21]:
<matplotlib.image.AxesImage at 0x1bb7a710d90>
Transformation of the detector data
The detector data has to be transformed in such a way that it matches the forward model. That way, an update function for the object and/or illumination can be computed from a comparison of both.
Again, we compare two methods: A reference method that uses scipy.ndimage
to unwind the transformations that have been applied to the simulated pattern and scale it to the correct resolution, and a faster method using functions from ptychography40.reconstruction.common
that calculates and applies a sparse transformation matrix.
Once the matrix is pre-computed, this method can apply all transformations in a single step to a whole stack of images. Furthermore, it can perform an inverse FFT shift with zero overhead to avoid this step in the forward model calculation.
[22]:
def reference_tansform(diffpatterns, target_shape):
pixel_size_detector_reconstruction = 1/target_shape[0]/pixel_size_real*lmbda
matched_patterns = np.empty((len(diffpatterns), *target_shape))
zoom = pixel_size_detector/pixel_size_detector_reconstruction
# We apply the ndimage transformations step by step since the 3D version
# to process a stack is slower
for i in range(len(diffpatterns)):
matched_diffpattern = scipy.ndimage.shift(diffpatterns[i], (-shift[0], -shift[1]))
matched_diffpattern = scipy.ndimage.rotate(matched_diffpattern, -scan_rotation, reshape=True)
matched_diffpattern = scipy.ndimage.zoom(matched_diffpattern, zoom)
padding = (target_shape[0] - matched_diffpattern.shape[0]) / 2
padding_spec = ((int(np.floor(padding)), int(np.ceil(padding))), )*2
matched_patterns[i] = np.pad(matched_diffpattern, pad_width=padding_spec)
return matched_patterns
The fast method first defines coordinate transformations, then calculates a sparse transformation matrix from these coordinate transformations, and finally applies them to data. This is particularly efficient to process a large number of images with the same transformation, as required for ptychography.
[23]:
affine_transformation = diffraction_to_detector(
lamb=lmbda,
diffraction_shape=rec_obj.shape,
pixel_size_real=pixel_size_real,
pixel_size_detector=pixel_size_detector,
cy=real_diffpattern.shape[0]/2 + shift[0],
cx=real_diffpattern.shape[1]/2 + shift[1],
flip_y=False,
scan_rotation=scan_rotation
)
transformation_fast = image_transformation_matrix(
source_shape=real_diffpattern.shape,
target_shape=rec_obj.shape,
affine_transformation=affine_transformation,
pre_transform=ifftshift_coords(reconstruct_shape=rec_obj.shape),
)
[24]:
transformed_reference = reference_tansform(
real_diffpattern[np.newaxis, ...],
rec_obj.shape
)[0]
transformed_fast = apply_matrix(
sources=real_diffpattern[np.newaxis, ...],
matrix=transformation_fast,
target_shape=rec_obj.shape
)[0]
Verify match
Forward model and transformed detector data should be similar. Both transformation methods are lossy and give somewhat different results.
[25]:
fig, axes = plt.subplots(2, 2)
axes[0, 0].imshow(np.abs(reference_forward_model(rec_illum, rec_obj, [(0, 0)])[0])**2)
axes[0, 0].set_title("ref forward")
axes[0, 1].imshow(np.abs(fast_forward_model(rec_illum, rec_obj, [(0, 0)])[0])**2)
axes[0, 1].set_title("fast forward")
axes[1, 0].imshow((transformed_reference))
axes[1, 0].set_title("ref detector")
axes[1, 1].imshow((transformed_fast))
axes[1, 1].set_title("fast detector")
[25]:
Text(0.5, 1.0, 'fast detector')
Compare performance
For demonstration, this just calculates the difference between detector and forward model. A real ptychographic routine would compute an update for object and/or illumination instead. The method based on ptychography40.reconstruction.common
is about 60x faster.
[26]:
def compare_ref(illum, obj, diffpatterns, probe_shifts):
matched_patterns = reference_tansform(diffpatterns, obj.shape)
ref_diff = matched_patterns - np.abs(reference_forward_model(illum, obj, probe_shifts))**2
return ref_diff
[27]:
def compare_fast(illum, obj, diffpatterns, probe_shifts, transformation):
transformed_fast = apply_matrix(
sources=diffpatterns,
matrix=transformation,
target_shape=obj.shape
)
fast_diff = transformed_fast - np.abs(fast_forward_model(illum, obj, probe_shifts))**2
return fast_diff
[28]:
diffpattern_stack = np.stack((real_diffpattern, )* 32)
shifts = np.array([(i, i) for i in range(32)])
[29]:
%%timeit
compare_ref(rec_illum, rec_obj, diffpattern_stack, shifts)
168 ms ± 1.18 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
[30]:
%%timeit
compare_fast(rec_illum, rec_obj, diffpattern_stack, shifts, transformation_fast)
2.6 ms ± 140 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
[31]:
# %lprun -f compare_ref -f reference_tansform -f reference_forward_model compare_ref(rec_illum, rec_obj, diffpattern_stack, shifts)
[32]:
# %lprun -f compare_fast compare_fast(rec_illum, rec_obj, diffpattern_stack, shifts, transformation_fast)
Conclusion
Besides being a lot faster, the tools in ptychography40.reconstruction.common
accept parameters that are available as metadata from experiments, as opposed to requiring custom transformations. They are verified to accept similar parameters and behave the same as the SSB implementation and center of mass analysis in LiberTEM to facilitate compatible metadata across experimental modalities and reconstruction algorithms.