Source code for seemps.analysis.comptree.branch

from __future__ import annotations
import numpy as np
from scipy.sparse import lil_array
from typing import Callable, overload
from ...tools import Logger, make_logger
from ...typing import Vector
from .sparse_mps import SparseCore


class BranchNode:
    def __init__(
        self,
        func: Callable,
        grid: Vector,
        binning_tol: float | None = None,
        max_rank: int | None = None,
    ):
        """
        Internal node in a computation-tree.

        A `BranchNode` represents an intermediate functional update

            x_out = func(x_in, x_s),

        where x_in is the value propagated from an upstream node and x_s is obtained by evaluating the
        node’s one-dimensional discretization grid at index s. The function `func` is applied to the
        pair (x_in, x_s) to produce the output passed downstream. If the node has no upstream dependency,
        x_in may be `None`.

        The index s labels the local physical dimension of the MPS core associated with this node.
        Optional compression parameters (`binning_tol`, `max_rank`) control the binning and truncation
        of the output image during MPS construction.

        Parameters
        ----------
        func : Callable
            Binary function implementing the update x_out = func(x_in, x_s).
        grid : Sequence
            One-dimensional discretization grid for the local variable. Its length denotes the local
            MPS physical dimension.
        binning_tol : float, optional
            Relative tolerance used to bin nearby output values during image compression.
        max_rank : int, optional
            Maximum allowed number of distinct values (bins) in the compressed output image,
            which bounds the maximum bond dimension of the MPS core generated at this node.
        """
        self.func = func
        self.grid = grid
        self.binning_tol = binning_tol
        self.max_rank = max_rank
        self.N = len(grid)

    @overload
    def evaluate(self, x_in: float | None, s: int) -> float | None: ...

    @overload
    def evaluate(self, x_in: np.ndarray | None, s: np.ndarray) -> np.ndarray | None: ...

    def evaluate(self, x_in, s):
        if x_in is None:
            return None
        x_s = self.grid[s]
        return self.func(x_in, x_s)

[docs] def compute_image( self, values: Vector, default_tol: float = 1e-4, tol_multiplier: float = 1.25 ) -> Vector: """ Compute the image of the node function over a set of input values. The node function x_out = func(x_in, x_s) is evaluated for all combinations of input values and grid indices, and the resulting outputs are collected into a set of sorted unique values. The image can be optionally compressed by binning nearby values and by enforcing `max_rank`, which prevents a combinatorial growth of image sizes across the tree and bounds the bond dimension of the corresponding MPS core. """ logger = make_logger(3) # Compute the image image_matrix = np.zeros((len(values), self.N)) for j, x_in in enumerate(values): for s in range(self.N): value = self.evaluate(x_in, s) image_matrix[j, s] = np.nan if value is None else value # Format the image image = image_matrix.reshape(-1) image = image[~np.isnan(image)] image = np.unique(image) logger(f"\tIncoming image of size {len(image)}.") # Compress the image if self.binning_tol is not None or self.max_rank is not None: binning_tol = default_tol if self.binning_tol is None else self.binning_tol image = self._bin_image(image, binning_tol) logger( f"\tImage compressed to {len(image)} bins with tolerance {binning_tol:.3e}." ) if self.max_rank is not None: while len(image) > self.max_rank: binning_tol *= tol_multiplier image = self._bin_image(image, binning_tol) logger( f"\tImage compressed to {len(image)} bins with tolerance {binning_tol:.3e}." ) logger.close() return image
@staticmethod def _bin_image(image: Vector, binning_tol: float) -> Vector: """ Bin nearby values in a sorted image using a relative tolerance. Consecutive image values within `binning_tol` are grouped and replaced by their mean, reducing the image size while controlling relative error. """ binned_image = [] bin = [image[0]] for x in image[1:]: error = abs((x - bin[0]) / bin[0]) if error <= binning_tol: bin.append(x) else: binned_image.append(np.mean(bin)) bin = [x] binned_image.append(np.mean(bin)) return np.array(binned_image) def propagate_images( nodes: list[BranchNode], logger: Logger = Logger() ) -> list[Vector]: """Helper function to propagate an initial image through a sequence of BranchNodes.""" l = len(nodes) images = [np.array([0.0])] for i, node in enumerate(nodes): image = node.compute_image(images[-1]) logger(f"Node {(i + 1)}/{l} | Image of size {len(image)}.") images.append(image) return images def get_transitions( nodes: list[BranchNode], images: list[Vector], logger: Logger = Logger() ) -> list[dict]: """ Construct transition mappings for a chain of `BranchNode`s. For each node, this function determines how input image indices map to output image indices under the node's functional update. Given consecutive images R_in and R_out, the transition assigns, for each input index k_in and grid index s, the output index k_out such that func(R_in[k_in], grid[s]) ≈ R_out[k_out]. When no binning is applied, exact matches in R_out are guaranteed. If binning is used, the nearest value in R_out is assigned instead. Each node yields a dictionary mapping (k_in, s) → k_out, which is used to assemble the corresponding MPS core. """ l = len(nodes) transitions = [] for i, node in enumerate(nodes): R_in = images[i] R_out = images[i + 1] # Create lookup tables for fast O(1) search R_out_lookup = {value: idx for idx, value in enumerate(R_out)} transition = {} for s in range(node.N): for k_in, x_in in enumerate(R_in): x_out = node.evaluate(x_in, s) if x_out is not None: k_out = R_out_lookup.get(x_out, None) # If not found, find closest index in R_out with np.searchsorted if k_out is None: k_out = int(np.searchsorted(R_out, x_out, side="left")) k_out = min(k_out, len(R_out) - 1) transition[(k_in, s)] = k_out logger(f"Node {(i + 1)}/{l} | Transition of size {len(transition)}.") transitions.append(transition) return transitions def assemble_cores( transitions: list[dict], logger: Logger = Logger() ) -> list[SparseCore]: """ Assemble sparse MPS cores from transition mappings. For each node, a rank-3 tensor A[r_L, s, r_R] is constructed such that A[k_in, s, k_out] = 1 whenever the transition mapping contains (k_in, s) → k_out. Each physical slice (fixed s) is stored as a sparse CSR matrix, yielding a compact sparse-core representation. """ cores = [] l = len(transitions) for i, transition in enumerate(transitions): coords = np.array([(k_in, s, k_out) for (k_in, s), k_out in transition.items()]) χ_L = 1 + np.max(coords[:, 0]) N = 1 + np.max(coords[:, 1]) χ_R = 1 + np.max(coords[:, 2]) data = [lil_array((χ_L, χ_R)) for _ in range(N)] for k_in, s, k_out in coords: data[s][k_in, k_out] += 1 core = SparseCore([matrix.tocsr() for matrix in data]) logger(f"Node {(i + 1)}/{l} | Core of shape {core.shape}.") cores.append(core) return cores