Source code for seemps.analysis.comptree.chain

from __future__ import annotations
import numpy as np
from typing import Callable, overload
import dataclasses
from ...tools import make_logger, Logger
from ...typing import Vector, Matrix
from .branch import (
    BranchNode,
    propagate_images,
    get_transitions,
    assemble_cores,
)
from .sparse_mps import SparseMPS


class ChainRoot:
    """
    Terminal node for a `ChainTree`, representing the final binary functional dependence.

    Represents a binary function f(x_L, x_s), where x_L is the value propagated from the evaluation
    of the left subtrees, and x_s is selected from the given discretization grid.

    The length of the discretization grid determines the physical dimension of the corresponding MPS core.
    """

    func: Callable
    grid: Vector
    N: int

    def __init__(self, func: Callable, grid: Vector):
        self.func = func
        self.grid = grid
        self.N = len(grid)

    @overload
    def evaluate(self, x_L: float | None, s: int) -> float: ...

    @overload
    def evaluate(self, x_L: np.ndarray | None, s: np.ndarray) -> np.ndarray: ...

    def evaluate(
        self, x_L: float | np.ndarray | None, s: int | np.ndarray
    ) -> np.ndarray | float:
        if x_L is None:
            return 0
        x_s = self.grid[s]
        return self.func(x_L, x_s)


@dataclasses.dataclass
class ChainTree:
    """
    Chain-like computation-tree representation of a multivariate function.

    This class encodes a multivariate function with a chain-like algebraic structure:

        f_n[...f_2(f_1(None, s_1), s_2), ..., s_n].

    The tree consists of a sequence of left nodes (`BranchNode`) that define intermediate functional
    updates f_k, k = 1,...,n-1, and a terminal root node (`ChainRoot`) that applies the final binary
    function f_n and terminates the chain evaluation.

    This representation can be efficiently loaded into a MPS using :func:`mps_chain_tree`.

    Examples
    --------
    A chain-like tree for the Heaviside function applied to the running sum

        f(x₁,x₂,... ) = Θ(x₁ + x₂ + ...)

    can be constructed as follows:

    - Each `BranchNode` represents the partial accumulation f_i(x_in, x_s) = x_in + x_s where
      its discretization grid is the domain of x_i and s ranges over its support.
    - The `ChainRoot` applies the Heaviside step f_n(x_{n-1}, x_s) = Θ(x_{n-1} + x_s).
    """

    left_nodes: list[BranchNode]
    root_node: ChainRoot

    # Keep type checker happy
    center: int = dataclasses.field(init=False)
    physical_dimensions: list[int] = dataclasses.field(init=False)
    length: int = dataclasses.field(init=False)

    def __post_init__(self):
        self.center = len(self.left_nodes)
        left_dimensions = [node.N for node in self.left_nodes]
        self.physical_dimensions = left_dimensions + [self.root_node.N]
        self.length = len(self.physical_dimensions)


