quantax.utils.filter_tree_map# quantax.utils.filter_tree_map(f: Callable, tree: PyTree, *rest: Tuple[PyTree]) → PyTree# The same as jax.tree.map but with filter, which means the map only applies to arrays.