quantax.utils.LogArray#
- class quantax.utils.LogArray#
Log-amplitude representation of JAX arrays: value = sign * exp(logabs) where
signis \(\pm 1\) or a complex phase andlogabsis real. Zero is encoded by sign=0, logabs=-inf.The array is a PyTree with two leaves:
signandlogabs. To convert it to a dense JAX array, usearr.value()orjnp.asarray(arr).Warning
JAX doesn’t have a full support for customized arrays, so one should be careful when using
LogArray. Here we list several possible problems.1. Manipulations like
jnp.fn(array)transform customized arrays tojax.Array. To avoid it, callarray.fn()whenever possible.2. Computations like
jax_array * customized_arrayalways calljax_array.__mul__(customized_array), which returns ajax.Array. To avoid it, usecustomized_array * jax_array.- __array__(dtype=None) ndarray#
Convert to a numpy array.
- __jax_array__() Array#
Convert to a JAX array.
- __mul__(other: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | LogArray | ScaleArray) LogArray#
Element-wise multiplication.
- __rmul__(other: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | LogArray | ScaleArray) LogArray#
Reversed element-wise multiplication.
- __truediv__(other: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | LogArray | ScaleArray) LogArray#
Element-wise division.
- __rtruediv__(other: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | LogArray | ScaleArray) LogArray#
Reversed element-wise division.
- __add__(other: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | LogArray | ScaleArray) LogArray#
Element-wise addition.
- __radd__(other: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | LogArray | ScaleArray) LogArray#
Reversed element-wise addition.
- __sub__(other: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | LogArray | ScaleArray) LogArray#
Element-wise subtraction.
- __rsub__(other: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | LogArray | ScaleArray) LogArray#
Reversed element-wise subtraction.
- static from_value(x: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | LogArray | ScaleArray) LogArray#
Create from a JAX array / Python scalar.
- property shape: Tuple[int, ...]#
The shape of the represented array.
- property dtype: dtype#
The data type of the represented array.
- property ndim: int#
The number of dimensions of the represented array.
- property size: int#
The total number of elements in the represented array.
- property nbytes: int#
The total number of bytes consumed by the represented array.
- property sharding: Sharding#
The sharding of the represented array.
- value() Array#
Materialize the dense array value.
- sum(axis: int | Tuple[int, ...] | None = None, keepdims: bool = False) LogArray#
Sum of array elements over a given axis.
- mean(axis: int | Tuple[int, ...] | None = None, keepdims: bool = False) LogArray#
Mean of array elements over a given axis.