quantax.utils.LogArray#

class quantax.utils.LogArray#

Log-amplitude representation of JAX arrays: value = sign * exp(logabs) where sign is \(\pm 1\) or a complex phase and logabs is real. Zero is encoded by sign=0, logabs=-inf.

The array is a PyTree with two leaves: sign and logabs. To convert it to a dense JAX array, use arr.value() or jnp.asarray(arr).

Warning

JAX doesn’t have a full support for customized arrays, so one should be careful when using LogArray. Here we list several possible problems.

1. Manipulations like jnp.fn(array) transform customized arrays to jax.Array. To avoid it, call array.fn() whenever possible.

2. Computations like jax_array * customized_array always call jax_array.__mul__(customized_array), which returns a jax.Array. To avoid it, use customized_array * jax_array.

__array__(dtype=None) ndarray#

Convert to a numpy array.

__jax_array__() Array#

Convert to a JAX array.

__neg__() LogArray#

Negation of the represented value.

__abs__() LogArray#

Absolute value of the represented array.

__mul__(other: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | LogArray | ScaleArray) LogArray#

Element-wise multiplication.

__rmul__(other: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | LogArray | ScaleArray) LogArray#

Reversed element-wise multiplication.

__truediv__(other: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | LogArray | ScaleArray) LogArray#

Element-wise division.

__rtruediv__(other: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | LogArray | ScaleArray) LogArray#

Reversed element-wise division.

__pow__(p: int | float | Array) LogArray#

Element-wise power.

__add__(other: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | LogArray | ScaleArray) LogArray#

Element-wise addition.

__radd__(other: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | LogArray | ScaleArray) LogArray#

Reversed element-wise addition.

__sub__(other: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | LogArray | ScaleArray) LogArray#

Element-wise subtraction.

__rsub__(other: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | LogArray | ScaleArray) LogArray#

Reversed element-wise subtraction.

static from_value(x: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | LogArray | ScaleArray) LogArray#

Create from a JAX array / Python scalar.

property shape: Tuple[int, ...]#

The shape of the represented array.

property dtype: dtype#

The data type of the represented array.

property ndim: int#

The number of dimensions of the represented array.

property size: int#

The total number of elements in the represented array.

property nbytes: int#

The total number of bytes consumed by the represented array.

property sharding: Sharding#

The sharding of the represented array.

value() Array#

Materialize the dense array value.

property T: LogArray#

Transpose of the represented array.

property mT: LogArray#

Matrix transpose of the represented array.

conj() LogArray#

Complex conjugate of the represented array.

property real: LogArray#

Real part of the represented array.

property imag: LogArray#

Imaginary part of the represented array.

astype(dtype) LogArray#

Cast the represented array to given dtype.

sum(axis: int | Tuple[int, ...] | None = None, keepdims: bool = False) LogArray#

Sum of array elements over a given axis.

mean(axis: int | Tuple[int, ...] | None = None, keepdims: bool = False) LogArray#

Mean of array elements over a given axis.

prod(axis: int | Tuple[int, ...] | None = None, keepdims: bool = False) LogArray#

Product of array elements over a given axis.

choose(*args, **kwargs) LogArray#

Apply choose to sign and logabs component-wise.

compress(*args, **kwargs) LogArray#

Apply compress to sign and logabs component-wise.

copy(*args, **kwargs) LogArray#

Apply copy to sign and logabs component-wise.

diagonal(*args, **kwargs) LogArray#

Apply diagonal to sign and logabs component-wise.

flatten(*args, **kwargs) LogArray#

Apply flatten to sign and logabs component-wise.

ravel(*args, **kwargs) LogArray#

Apply ravel to sign and logabs component-wise.

repeat(*args, **kwargs) LogArray#

Apply repeat to sign and logabs component-wise.

reshape(*args, **kwargs) LogArray#

Apply reshape to sign and logabs component-wise.

squeeze(*args, **kwargs) LogArray#

Apply squeeze to sign and logabs component-wise.

swapaxes(*args, **kwargs) LogArray#

Apply swapaxes to sign and logabs component-wise.

take(*args, **kwargs) LogArray#

Apply take to sign and logabs component-wise.

transpose(*args, **kwargs) LogArray#

Apply transpose to sign and logabs component-wise.