[docs] def mps_chain_tree( chain_tree: ChainTree, allowed_support: Vector | None = None ) -> SparseMPS: """ Compute the MPS representation of a multivariate function encoded as a `ChainTree`. Returns a `SparseMPS`, whose cores are highly sparse and represented as collections of CSR matrices. This construction is particularly efficient for functions with a reduced support, which are recompressed with a right-to-left binning sweep. Source: https://arxiv.org/abs/2206.03832 Parameters ---------- chain_tree : ChainTree Chain-like computation-tree representation of the target function. allowed_support : Vector, optional Discrete set of admissible function outputs (e.g., `[0, 1]` for a Heaviside function). Returns ------- SparseMPS Sparse MPS approximation of the target multivariate function. """ with make_logger(2) as logger: logger("Computing branch images:") left_images = propagate_images(chain_tree.left_nodes, logger) if allowed_support is not None: # Recompression sweep (more efficient) logger("Computing grouped transitions:") root_transition, left_transitions = _recompress_transitions( chain_tree, left_images, allowed_support, logger ) else: logger("Computing transitions:") left_transitions = get_transitions( chain_tree.left_nodes, left_images, logger ) root_transition = { (k_L, s): chain_tree.root_node.evaluate(x_L, s) for k_L, x_L in enumerate(left_images[-1]) for s in range(chain_tree.root_node.N) } logger("Computing MPS cores:") left_cores = assemble_cores(left_transitions, logger) # Compute root core coords = np.array(list(root_transition.keys())) values = np.array(list(root_transition.values())) χ_L = 1 + np.max(coords[:, 0]) N = 1 + np.max(coords[:, 1]) shape = (χ_L.item(), N.item(), 1) root_core = np.zeros(shape) root_core[tuple(coords.T)] = values.reshape(-1, 1) logger(f"Root node | Core of shape {shape}.") cores = left_cores + [root_core] return SparseMPS(cores)
def _recompress_transitions( chain_tree: ChainTree, left_images: list[Vector], tensor_values: Vector, logger: Logger = Logger(), ) -> tuple[dict, list[dict]]: l = chain_tree.length - 1 # Root node transition root_node = chain_tree.root_node x_in = left_images[-1] root_image = root_node.evaluate(x_in[:, np.newaxis], np.arange(root_node.N)) root_image = _round_matrix_to_vector(root_image, tensor_values) root_transition = _build_transition(root_image) x_grouped = _group_x(x_in, root_image) logger(f"Node {l}/{l} | Image compressed ({len(x_in)} -> {len(x_grouped)}).") # Left node transitions left_transitions = [] for i in reversed(range(l)): left_node = chain_tree.left_nodes[i] x_in = left_images[i] left_image = left_node.evaluate(x_in[:, np.newaxis], np.arange(left_node.N)) assert left_image is not None # To keep the type checker happy A = _map_to_group(left_image, x_grouped) transition = _build_transition(A) left_transitions.append(transition) x_grouped = _group_x(x_in, A) logger(f"Node {i}/{l} | Image compressed ({len(x_in)} -> {len(x_grouped)}).") return root_transition, left_transitions[::-1] def _round_matrix_to_vector(matrix: Matrix, vector: Vector) -> Matrix: """Rounds elements in `matrix` to the nearest values in `vector`.""" diff = np.abs(matrix[..., np.newaxis] - vector) indices = np.argmin(diff, axis=-1) return vector[indices] def _group_x(x_in: Vector, image: Matrix) -> tuple[Vector, ...]: """Groups elements of `x_in` based on unique rows in `a`.""" # TODO: Optimize index_map = {} for index, row in enumerate(image): row_tuple = tuple(row) if row_tuple not in index_map: index_map[row_tuple] = np.array([x_in[index]]) else: index_map[row_tuple] = np.append(index_map[row_tuple], x_in[index]) return tuple(index_map.values()) def _build_transition(image: Matrix) -> dict: """Efficiently constructs a dictionary mapping (k, s) to unique values in matrix `image`.""" unique_rows = _get_unique_rows(image) k_indices = np.arange(unique_rows.shape[0]) s_indices = np.arange(unique_rows.shape[1]) k_grid, s_grid = np.meshgrid(k_indices, s_indices, indexing="ij") keys = zip(k_grid.reshape(-1), s_grid.reshape(-1)) values = unique_rows.reshape(-1) return {key: val for key, val in zip(keys, values)} def _map_to_group(image: Matrix, x_grouped: tuple[Vector, ...]) -> Matrix: """Maps values in `a` to group indices efficiently.""" bounds = np.array([group[-1] for group in x_grouped]) A = np.searchsorted(bounds, image, side="left") return np.minimum(A, len(x_grouped) - 1) def _get_unique_rows(image: Matrix) -> Matrix: """Gets the unique rows of the matrix `image` without reordering.""" _, idx = np.unique(image, axis=0, return_index=True) return image[np.sort(idx)]