quantax.utils.ScaleArray#
- class quantax.utils.ScaleArray#
Array representation with a scale: value = significand * exp(exponent), where exponent is a normalization factor.
The array is a PyTree with two leaves:
significandandexponent. To convert it to a dense JAX array, usearr.value()orjnp.asarray(arr).Note
The same value can be represented by different (significand, exponent) pairs. For example, (e, 0) and (1, 1) both represent the value e. We don’t enforce a canonical form for better performance, but the
normalizemethod can be used to obtain a normalizedScaleArraywhere the maximum absolute value of the significand is 1.Warning
JAX doesn’t have a full support for customized arrays, so one should be careful when using
ScaleArray. Here we list several possible problems.1. Manipulations like
jnp.fn(array)transform customized arrays tojax.Array. To avoid it, callarray.fn()whenever possible.2. Computations like
jax_array * customized_arrayalways calljax_array.__mul__(customized_array), which returns ajax.Array. To avoid it, usecustomized_array * jax_array.- __array__(dtype=None) ndarray#
Convert to a numpy array.
- __jax_array__() Array#
Convert to a JAX array.
- __neg__() ScaleArray#
Negate the represented value.
- __abs__() ScaleArray#
Absolute value of the represented array.
- __mul__(other: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | LogArray | ScaleArray) ScaleArray#
Element-wise multiplication.
- __rmul__(other: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | LogArray | ScaleArray) ScaleArray#
Reversed element-wise multiplication.
- __truediv__(other: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | LogArray | ScaleArray) ScaleArray#
Element-wise division.
- __rtruediv__(other: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | LogArray | ScaleArray) ScaleArray#
Reversed element-wise division.
- __pow__(p: int | float | Array) ScaleArray#
Element-wise power.
- __add__(other: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | LogArray | ScaleArray) ScaleArray#
Element-wise addition.
- __radd__(other: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | LogArray | ScaleArray) ScaleArray#
Reversed element-wise addition.
- __sub__(other: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | LogArray | ScaleArray) ScaleArray#
Element-wise subtraction.
- __rsub__(other: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | LogArray | ScaleArray) ScaleArray#
Reversed element-wise subtraction.
- normalize() ScaleArray#
Return a normalized
ScaleArray.
- static from_value(x: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | LogArray | ScaleArray) ScaleArray#
Create from a JAX array / Python scalar.
- property shape: Tuple[int, ...]#
The shape of the represented array.
- property dtype: dtype#
The data type of the represented array.
- property ndim: int#
The number of dimensions of the represented array.
- property size: int#
The total number of elements in the represented array.
- property nbytes: int#
The total number of bytes consumed by the represented array.
- property sharding: Sharding#
The sharding of the represented array.
- value() Array#
Materialize the dense array value.
- property T: ScaleArray#
Transpose the represented array.
- property mT: ScaleArray#
Matrix transpose of the represented array.
- conj() ScaleArray#
Complex conjugate of the represented value.
- property real: ScaleArray#
Real part of the represented array.
- property imag: ScaleArray#
Imaginary part of the represented array.
- astype(dtype) ScaleArray#
Cast the represented array to given dtype.
- sum(axis: int | Tuple[int, ...] | None = None, keepdims: bool = False) ScaleArray#
Sum of array elements over a given axis.
- mean(axis: int | Tuple[int, ...] | None = None, keepdims: bool = False) ScaleArray#
Mean of array elements over a given axis.
- prod(axis: int | Tuple[int, ...] | None = None, keepdims: bool = False) ScaleArray#
Product of array elements over a given axis.
- choose(*args, **kwargs) ScaleArray#
Apply
chooseto significand and exponent if possible.
- compress(*args, **kwargs) ScaleArray#
Apply
compressto significand and exponent if possible.
- copy(*args, **kwargs) ScaleArray#
Apply
copyto significand and exponent if possible.
- diagonal(*args, **kwargs) ScaleArray#
Apply
diagonalto significand and exponent if possible.
- flatten(*args, **kwargs) ScaleArray#
Apply
flattento significand and exponent if possible.
- ravel(*args, **kwargs) ScaleArray#
Apply
ravelto significand and exponent if possible.
- repeat(*args, **kwargs) ScaleArray#
Apply
repeatto significand and exponent if possible.
- reshape(*args, **kwargs) ScaleArray#
Apply
reshapeto significand and exponent if possible.
- squeeze(*args, **kwargs) ScaleArray#
Apply
squeezeto significand and exponent if possible.
- swapaxes(*args, **kwargs) ScaleArray#
Apply
swapaxesto significand and exponent if possible.
- take(*args, **kwargs) ScaleArray#
Apply
taketo significand and exponent if possible.
- transpose(*args, **kwargs) ScaleArray#
Apply
transposeto significand and exponent if possible.