quantax.utils.filter_tree_map#

quantax.utils.filter_tree_map(f: Callable, tree: PyTree, *rest: Tuple[PyTree]) PyTree#

The same as jax.tree.map but with filter, which means the map only applies to arrays.