import h5py
import numpy
import logging
from math import ceil
from typing import Optional, Iterable, Union
from tomoscan.series import Series
from tomoscan.identifier import BaseIdentifier
from nabu.stitching.stitcher.base import _StitcherBase, get_obj_constant_side_length
from nabu.stitching.stitcher_2D import stitch_raw_frames
from nabu.stitching.utils.utils import ShiftAlgorithm, from_slice_to_n_elements
from nabu.stitching.overlap import (
check_overlaps,
ImageStichOverlapKernel,
)
from nabu.stitching.config import (
SingleAxisStitchingConfiguration,
KEY_RESCALE_MIN_PERCENTILES,
KEY_RESCALE_MAX_PERCENTILES,
)
from nabu.misc.utils import rescale_data
from nabu.stitching.sample_normalization import normalize_frame as normalize_frame_by_sample
from nabu.stitching.stitcher.dumper.base import DumperBase
from silx.io.utils import get_data
from silx.io.url import DataUrl
from scipy.ndimage import shift as shift_scipy
_logger = logging.getLogger(__name__)
PROGRESS_BAR_STITCH_VOL_DESC = "stitch volumes"
# description of the progress bar used when stitching volume.
# Needed to retrieve advancement from file when stitching remotely
class _SingleAxisMetaClass(type):
"""
Metaclass for single axis stitcher in order to aggregate dumper class and axis
"""
def __new__(mcls, name, bases, attrs, axis=None, dumper_cls=None):
mcls = super().__new__(mcls, name, bases, attrs)
mcls._axis = axis
mcls._dumperCls = dumper_cls
return mcls
[docs]
class SingleAxisStitcher(_StitcherBase, metaclass=_SingleAxisMetaClass):
"""
Any single-axis base class
"""
def __init__(self, configuration, *args, **kwargs) -> None:
super().__init__(configuration, *args, **kwargs)
if self._dumperCls is not None:
self._dumper = self._dumperCls(configuration=configuration)
else:
self._dumper = None
# initial shifts
self._axis_0_rel_ini_shifts = []
"""Shift between two juxtapose objects along axis 0 found from position metadata or given by the user"""
self._axis_1_rel_ini_shifts = []
"""Shift between two juxtapose objects along axis 1 found from position metadata or given by the user"""
self._axis_2_rel_ini_shifts = []
"""Shift between two juxtapose objects along axis 2 found from position metadata or given by the user"""
# shifts to add once refine
self._axis_0_rel_final_shifts = []
"""Shift over axis 0 found once refined by the cross correlation algorithm"""
self._axis_1_rel_final_shifts = []
"""Shift over axis 1 found once refined by the cross correlation algorithm"""
self._axis_2_rel_final_shifts = []
"""Shift over axis 2 found once refined by the cross correlation algorithm"""
self._slices_to_stitch = None
# slices to be stitched. Obtained from calling Configuration.settle_slices
self._stitching_constant_length = None
# stitching width: larger volume width. Other volume will be pad
def shifts_is_scalar(shifts):
return isinstance(shifts, ShiftAlgorithm) or numpy.isscalar(shifts)
# 'expend' shift algorithm
if shifts_is_scalar(self.configuration.axis_0_pos_px):
self.configuration.axis_0_pos_px = [
self.configuration.axis_0_pos_px,
] * (len(self.series) - 1)
if shifts_is_scalar(self.configuration.axis_1_pos_px):
self.configuration.axis_1_pos_px = [
self.configuration.axis_1_pos_px,
] * (len(self.series) - 1)
if shifts_is_scalar(self.configuration.axis_1_pos_px):
self.configuration.axis_1_pos_px = [
self.configuration.axis_1_pos_px,
] * (len(self.series) - 1)
if numpy.isscalar(self.configuration.axis_0_params):
self.configuration.axis_0_params = [
self.configuration.axis_0_params,
] * (len(self.series) - 1)
if numpy.isscalar(self.configuration.axis_1_params):
self.configuration.axis_1_params = [
self.configuration.axis_1_params,
] * (len(self.series) - 1)
if numpy.isscalar(self.configuration.axis_2_params):
self.configuration.axis_2_params = [
self.configuration.axis_2_params,
] * (len(self.series) - 1)
@property
def axis(self) -> int:
return self._axis
@property
def dumper(self):
return self._dumper
@property
def stitching_axis_in_frame_space(self):
"""
stitching is operated in 2D (frame) space. So the axis in frame space is different than the one in 3D ebs-tomo space (https://tomo.gitlab-pages.esrf.fr/bliss-tomo/master/modelization.html)
"""
raise NotImplementedError("Base class")
[docs]
def stitch(self, store_composition: bool = True) -> BaseIdentifier:
if self.progress is not None:
self.progress.set_description("order scans")
self.order_input_tomo_objects()
if self.progress is not None:
self.progress.set_description("check inputs")
self.check_inputs()
self.settle_flips()
if self.progress is not None:
self.progress.set_description("compute shifts")
self._compute_positions_as_px()
self.pre_processing_computation()
self.compute_estimated_shifts()
self._compute_shifts()
self._createOverlapKernels()
if self.progress is not None:
self.progress.set_description(PROGRESS_BAR_STITCH_VOL_DESC)
self._create_stitching(store_composition=store_composition)
if self.progress is not None:
self.progress.set_description("dump configuration")
self.dumper.save_configuration()
return self.dumper.output_identifier
@property
def serie_label(self) -> str:
"""return serie name for logs"""
return "single axis serie"
[docs]
def get_n_slices_to_stitch(self):
"""Return the number of slice to be stitched"""
if self._slices_to_stitch is None:
raise RuntimeError("Slices needs to be settled first")
return from_slice_to_n_elements(self._slices_to_stitch)
[docs]
def get_final_axis_positions_in_px(self) -> dict:
"""
compute the final position (**in pixel**) from the initial position of the first object and the final relative shift computed (1)
(1): the final relative shift is obtained from the initial shift (from motor position of provided by the user) + the refinement shift from cross correlation algorithm
:return: dict with tomo object identifier (str) as key and a tuple of position in pixel (axis_0_pos, axis_1_pos, axis_2_pos)
"""
pos_0_shift = numpy.concatenate(
(
numpy.atleast_1d(0.0),
numpy.array(self._axis_0_rel_final_shifts) - numpy.array(self._axis_0_rel_ini_shifts),
)
)
pos_0_cum_shift = numpy.cumsum(pos_0_shift)
final_pos_axis_0 = self.configuration.axis_0_pos_px + pos_0_cum_shift
pos_1_shift = numpy.concatenate(
(
numpy.atleast_1d(0.0),
numpy.array(self._axis_1_rel_final_shifts) - numpy.array(self._axis_1_rel_ini_shifts),
)
)
pos_1_cum_shift = numpy.cumsum(pos_1_shift)
final_pos_axis_1 = self.configuration.axis_1_pos_px + pos_1_cum_shift
pos_2_shift = numpy.concatenate(
(
numpy.atleast_1d(0.0),
numpy.array(self._axis_2_rel_final_shifts) - numpy.array(self._axis_2_rel_ini_shifts),
)
)
pos_2_cum_shift = numpy.cumsum(pos_2_shift)
final_pos_axis_2 = self.configuration.axis_2_pos_px + pos_2_cum_shift
assert len(final_pos_axis_0) == len(final_pos_axis_1)
assert len(final_pos_axis_0) == len(final_pos_axis_2)
assert len(final_pos_axis_0) == len(self.series)
return {
tomo_obj.get_identifier().to_str(): (pos_0, pos_1, pos_2)
for tomo_obj, (pos_0, pos_1, pos_2) in zip(
self.series, zip(final_pos_axis_0, final_pos_axis_1, final_pos_axis_2)
)
}
[docs]
def settle_flips(self):
"""
User can provide some information on existing flips at frame level.
The goal of this step is to get one flip_lr and on flip_ud value per scan or volume
"""
if numpy.isscalar(self.configuration.flip_lr):
self.configuration.flip_lr = tuple([self.configuration.flip_lr] * len(self.series))
else:
if not len(self.configuration.flip_lr) == len(self.series):
raise ValueError("flip_lr expects a scalar value or one value per element to stitch")
self.configuration.flip_lr = tuple(self.configuration.flip_lr)
for elmt in self.configuration.flip_lr:
if not isinstance(elmt, bool):
raise TypeError
if numpy.isscalar(self.configuration.flip_ud):
self.configuration.flip_ud = tuple([self.configuration.flip_ud] * len(self.series))
else:
if not len(self.configuration.flip_ud) == len(self.series):
raise ValueError("flip_ud expects a scalar value or one value per element to stitch")
self.configuration.flip_ud = tuple(self.configuration.flip_ud)
for elmt in self.configuration.flip_ud:
if not isinstance(elmt, bool):
raise TypeError
def _createOverlapKernels(self):
"""
after this stage the overlap kernels must be created and with the final overlap size
"""
if self.axis == 0:
stitched_axis_rel_shifts = self._axis_0_rel_final_shifts
stitched_axis_params = self.configuration.axis_0_params
elif self.axis == 1:
stitched_axis_rel_shifts = self._axis_1_rel_final_shifts
stitched_axis_params = self.configuration.axis_1_params
elif self.axis == 2:
stitched_axis_rel_shifts = self._axis_2_rel_final_shifts
stitched_axis_params = self.configuration.axis_2_params
else:
raise NotImplementedError
if stitched_axis_rel_shifts is None or len(stitched_axis_rel_shifts) == 0:
raise RuntimeError(
f"axis {self.axis} shifts have not been defined yet. Please define them before calling this function"
)
overlap_size = stitched_axis_params.get("overlap_size", None)
if overlap_size in (None, "None", ""):
overlap_size = -1
else:
overlap_size = int(overlap_size)
self._stitching_constant_length = max(
[get_obj_constant_side_length(obj, axis=self.axis) for obj in self.series]
)
for stitched_axis_shift in stitched_axis_rel_shifts:
if overlap_size == -1:
height = abs(stitched_axis_shift)
else:
height = overlap_size
self._overlap_kernels.append(
ImageStichOverlapKernel(
stitching_axis=self.stitching_axis_in_frame_space,
frame_unstitched_axis_size=self._stitching_constant_length,
stitching_strategy=self.configuration.stitching_strategy,
overlap_size=height,
extra_params=self.configuration.stitching_kernels_extra_params,
)
)
@property
def series(self) -> Series:
return self._series
@property
def configuration(self) -> SingleAxisStitchingConfiguration:
return self._configuration
@property
def progress(self):
return self._progress
@staticmethod
def _data_bunch_iterator(slices, bunch_size):
"""util to get indices by bunch until we reach n_frames"""
if isinstance(slices, slice):
# note: slice step is handled at a different level
start = end = slices.start
while True:
start, end = end, min((end + bunch_size), slices.stop)
yield (start, end)
if end >= slices.stop:
break
# in the case of non-contiguous frames
elif isinstance(slices, Iterable):
for s in slices:
yield (s, s + 1)
else:
raise TypeError(f"slices is provided as {type(slices)}. When Iterable or slice is expected")
[docs]
def rescale_frames(self, frames: tuple):
"""
rescale_frames if requested by the configuration
"""
_logger.info("apply rescale frames")
def cast_percentile(percentile) -> int:
if isinstance(percentile, str):
percentile.replace(" ", "").rstrip("%")
return int(percentile)
rescale_min_percentile = cast_percentile(self.configuration.rescale_params.get(KEY_RESCALE_MIN_PERCENTILES, 0))
rescale_max_percentile = cast_percentile(
self.configuration.rescale_params.get(KEY_RESCALE_MAX_PERCENTILES, 100)
)
new_min = numpy.percentile(frames[0], rescale_min_percentile)
new_max = numpy.percentile(frames[0], rescale_max_percentile)
def rescale(data):
# FIXME: takes time because browse several time the dataset, twice for percentiles and twices to get min and max when calling rescale_data...
data_min = numpy.percentile(data, rescale_min_percentile)
data_max = numpy.percentile(data, rescale_max_percentile)
return rescale_data(data, new_min=new_min, new_max=new_max, data_min=data_min, data_max=data_max)
return tuple([rescale(data) for data in frames])
[docs]
def normalize_frame_by_sample(self, frames: tuple):
"""
normalize frame from a sample picked on the left or the right
"""
_logger.info("apply normalization by a sample")
return tuple(
[
normalize_frame_by_sample(
frame=frame,
side=self.configuration.normalization_by_sample.side,
method=self.configuration.normalization_by_sample.method,
margin_before_sample=self.configuration.normalization_by_sample.margin,
sample_width=self.configuration.normalization_by_sample.width,
)
for frame in frames
]
)
[docs]
@staticmethod
def stitch_frames(
frames: Union[tuple, numpy.ndarray],
axis,
x_relative_shifts: tuple,
y_relative_shifts: tuple,
output_dtype: numpy.ndarray,
stitching_axis: int,
overlap_kernels: tuple,
dumper: DumperBase = None,
check_inputs=True,
shift_mode="nearest",
i_frame=None,
return_composition_cls=False,
alignment="center",
pad_mode="constant",
new_width: Optional[int] = None,
) -> numpy.ndarray:
"""
shift frames according to provided `shifts` (as y, x tuples) then stitch all the shifted frames together and
save them to output_dataset.
:param tuple frames: element must be a DataUrl or a 2D numpy array
:param stitching_regions_hdf5_dataset:
"""
if check_inputs:
if len(frames) < 2:
raise ValueError(f"Not enought frames provided for stitching ({len(frames)} provided)")
if len(frames) != len(x_relative_shifts) + 1:
raise ValueError(
f"Incoherent number of shift provided ({len(x_relative_shifts)}) compare to number of frame ({len(frames)}). len(frames) - 1 expected"
)
if len(x_relative_shifts) != len(overlap_kernels):
raise ValueError(
f"expect to have the same number of x_relative_shifts ({len(x_relative_shifts)}) and y_overlap ({len(overlap_kernels)})"
)
if len(y_relative_shifts) != len(overlap_kernels):
raise ValueError(
f"expect to have the same number of y_relative_shifts ({len(y_relative_shifts)}) and y_overlap ({len(overlap_kernels)})"
)
relative_positions = [(0, 0, 0)]
for y_rel_pos, x_rel_pos in zip(y_relative_shifts, x_relative_shifts):
relative_positions.append(
(
y_rel_pos + relative_positions[-1][0],
0, # position over axis 1 (aka y) is not handled yet
x_rel_pos + relative_positions[-1][2],
)
)
check_overlaps(
frames=tuple(frames),
positions=tuple(relative_positions),
axis=axis,
raise_error=False,
)
def check_frame_is_2d(frame):
if frame.ndim != 2:
raise ValueError(f"2D frame expected when {frame.ndim}D provided")
# step_0 load data if from url
data = []
for frame in frames:
if isinstance(frame, DataUrl):
data_frame = get_data(frame)
if check_inputs:
check_frame_is_2d(data_frame)
data.append(data_frame)
elif isinstance(frame, numpy.ndarray):
if check_inputs:
check_frame_is_2d(frame)
data.append(frame)
else:
raise TypeError(f"frames are expected to be DataUrl or 2D numpy array. Not {type(frame)}")
# step 1: shift each frames (except the first one)
if stitching_axis == 0:
relative_shift_along_stitched_axis = y_relative_shifts
relative_shift_along_unstitched_axis = x_relative_shifts
elif stitching_axis == 1:
relative_shift_along_stitched_axis = x_relative_shifts
relative_shift_along_unstitched_axis = y_relative_shifts
else:
raise NotImplementedError("")
shifted_data = [data[0]]
for frame, relative_shift in zip(data[1:], relative_shift_along_unstitched_axis):
# note: for now we only shift data in x. the y shift is handled in the FrameComposition
relative_shift = numpy.asarray(relative_shift).astype(numpy.int8)
if relative_shift == 0:
shifted_frame = frame
else:
# TO speed up: should use the Fourier transform
shifted_frame = shift_scipy(
frame,
mode=shift_mode,
shift=[0, -relative_shift] if stitching_axis == 0 else [-relative_shift, 0],
order=1,
)
shifted_data.append(shifted_frame)
# step 2: create stitched frame
stitched_frame, composition_cls = stitch_raw_frames(
frames=shifted_data,
key_lines=(
[
(int(frame.shape[stitching_axis] - abs(relative_shift / 2)), int(abs(relative_shift / 2)))
for relative_shift, frame in zip(relative_shift_along_stitched_axis, frames)
]
),
overlap_kernels=overlap_kernels,
check_inputs=check_inputs,
output_dtype=output_dtype,
return_composition_cls=True,
alignment=alignment,
pad_mode=pad_mode,
new_unstitched_axis_size=new_width,
)
dumper.save_stitched_frame(
stitched_frame=stitched_frame,
composition_cls=composition_cls,
i_frame=i_frame,
axis=1,
)
if return_composition_cls:
return stitched_frame, composition_cls
else:
return stitched_frame