from __future__ import annotations
from contextlib import contextmanager
from typing import (
Any,
Callable,
Iterator,
Literal,
Generic,
Sequence,
TypeVar,
overload,
)
import numpy as np
from numpy.typing import NDArray
from . import _bandpass, _missing_wedge
from acryo._types import degree
from scipy.spatial.transform import Rotation
_T = TypeVar("_T", bound=np.generic)
_T1 = TypeVar("_T1", bound=np.generic)
# fmt: off
[docs]class AnyArray(Generic[_T]):
"""
Type representing a ndarray of numpy or cupy (or any other array
that has similar API).
"""
def __pos__(self) -> AnyArray[_T]: ...
def __neg__(self) -> AnyArray[_T]: ...
def __invert__(self) -> AnyArray[_T]: ...
def __add__(self, other: Any) -> AnyArray[_T]: ... # type: ignore
def __sub__(self, other: Any) -> AnyArray[_T]: ... # type: ignore
def __mul__(self, other: Any) -> AnyArray[_T]: ... # type: ignore
def __truediv__(self, other: Any) -> AnyArray[np.float_]: ... # type: ignore
def __radd__(self, other: Any) -> AnyArray[_T]: ... # type: ignore
def __rsub__(self, other: Any) -> AnyArray[_T]: ... # type: ignore
def __rmul__(self, other: Any) -> AnyArray[_T]: ... # type: ignore
def __rtruediv__(self, other: Any) -> AnyArray[np.float_]: ... # type: ignore
def __floordiv__(self, other: Any) -> AnyArray[np.intp]: ... # type: ignore
def __gt__(self, other: AnyArray[_T] | float) -> AnyArray[_T]: ... # type: ignore
def __lt__(self, other: AnyArray[_T] | float) -> AnyArray[_T]: ... # type: ignore
def __ge__(self, other: AnyArray[_T] | float) -> AnyArray[_T]: ... # type: ignore
def __le__(self, other: AnyArray[_T] | float) -> AnyArray[_T]: ... # type: ignore
def __eq__(self, other: AnyArray[_T] | float) -> AnyArray[_T]: ... # type: ignore
def __pow__(self, other: AnyArray[_T] | float) -> AnyArray[_T]: ... # type: ignore
def __getitem__(self, key) -> AnyArray[_T]: ... # type: ignore
def __setitem__(self, key, value) -> None: ... # type: ignore
def __iter__(self) -> Iterator[AnyArray[_T]]: ... # type: ignore
@property
def real(self) -> AnyArray[np.float32]: ... # type: ignore
@property
def imag(self) -> AnyArray[np.float32]: ... # type: ignore
[docs] def conj(self) -> AnyArray[_T]: ... # type: ignore
@property
def shape(self) -> tuple[int, ...]: ... # type: ignore
@property
def ndim(self) -> int: ...
@property
def dtype(self) -> np.dtype[_T]: ... # type: ignore
[docs] def dot(self, other: AnyArray[_T]) -> AnyArray[_T]: ... # type: ignore
[docs] def astype(self, dtype: type[_T1]) -> AnyArray[_T1]: ... # type: ignore
@overload
def mean(self, axis: None = None) -> _T: ... # type: ignore
@overload
def mean(self, axis: int | tuple[int, ...]) -> AnyArray[_T]: ... # type: ignore
# fmt: on
[docs]class Backend:
_default = "numpy"
def __init__(self, name: str | None = None) -> None:
if name is None:
name = self._default
if name == "numpy":
from scipy import ndimage, fft # type: ignore
self._xp_ = np
self._ndi_ = ndimage
self._fft_ = fft
elif name == "cupy":
import cupy
from cupyx.scipy import ndimage, fft
self._xp_ = cupy
self._ndi_ = ndimage
self._fft_ = fft
else:
raise ValueError(f"Unknown backend {name}")
@property
def name(self) -> str:
return self._xp_.__name__
def __hash__(self) -> int:
"""Hash using the backend module."""
return hash(self._xp_)
def __repr__(self) -> str:
return f"Backend<{self.name}>"
[docs] def asnumpy(self, x: AnyArray[_T] | NDArray[_T]) -> NDArray[_T]:
"""Convert to numpy array."""
if self._xp_ is np:
return x # type: ignore
return x.get() # type: ignore
[docs] def maycopy(self, x: AnyArray[_T]) -> AnyArray[_T]:
if self._xp_ is np:
return x
return x.copy() # type: ignore
@overload
def array(self, x, dtype: type[_T] | np.dtype[_T]) -> AnyArray[_T]:
...
@overload
def array(self, x: AnyArray[_T] | NDArray[_T], dtype: None = None) -> AnyArray[_T]:
...
[docs] def array(self, x, dtype=None): # type: ignore
"""Convert to numpy array."""
return self._xp_.array(x, dtype) # type: ignore
@overload
def asarray(self, x, dtype: type[_T] | np.dtype[_T]) -> AnyArray[_T]:
...
@overload
def asarray(
self, x: AnyArray[_T] | NDArray[_T], dtype: None = None
) -> AnyArray[_T]:
...
[docs] def asarray(self, x, dtype=None): # type: ignore
"""Convert to numpy array."""
return self._xp_.asarray(x, dtype) # type: ignore
@overload
def arange(self, *args, dtype: type[_T], **kwargs) -> AnyArray[_T]:
...
@overload
def arange(self, *args, dtype: None = None, **kwargs) -> AnyArray:
...
[docs] def arange(self, *args, dtype=None, **kwargs):
"""Return evenly spaced values within a given interval."""
return self._xp_.arange(*args, dtype=dtype, **kwargs) # type: ignore
[docs] def zeros(
self, shape: int | tuple[int, ...], dtype: type[_T] | np.dtype[_T] | None = None
) -> AnyArray[_T]:
"""Return a new array of given shape and type, filled with zeros."""
return self._xp_.zeros(shape, dtype) # type: ignore
[docs] def full(
self,
shape: int | tuple[int, ...],
fill_value: Any,
dtype: type[_T] | np.dtype[_T] | None = None,
) -> AnyArray[_T]:
"""Return a new array of given shape and type, filled with fill_value."""
return self._xp_.full(shape, fill_value, dtype=dtype) # type: ignore
@overload
def sum(self, x: AnyArray[_T], axis: None = None) -> _T:
...
@overload
def sum(self, x: AnyArray[_T], axis: int | tuple[int, ...]) -> AnyArray[_T]:
...
[docs] def sum(self, x, axis=None):
"""Return the sum of array elements over a given axis."""
return self._xp_.sum(x, axis=axis)
@overload
def mean(self, x: AnyArray[_T], axis: None = None) -> _T:
...
@overload
def mean(self, x: AnyArray[_T], axis: int | tuple[int, ...]) -> AnyArray[_T]:
...
[docs] def mean(self, x, axis=None):
"""Return the mean of array elements over a given axis."""
return self._xp_.mean(x, axis=axis)
[docs] def cumsum(self, x: AnyArray[_T], axis: int | None = None) -> AnyArray[_T]:
return self._xp_.cumsum(x, axis=axis) # type: ignore
[docs] def sqrt(self, x: AnyArray[_T]) -> AnyArray[_T]:
"""Return the non-negative square-root of an array."""
return self._xp_.sqrt(x) # type: ignore
[docs] def exp(self, x: AnyArray[_T]) -> AnyArray[_T]:
"""Return the exponential of an array."""
return self._xp_.exp(x) # type: ignore
[docs] def pad(
self,
x: AnyArray[_T],
pad_width: int | Sequence[int] | Sequence[tuple[int, int]],
mode: str = "constant",
constant_values: float = 0.0,
) -> AnyArray[_T]:
"""Pad an array."""
return self._xp_.pad(x, pad_width, mode=mode, constant_values=constant_values) # type: ignore
[docs] def tensordot(
self, a: AnyArray[_T], b: AnyArray[_T], axes: int | tuple[int, ...] = 2
) -> AnyArray[_T]:
"""Return tensor dot product of two arrays."""
return self._xp_.tensordot(a, b, axes) # type: ignore
@overload
def max(self, x: AnyArray[_T], axis: None = None) -> _T:
...
@overload
def max(self, x: AnyArray[_T], axis: int | tuple[int, ...]) -> AnyArray[_T]:
...
[docs] def max(self, x, axis=None):
"""Return the maximum of an array or maximum along an axis."""
return self._xp_.max(x, axis=axis)
@overload
def min(self, x: AnyArray[_T], axis: None = None) -> _T:
...
@overload
def min(self, x: AnyArray[_T], axis: int | tuple[int, ...]) -> AnyArray[_T]:
...
[docs] def min(self, x, axis=None):
"""Return the minimum of an array or minimum along an axis."""
return self._xp_.min(x, axis=axis)
@overload
def percentile(self, x: AnyArray[_T], q: float, axis: None = None) -> _T:
...
@overload
def percentile(
self, x: AnyArray[_T], q: float, axis: int | tuple[int, ...]
) -> AnyArray[_T]:
...
[docs] def percentile(self, x, q, axis=None):
"""Compute the q-th percentile of the data along the specified axis."""
return self._xp_.percentile(x, q, axis=axis)
@overload
def argmin(self, x: AnyArray[_T], axis: None = None) -> np.intp:
...
@overload
def argmin(self, x: AnyArray[_T], axis: int | tuple[int, ...]) -> AnyArray[np.intp]:
...
[docs] def argmin(self, x, axis=None): # type: ignore
return self._xp_.argmin(x, axis=axis) # type: ignore
@overload
def argmax(self, x: AnyArray[_T], axis: None = None) -> np.intp:
...
@overload
def argmax(self, x: AnyArray[_T], axis: int | tuple[int, ...]) -> AnyArray[np.intp]:
...
[docs] def argmax(self, x, axis=None): # type: ignore
return self._xp_.argmax(x, axis=axis) # type: ignore
[docs] def fix(self, x: AnyArray[_T]) -> AnyArray[_T]:
"""Round to nearest integer towards zero."""
return self._xp_.fix(x) # type: ignore
[docs] def fftn(
self,
x: AnyArray[np.float32] | AnyArray[np.complex64],
s: tuple[int, int, int] | None = None,
axes: int | tuple[int, ...] | None = None,
) -> AnyArray[np.complex64]:
"""N-dimensional FFT."""
return self._fft_.fftn(x, s, axes) # type: ignore
[docs] def ifftn(
self,
x: AnyArray[np.float32] | AnyArray[np.complex64],
s: tuple[int, int, int] | None = None,
axes: int | tuple[int, ...] | None = None,
) -> AnyArray[np.complex64]:
"""N-dimensional inverse FFT."""
return self._fft_.ifftn(x, s, axes) # type: ignore
[docs] def rfftn(
self,
x: AnyArray[np.float32],
s: tuple[int, int, int] | None = None,
axes: int | tuple[int, ...] | None = None,
) -> AnyArray[np.complex64]:
"""N-dimensional FFT of real part."""
return self._fft_.rfftn(x, s, axes) # type: ignore
[docs] def irfftn(
self,
x: AnyArray[np.complex64],
s: tuple[int, int, int] | None = None,
axes: int | tuple[int, ...] | None = None,
) -> AnyArray[np.float32]:
"""N-dimensional inverse FFT of real part."""
return self._fft_.irfftn(x, s, axes) # type: ignore
[docs] def fftshift(self, x: AnyArray[_T], axes=None) -> AnyArray[_T]:
"""Shift zero-frequency component to center."""
return self._xp_.fft.fftshift(x, axes=axes) # type: ignore
[docs] def ifftshift(self, x: AnyArray[_T], axes=None) -> AnyArray[_T]:
"""Inverse shift zero-frequency component to center."""
return self._xp_.fft.ifftshift(x, axes=axes) # type: ignore
[docs] def fftfreq(self, n: int, d: float = 1.0) -> AnyArray[np.float_]:
"""Return the Discrete Fourier Transform sample frequencies."""
return self._xp_.fft.fftfreq(n, d) # type: ignore
@overload
def meshgrid(
self,
x0: AnyArray[_T],
copy: bool = True,
sparse: bool = False,
indexing: Literal["xy", "ij"] = "xy",
) -> tuple[AnyArray[_T]]:
...
@overload
def meshgrid(
self,
x0: AnyArray[_T],
x1: AnyArray[_T],
copy: bool = True,
sparse: bool = False,
indexing: Literal["xy", "ij"] = "xy",
) -> tuple[AnyArray[_T], AnyArray[_T]]:
...
@overload
def meshgrid(
self,
x0: AnyArray[_T],
x1: AnyArray[_T],
x2: AnyArray[_T],
copy: bool = True,
sparse: bool = False,
indexing: Literal["xy", "ij"] = "xy",
) -> tuple[AnyArray[_T], AnyArray[_T], AnyArray[_T]]:
...
[docs] def meshgrid(self, *xi, copy=True, sparse=False, indexing="xy"): # type: ignore
"""Return coordinate matrices from coordinate vectors."""
return self._xp_.meshgrid(*xi, copy=copy, sparse=sparse, indexing=indexing) # type: ignore
@overload
def indices(
self, shape: tuple[int], dtype: type[_T] = np.int32
) -> tuple[AnyArray[_T]]:
...
@overload
def indices(
self, shape: tuple[int, int], dtype: type[_T] = np.int32
) -> tuple[AnyArray[_T], AnyArray[_T]]:
...
@overload
def indices(
self, shape: tuple[int, int, int], dtype: type[_T] = np.int32
) -> tuple[AnyArray[_T], AnyArray[_T], AnyArray[_T]]:
...
@overload
def indices(
self, shape: tuple[int, ...], dtype: type[_T] = np.int32
) -> tuple[AnyArray[_T], ...]:
...
[docs] def indices(self, shape, dtype=np.int32): # type: ignore
"""Return an array representing the indices of a grid."""
return self._xp_.indices(shape, dtype=dtype) # type: ignore
[docs] def unravel_index(self, indices, shape: tuple[int, ...]) -> AnyArray[np.intp]:
"""Converts a flat index into a tuple of coordinate arrays."""
return self._xp_.asarray(self._xp_.unravel_index(indices, shape)) # type: ignore
[docs] def stack(self, arrays: Sequence[AnyArray[_T]], axis: int = 0) -> AnyArray[_T]:
"""Stack arrays in sequence along a new axis."""
return self._xp_.stack(arrays, axis=axis) # type: ignore
[docs] def spline_filter(
self,
input,
order: int = 3,
output: type[_T] = np.float64,
mode: str = "mirror",
) -> AnyArray[_T]:
return self._ndi_.spline_filter(
self.asarray(input), order=order, output=output, mode=mode # type: ignore
)
[docs] def map_coordinates(
self,
x: AnyArray[_T],
coords: AnyArray[_T],
order: int = 3,
mode: str = "constant",
cval: float = -1.0,
prefilter: bool = True,
) -> AnyArray[_T]:
return self._ndi_.map_coordinates(
x, coords, order=order, mode=mode, cval=cval, prefilter=prefilter
) # type: ignore
[docs] def rotated_crop(
self,
subimg,
mtx: NDArray[np.float32],
shape: tuple[int, int, int],
order: int,
cval: float | Callable[[AnyArray[np.float32]], Any],
) -> AnyArray[np.float32]:
if callable(cval):
_cval = cval(subimg)
else:
_cval = cval
out = self.affine_transform(
subimg,
matrix=self.asarray(mtx),
output_shape=shape,
order=order,
prefilter=order > 1,
mode="constant",
cval=float(_cval),
)
return out
[docs] def lowpass_filter_ft(self, img, cutoff, order: int = 2) -> AnyArray[np.complex64]:
"""Lowpass filter in Fourier space."""
return _bandpass.lowpass_filter_ft(self, self.asarray(img), cutoff, order)
[docs] def lowpass_filter(self, img, cutoff, order: int = 2) -> AnyArray[np.float32]:
"""Lowpass filter in real space."""
return _bandpass.lowpass_filter(self, self.asarray(img), cutoff, order)
[docs] def missing_wedge_mask(
self,
rotator: Rotation,
tilt_range: tuple[degree, degree],
shape: tuple[int, int, int],
):
return _missing_wedge.missing_wedge_mask(self, rotator, tilt_range, shape)
NUMPY_BACKEND = Backend("numpy")
[docs]@contextmanager
def using_backend(name: str):
"""Context manager to temporarily change the default backend."""
old_backend = Backend._default
Backend._default = name
try:
yield
finally:
Backend._default = old_backend
[docs]def set_backend(name: str):
"""Set the default backend."""
Backend._default = name