Source code for nabu.reconstruction.mlem

import numpy as np


try:
    import corrct as cct

    __have_corrct__ = True
except ImportError:
    __have_corrct__ = False


[docs] class MLEMReconstructor: """ A reconstructor for MLEM reconstruction using the CorrCT toolbox. """ default_extra_options = { "compute_shifts": False, "tomo_consistency": False, "v_min_for_v_shifts": 0, "v_max_for_v_shifts": None, "v_min_for_u_shifts": 0, "v_max_for_u_shifts": None, } def __init__( self, sinos_shape, angles_rad, shifts_uv=None, cor=None, n_iterations=50, extra_options=None, ): """ """ if not (__have_corrct__): raise ImportError("Need corrct package") self.angles_rad = angles_rad self.n_iterations = n_iterations self._configure_extra_options(extra_options) self._set_sino_shape(sinos_shape) self._set_shifts(shifts_uv, cor) def _configure_extra_options(self, extra_options): self.extra_options = self.default_extra_options.copy() self.extra_options.update(extra_options or {}) def _set_sino_shape(self, sinos_shape): if len(sinos_shape) != 3: raise ValueError("Expected a 3D shape") self.sinos_shape = sinos_shape self.n_sinos, self.n_angles, self.prj_width = sinos_shape if self.n_angles != len(self.angles_rad): raise ValueError( f"Number of angles ({len(self.angles_rad)}) does not match size of sinograms ({self.n_angles})." ) def _set_shifts(self, shifts_uv, cor): if shifts_uv is None: self.shifts_uv = np.zeros([self.n_angles, 2]) else: if shifts_uv.shape[0] != self.n_angles: raise ValueError( f"Number of shifts given ({shifts_uv.shape[0]}) does not mathc the number of projections ({self.n_angles})." ) self.shifts_uv = shifts_uv.copy() self.cor = cor
[docs] def reconstruct(self, data_vwu): """ data_align_vwu: numpy.ndarray or pycuda.gpuarray Raw data, with shape (n_sinograms, n_angles, width) output: optional Output array. If not provided, a new numpy array is returned """ if not isinstance(data_vwu, np.ndarray): data_vwu = data_vwu.get() data_vwu /= data_vwu.mean() # MLEM recons self.vol_geom_align = cct.models.VolumeGeometry.get_default_from_data(data_vwu) self.prj_geom_align = cct.models.ProjectionGeometry.get_default_parallel() # Vertical shifts were handled in pipeline. Set them to ZERO self.shifts_uv[:, 1] = 0.0 self.prj_geom_align.set_detector_shifts_vu(self.shifts_uv.T[::-1]) variances_align = cct.processing.compute_variance_poisson(data_vwu) self.weights_align = cct.processing.compute_variance_weight(variances_align, normalized=True) # , use_std=True self.data_term_align = cct.data_terms.DataFidelity_wl2(self.weights_align) solver = cct.solvers.MLEM(verbose=True, data_term=self.data_term_align) self.solver_opts = dict(lower_limit=0) # , x_mask=cct.processing.circular_mask(vol_geom_align.shape_xyz[:-2]) with cct.projectors.ProjectorUncorrected( self.vol_geom_align, self.angles_rad, rot_axis_shift_pix=self.cor, prj_geom=self.prj_geom_align ) as A: rec, _ = solver(A, data_vwu, iterations=self.n_iterations, **self.solver_opts) return rec