from __future__ import annotations
from typing import TYPE_CHECKING, Callable, Literal, TypeVar, overload
from functools import wraps
import numpy as np
import inspect
from ..array_api import xp
if TYPE_CHECKING:
from ..arrays import LazyImgArray, ImgArray
from ..arrays.axesmixin import AxesMixin
from typing_extensions import ParamSpec
_P = ParamSpec("_P")
_R = TypeVar("_R")
__all__ = [
"check_input_and_output",
"check_input_and_output_lazy",
"same_dtype",
"dims_to_spatial_axes",
]
@overload
def check_input_and_output(
func: Literal[None],
*,
inherit_label_info: bool = False,
only_binary: bool = False,
need_labels: bool = False,
) -> Callable[[Callable[_P, _R]], Callable[_P, _R]]:
...
@overload
def check_input_and_output(
func: Callable[_P, _R],
*,
inherit_label_info: bool = False,
only_binary: bool = False,
need_labels: bool = False,
) -> Callable[_P, _R]:
...
@overload
def check_input_and_output_lazy(func: Callable[_P, _R], *, only_binary: bool = False) -> Callable[_P, _R]:
...
@overload
def check_input_and_output_lazy(func: Literal[None], *, only_binary: bool = False) -> Callable[[Callable[_P, _R], Callable[_P, _R]]]:
...
@overload
def same_dtype(func: Callable[_P, _R], asfloat: bool = False) -> Callable[_P, _R]:
...
@overload
def same_dtype(func: Literal[None], asfloat: bool = False) -> Callable[[Callable[_P, _R]], Callable[_P, _R]]:
...
[docs]def same_dtype(func=None, asfloat: bool = False):
"""
Decorator to assure output image has the same dtype as the input image.
This decorator is compatible with both ImgArray and LazyImgArray.
Parameters
----------
asfloat : bool, optional
If input image should be converted to float first, by default False
"""
def f(func: Callable[_P, _R]) -> Callable[_P, _R]:
@wraps(func)
def _same_dtype(self: ImgArray, *args, **kwargs):
dtype = self.dtype
if asfloat and self.dtype.kind in "ui":
self = self.as_float()
out: ImgArray = func(self, *args, **kwargs)
out = out.as_img_type(dtype)
return out
return _same_dtype
return f if func is None else f(func)
[docs]def dims_to_spatial_axes(func: Callable[_P, _R]) -> Callable[_P, _R]:
"""
Decorator to convert input `dims` to correct spatial axes. Compatible with ImgArray and
LazyImgArray
e.g.)
dims=None (default) -> "yx" or "zyx" depend on the input image
dims=2 -> "yx"
dims=3 -> "zyx"
dims="ty" -> "ty"
"""
@wraps(func)
def _dims_to_spatial_axes(self: AxesMixin, *args, **kwargs):
dims = kwargs.get(
"dims",
inspect.signature(func).parameters["dims"].default
)
if dims is None or dims == "":
dims = len([a for a in "zyx" if a in self.axes])
if dims not in (1, 2, 3):
raise ValueError(
f"Image spatial dimension must be 2 or 3, but {dims} was detected. If "
"image axes is not a standard one, such as 'tx' in kymograph, specify "
"the spatial axes by dims='tx' or dims='x'."
)
if isinstance(dims, int):
s_axes = [a for a in "zyx" if a in self.axes][-dims:]
else:
s_axes = list(dims)
kwargs["dims"] = s_axes # update input
return func(self, *args, **kwargs)
return _dims_to_spatial_axes