quantax.utils.ScaleArray#

class quantax.utils.ScaleArray#

Array representation with a scale: value = significand * exp(exponent), where exponent is a normalization factor.

The array is a PyTree with two leaves: significand and exponent. To convert it to a dense JAX array, use arr.value() or jnp.asarray(arr).

Note

The same value can be represented by different (significand, exponent) pairs. For example, (e, 0) and (1, 1) both represent the value e. We don’t enforce a canonical form for better performance, but the normalize method can be used to obtain a normalized ScaleArray where the maximum absolute value of the significand is 1.

Warning

JAX doesn’t have a full support for customized arrays, so one should be careful when using ScaleArray. Here we list several possible problems.

1. Manipulations like jnp.fn(array) transform customized arrays to jax.Array. To avoid it, call array.fn() whenever possible.

2. Computations like jax_array * customized_array always call jax_array.__mul__(customized_array), which returns a jax.Array. To avoid it, use customized_array * jax_array.

__array__(dtype=None) ndarray#

Convert to a numpy array.

__jax_array__() Array#

Convert to a JAX array.

__neg__() ScaleArray#

Negate the represented value.

__abs__() ScaleArray#

Absolute value of the represented array.

__mul__(other: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | LogArray | ScaleArray) ScaleArray#

Element-wise multiplication.

__rmul__(other: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | LogArray | ScaleArray) ScaleArray#

Reversed element-wise multiplication.

__truediv__(other: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | LogArray | ScaleArray) ScaleArray#

Element-wise division.

__rtruediv__(other: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | LogArray | ScaleArray) ScaleArray#

Reversed element-wise division.

__pow__(p: int | float | Array) ScaleArray#

Element-wise power.

__add__(other: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | LogArray | ScaleArray) ScaleArray#

Element-wise addition.

__radd__(other: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | LogArray | ScaleArray) ScaleArray#

Reversed element-wise addition.

__sub__(other: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | LogArray | ScaleArray) ScaleArray#

Element-wise subtraction.

__rsub__(other: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | LogArray | ScaleArray) ScaleArray#

Reversed element-wise subtraction.

normalize() ScaleArray#

Return a normalized ScaleArray.

static from_value(x: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | LogArray | ScaleArray) ScaleArray#

Create from a JAX array / Python scalar.

property shape: Tuple[int, ...]#

The shape of the represented array.

property dtype: dtype#

The data type of the represented array.

property ndim: int#

The number of dimensions of the represented array.

property size: int#

The total number of elements in the represented array.

property nbytes: int#

The total number of bytes consumed by the represented array.

property sharding: Sharding#

The sharding of the represented array.

value() Array#

Materialize the dense array value.

property T: ScaleArray#

Transpose the represented array.

property mT: ScaleArray#

Matrix transpose of the represented array.

conj() ScaleArray#

Complex conjugate of the represented value.

property real: ScaleArray#

Real part of the represented array.

property imag: ScaleArray#

Imaginary part of the represented array.

astype(dtype) ScaleArray#

Cast the represented array to given dtype.

sum(axis: int | Tuple[int, ...] | None = None, keepdims: bool = False) ScaleArray#

Sum of array elements over a given axis.

mean(axis: int | Tuple[int, ...] | None = None, keepdims: bool = False) ScaleArray#

Mean of array elements over a given axis.

prod(axis: int | Tuple[int, ...] | None = None, keepdims: bool = False) ScaleArray#

Product of array elements over a given axis.

choose(*args, **kwargs) ScaleArray#

Apply choose to significand and exponent if possible.

compress(*args, **kwargs) ScaleArray#

Apply compress to significand and exponent if possible.

copy(*args, **kwargs) ScaleArray#

Apply copy to significand and exponent if possible.

diagonal(*args, **kwargs) ScaleArray#

Apply diagonal to significand and exponent if possible.

flatten(*args, **kwargs) ScaleArray#

Apply flatten to significand and exponent if possible.

ravel(*args, **kwargs) ScaleArray#

Apply ravel to significand and exponent if possible.

repeat(*args, **kwargs) ScaleArray#

Apply repeat to significand and exponent if possible.

reshape(*args, **kwargs) ScaleArray#

Apply reshape to significand and exponent if possible.

squeeze(*args, **kwargs) ScaleArray#

Apply squeeze to significand and exponent if possible.

swapaxes(*args, **kwargs) ScaleArray#

Apply swapaxes to significand and exponent if possible.

take(*args, **kwargs) ScaleArray#

Apply take to significand and exponent if possible.

transpose(*args, **kwargs) ScaleArray#

Apply transpose to significand and exponent if possible.