Source code for seemps.analysis.comptree.binary

from __future__ import annotations
import numpy as np
from typing import Callable, overload
import dataclasses
from ...tools import make_logger
from ...typing import Vector
from .branch import (
    BranchNode,
    propagate_images,
    get_transitions,
    assemble_cores,
)
from .sparse_mps import SparseMPS


class BinaryRoot:
    """
    Terminal node for a `BinaryTree`, representing the final ternary functional dependence.

    Represents a ternary function f(x_L, x_s, x_R), where x_L and x_R are values propagated
    from the left and right subtrees, respectively, and x_s is obtained by evaluating the node's
    discretization grid at index s.

    The length of the discretization grid determines the physical dimension of the corresponding MPS core.
    """

    def __init__(self, func: Callable, grid: Vector):
        self.func = func
        self.grid = grid
        self.N = len(grid)

    @overload
    def evaluate(self, x_L: float | None, s: int, x_R: float | None) -> float: ...

    @overload
    def evaluate(
        self, x_L: np.ndarray | None, s: np.ndarray, x_R: np.ndarray | None
    ) -> np.ndarray: ...

    def evaluate(self, x_L, s, x_R):
        if x_L is None or x_R is None:
            return 0
        x_s = self.grid[s]
        return self.func(x_L, x_s, x_R)


@dataclasses.dataclass
class BinaryTree:
    """
    Binary computation-tree representation of a multivariate function.

    This class encodes a multivariate function with a binary branching algebraic structure:

        f[ g1( g11(...), g12(...)) , g2( g21(...), g22(...) )],

    i.e. in which two chain-like subtrees are evaluated and merged at a terminal root.
    The left and right subtrees are represented by sequences of nodes (`BranchNode`), and
    their aggregated values are combined by a `BinaryRoot` through a ternary function.

    This representation can be efficiently loaded into a MPS using :func:`mps_binary_tree`.
    """

    left_nodes: list[BranchNode]
    root_node: BinaryRoot
    right_nodes: list[BranchNode]

    # Keep type checker happy
    center: int = dataclasses.field(init=False)
    physical_dimensions: list[int] = dataclasses.field(init=False)
    length: int = dataclasses.field(init=False)

    def __post_init__(self):
        self.center = len(self.left_nodes)
        left_dimensions = [node.N for node in self.left_nodes]
        right_dimensions = [node.N for node in self.right_nodes]
        self.physical_dimensions = (
            left_dimensions + [self.root_node.N] + right_dimensions
        )
        self.length = len(self.physical_dimensions)


[docs] def mps_binary_tree(binary_tree: BinaryTree) -> SparseMPS: """ Compute the MPS representation of a multivariate function encoded as a `BinaryTree`. Returns an `SparseMPS`, whose cores are highly sparse and represented as collections of CSR matrices. Source: https://arxiv.org/abs/2206.03832 Parameters ---------- binary_tree : BinaryTree Binary computation-tree representation of the target function. Returns ------- SparseMPS Sparse MPS approximation of the target multivariate function. """ with make_logger(2) as logger: logger("Computing branch images:") left_images = propagate_images(binary_tree.left_nodes, logger) right_images = propagate_images(binary_tree.right_nodes, logger) logger("Computing transitions:") left_transitions = get_transitions(binary_tree.left_nodes, left_images, logger) right_transitions = get_transitions( binary_tree.right_nodes, right_images, logger ) root_transition = { (k_L, s, k_R): binary_tree.root_node.evaluate(x_L, s, x_R) for k_L, x_L in enumerate(left_images[-1]) for s in range(binary_tree.root_node.N) for k_R, x_R in enumerate(right_images[-1]) } logger("Computing MPS cores:") left_cores = assemble_cores(left_transitions, logger) right_cores = assemble_cores(right_transitions, logger) right_cores = [A.transpose() for A in right_cores][::-1] # Compute root core coords = np.array(list(root_transition.keys())) values = np.array(list(root_transition.values())) χ_L = 1 + np.max(coords[:, 0]) N = 1 + np.max(coords[:, 1]) χ_R = 1 + np.max(coords[:, 2]) shape = (χ_L, N, χ_R) root_core = np.zeros(shape) root_core[tuple(coords.T)] = values logger(f"Center core of shape {shape}.") cores = left_cores + [root_core] + right_cores return SparseMPS(cores)