from __future__ import annotations
from abc import abstractmethod
from collections.abc import Sequence
from time import perf_counter
import numpy as np
import scipy
import dataclasses
import functools
from typing import TypeAlias
from ...state import MPS
from ...tools import SEED, Logger, make_logger
from ...typing import Vector, Matrix, Tensor3, Tensor4, Natural
from ..evaluation import random_mps_indices, evaluate_mps
from .black_box import BlackBox
@dataclasses.dataclass
class CrossStrategy:
tol: float = 1e-8
num_samples: Natural = 2**10
error_norm: float = np.inf
error_relative: bool = False
range_iters: tuple[Natural, Natural] = (1, 200)
range_max_bonds: tuple[Natural, Natural] = (1, 1000)
max_time: float | None = None
max_evals: Natural | None = None
rng: np.random.Generator = dataclasses.field(
default_factory=lambda: np.random.default_rng(SEED)
)
"""
Abstract dataclass containing the base parameters for tensor cross interpolation.
See specializations :class:`~seemps.analysis.cross.CrossStrategyDMRG`,
:class:`~seemps.analysis.cross.CrossStrategyGreedy`, and
:class:`~seemps.analysis.cross.CrossStrategyMaxvol`.
Parameters
----------
tol : float, default=1e-12
Tolerance for the sampled error.
num_samples : Natural, default=1024
Number of function samples to evaluate the error.
error_norm : float, default=np.inf
L^p norm used for computing the error.
error_relative : bool, default=False
Whether to compute the absolute or relative error.
range_iters : tuple[Natural, Natural], default=(1, 200)
Range of iterations (half-sweeps) allowed.
range_max_bonds : tuple[Natural, Natural], default=(1, 1000)
Range of MPS maximum bond dimension allowed.
max_time : float | None, default=None
Maximum computation time allowed.
max_evals : Natural | None, default=None
Maximum number of evaluations allowed.
rng : np.random.Generator, default=`numpy.random.default_rng(seemps.tools.SEED)`
Random number generator used to initialize the algorithm and sample the error.
"""
def __post_init__(self) -> None:
assert self.num_samples > 0
assert self.range_iters[0] > 0
assert self.range_max_bonds[0] > 0
@abstractmethod
def make_interpolator(
self, black_box: BlackBox, initial_points: Matrix | None = None
) -> CrossInterpolation:
pass
IndexMatrix: TypeAlias = np.ndarray[tuple[int, int], np.dtype[np.integer]]
IndexVector: TypeAlias = np.ndarray[tuple[int], np.dtype[np.integer]]
IndexSlice: TypeAlias = np.intp | IndexVector | slice
class CrossInterpolation:
"""
Auxiliar base class for TCI used to keep track of the required interpolation
information.
"""
black_box: BlackBox
sites: int
I_l: list[np.ndarray]
I_g: list[np.ndarray]
I_s: list[np.ndarray]
mps: MPS
two_sweeps_required: bool
iteration_range: range
def __init__(
self,
black_box: BlackBox,
initial_points: Matrix | None = None,
two_sweeps_required: bool = False,
two_site_algorithm: bool = True,
):
self.black_box = black_box
self.sites = len(black_box.physical_dimensions)
if initial_points is None:
initial_points = np.zeros(self.sites, dtype=int)
self.I_l, self.I_g = self.points_to_indices(initial_points)
self.I_s = [np.arange(s).reshape(-1, 1) for s in black_box.physical_dimensions]
self.mps = MPS([np.ones((1, s, 1)) for s in black_box.physical_dimensions])
self.two_sweeps_required = two_sweeps_required
self.iteration_range = range(
self.sites - 1 if two_site_algorithm else self.sites
)
@abstractmethod
def update(self, k: int, left_to_right: bool) -> None:
pass
def sample_fiber(self, k: int) -> Tensor3:
i_l, i_s, i_g = self.I_l[k], self.I_s[k], self.I_g[k]
mps_indices = self.combine_indices(i_l, i_s, i_g)
return self.black_box[mps_indices].reshape((len(i_l), len(i_s), len(i_g)))
def sample_superblock(self, k: int) -> Tensor4:
i_l, i_g = self.I_l[k], self.I_g[k + 1]
i_s1, i_s2 = self.I_s[k], self.I_s[k + 1]
mps_indices = self.combine_indices(i_l, i_s1, i_s2, i_g)
return self.black_box[mps_indices].reshape(
(len(i_l), len(i_s1), len(i_s2), len(i_g))
)
@staticmethod
def combine_indices(*indices: IndexMatrix, row_major: bool = False) -> IndexMatrix:
"""
Computes the Cartesian product of a set of multi-indices arrays and arranges the
result as concatenated indices in column or row-major order.
Parameters
----------
indices : *np.ndarray
A variable number of arrays where each array is treated as a set of multi-indices.
row_major : bool, default=False
Whether to compute the Cartesian product in row-major order.
Examples
--------
>>> combine_indices(np.array([[1, 2, 3], [4, 5, 6]]), np.array([[0], [1]]))
array([[1, 2, 3, 0],
[1, 2, 3, 1],
[4, 5, 6, 0],
[4, 5, 6, 1]])
"""
def cartesian_column(A: Matrix, B: Matrix) -> Matrix:
A_repeated = np.repeat(A, repeats=B.shape[0], axis=0)
B_tiled = np.tile(B, (A.shape[0], 1))
return np.hstack((A_repeated, B_tiled))
def cartesian_row(A: Matrix, B: Matrix) -> Matrix:
A_tiled = np.tile(A, (B.shape[0], 1))
B_repeated = np.repeat(B, repeats=A.shape[0], axis=0)
return np.hstack((A_tiled, B_repeated))
if row_major:
return functools.reduce(cartesian_row, indices)
return functools.reduce(cartesian_column, indices)
@staticmethod
def points_to_indices(points: Matrix) -> tuple[list[Matrix], list[Matrix]]:
if points.ndim == 1:
points = points.reshape(1, -1)
sites = points.shape[1]
I_l = [points[:, :k] for k in range(sites)]
I_g = [points[:, (k + 1) : sites] for k in range(sites)]
return I_l, I_g
@dataclasses.dataclass
class CrossIterationResults:
error: float
bonds: Sequence[int]
time: float
evaluations: int
@property
def max_bond_dimension(self) -> int:
return max(self.bonds)
class CrossResults:
"""
Dataclass containing the results from TCI. Keeps track of values for every iteration (half-sweep).
Parameters
----------
mps : MPS
The resulting MPS interpolation of the black-box function.
errors : Vector
Vector of error values.
bonds : Matrix
Matrix of intermediate bond dimensions.
times : Vector
Vector of cumulative computation times.
evals : Vector
Vector of cumulative function evaluations.
"""
mps: MPS
_points: list[CrossIterationResults]
_start: float
def __init__(self, mps: MPS):
self.mps = mps
self._points = []
self._start = perf_counter()
@property
def errors(self) -> np.ndarray:
return np.asarray([r.error for r in self._points])
@property
def bonds(self) -> np.ndarray:
return np.asarray([r.bonds for r in self._points])
@property
def times(self) -> np.ndarray:
return np.asarray([r.time for r in self._points])
@property
def evals(self) -> np.ndarray:
return np.asarray([r.evaluations for r in self._points])
def update(
self,
mps: MPS,
error: float,
bonds: Sequence[int],
evaluations: int,
):
self.mps = mps
self._points.append(
CrossIterationResults(
error, bonds, perf_counter() - self._start, evaluations
)
)
def get_result(self, index: int) -> CrossIterationResults:
return self._points[index]
class CrossError:
"""
Auxiliar base class for TCI used to compute the sampled error between the function and the
MPS at every iteration using the sampled Lp norm and caching intermediate results for efficiency.
"""
error_norm: float
num_samples: int
error_relative: bool
rng: np.random.Generator
mps_indices: Vector | None
black_box_evals: np.ndarray | None
norm: float
def __init__(self, cross_strategy: CrossStrategy):
self.error_norm = cross_strategy.error_norm
self.num_samples = cross_strategy.num_samples
self.error_relative = cross_strategy.error_relative
self.rng = cross_strategy.rng
# Cache
self.mps_indices = None
self.black_box_evals = None
self.norm = 1.0
def lp_distance(self, x: Vector) -> float:
p = self.error_norm
if np.isfinite(p):
dist = ((1 / len(x)) * np.sum(np.abs(x) ** p)) ** (1 / p)
else:
dist = np.max(np.abs(x))
return float(dist)
def sample_error(self, cross: CrossInterpolation) -> float:
if self.mps_indices is None:
# Consider the allowed indices to impose restrictions (e.g. diagonal MPO)
allowed_indices = getattr(cross.black_box, "allowed_indices", None)
self.mps_indices = random_mps_indices(
cross.black_box.physical_dimensions,
self.num_samples,
allowed_indices,
self.rng,
)
self.black_box_evals = cross.black_box[self.mps_indices].reshape(-1)
self.norm = self.lp_distance(self.black_box_evals)
mps_evals = evaluate_mps(cross.mps, self.mps_indices)
error = self.lp_distance(mps_evals - self.black_box_evals)
return error / self.norm if self.error_relative else error
def check_tci_convergence(
logger: Logger,
half_sweep: int,
results: CrossResults,
cross_strategy: CrossStrategy,
) -> bool:
"""Checks the convergence of TCI from its trajectories and logs the results for each iteration."""
iter_min, iter_max = cross_strategy.range_iters
bond_min, bond_max = cross_strategy.range_max_bonds
last = results.get_result(-1)
maxbond = last.max_bond_dimension
maxbond_prev = results.get_result(-2).max_bond_dimension if half_sweep > 2 else 0
evals = last.evaluations
error = last.error
if logger:
logger(
f"Iteration (half-sweep): {half_sweep:3}/{iter_max}, "
+ f"error: {error:1.15e}/{cross_strategy.tol:.2e}, "
+ f"maxbond: {maxbond:3}/{bond_max}, "
+ f"evals: {evals:8}/{cross_strategy.max_evals}."
)
if half_sweep < iter_min or maxbond < bond_min:
return False
if error <= cross_strategy.tol:
logger(f"State converged within tolerance {cross_strategy.tol}")
return True
elif maxbond - maxbond_prev <= 0:
logger(f"Max. bond dimension converged with value {maxbond}")
return True
elif half_sweep >= iter_max:
logger(f"Max. iterations reached at {iter_max}")
return True
elif maxbond >= bond_max:
logger(f"Max. bond reached above the threshold {bond_max}")
return True
elif cross_strategy.max_time is not None and last.time >= cross_strategy.max_time:
logger(f"Max. time reached above the threshold {cross_strategy.max_time}")
return True
elif cross_strategy.max_evals is not None and evals >= cross_strategy.max_evals:
logger(f"Max. evals reached above the threshold {cross_strategy.max_evals}")
return True
return False
[docs]
def maxvol_square(
A: Matrix,
max_iter: int = 10,
tol: float = 1.05,
) -> tuple[Matrix, Matrix]:
"""
Returns the row indices I of a tall matrix A of size (n x r) with n > r that give place
to a square submatrix of (quasi-)maximum volume (modulus of the submatrix determinant).
Also, returns a matrix of coefficients B such that A ≈ B A[I, :].
Parameters
----------
A : np.ndarray
A tall (n x r) matrix with more rows than columns (n > r).
maxiter : int, default=100
Maximum number of iterations allowed.
tol : float, default=1.1
Sensibility of the algorithm.
Returns
-------
I : np.ndarray
An array of r indices that determine a square submatrix of A with (quasi-)maximum volume.
B : np.ndarray
A (r x r) submatrix of coefficients such that A ≈ B A[I, :].
"""
n, r = A.shape
if n <= r:
I, B = np.arange(n, dtype=int), np.eye(n)
return I, B
P, L, U = scipy.linalg.lu(A)
I = P[:, :r].argmax(axis=0)
Q = scipy.linalg.solve_triangular(U, A.T, trans=1)
B = scipy.linalg.solve_triangular(
L[:r, :], Q, trans=1, unit_diagonal=True, lower=True
).T
for _ in range(max_iter):
i, j = np.divmod(abs(B).argmax(), r)
if abs(B[i, j]) <= tol:
break
I[j] = i
bj = B[:, j]
bi = B[i, :].copy()
bi[j] -= 1.0
B -= np.outer(bj, bi / B[i, j])
return I, B
[docs]
def cross_interpolation(
cross_strategy: CrossStrategy,
black_box: BlackBox,
initial_points: Matrix | None = None,
) -> CrossResults:
"""
Computes the MPS representation of a black-box function using different the tensor
cross-approximation (TCI) algorithm
The black-box function can represent several different data structures.
See `black_box` for usage examples.
Parameters
----------
cross_strategy : CrossStrategy
A dataclass containing the parameters of the algorithm.
See :class:`CrossStrategy`.
black_box : BlackBox
The black box to approximate as a MPS.
initial_points : Matrix | None, default=None
A collection of initial points used to initialize the algorithm.
If None, an initial point at the origin is used.
Returns
-------
CrossResults
A dataclass containing the MPS representation of the black-box function,
among other useful information.
"""
cross = cross_strategy.make_interpolator(black_box, initial_points)
error_calculator = CrossError(cross_strategy)
converged = False
with make_logger(1) as logger:
results = CrossResults(cross.mps)
for i in range(cross_strategy.range_iters[1] // 2):
# Left-to-right half sweep
for k in cross.iteration_range:
cross.update(k, True)
results.update(
cross.mps,
error_calculator.sample_error(cross),
cross.mps.bond_dimensions(),
cross.black_box.evals,
)
if not cross.two_sweeps_required:
if converged := check_tci_convergence(
logger, 2 * i + 1, results, cross_strategy
):
break
# Right-to-left half sweep
for k in reversed(cross.iteration_range):
cross.update(k, False)
results.update(
cross.mps,
error_calculator.sample_error(cross),
cross.mps.bond_dimensions(),
cross.black_box.evals,
)
if converged := check_tci_convergence(
logger, 2 * i + 2, results, cross_strategy
):
break
if not converged:
logger("Maximum number of iterations reached")
return results