from __future__ import annotations
import numpy as np
import scipy.linalg
import functools
from typing import TypeAlias
from seemps.state import MPS, Strategy, DEFAULT_STRATEGY
from seemps.state.schmidt import _destructive_svd
from seemps.cython import destructively_truncate_vector, _contract_last_and_first
from seemps.analysis.mesh import Mesh, mesh_to_mps_indices
from seemps.analysis.cross import BlackBoxLoadMPS
from seemps.typing import Vector, Matrix, Tensor3
IndexMatrix: TypeAlias = np.ndarray[tuple[int, int], np.dtype[np.integer]]
class SketchedCross:
"""
Helper class for TT-RSS algorithm. Holds function multi-indices and enables sampling
function fibers, similarly to TCI implementations.
Parameters
----------
black_box : BlackBoxLoadMPS
Callable object that evaluates a function f on an array of MPS multi-indices
considering the MPS quantization and structure given by the `map_matrix` and
`physical_dimensions` attributes.
samples : Matrix
Array of pivot samples x_i ∈ ℝ^m, shape (D, m). These are quantized to indices
and used to form the recursive prefix/suffix multi-index sets.
"""
def __init__(self, black_box: BlackBoxLoadMPS, samples: Matrix):
self.black_box = black_box
self.sites = len(black_box.physical_dimensions)
n = self.sites
mesh_indices = _samples_to_mesh_indices(samples, black_box.mesh)
if black_box.map_matrix is None:
mps_indices = mesh_indices
else:
mps_indices = mesh_to_mps_indices(mesh_indices, black_box.map_matrix)
# Sets of multi-index sets: left (I_l), physical (I_s) and right (I_r).
# Equivalent to the recursive prefix and suffix sets S_k and T_k.
self.I_l = [np.unique(mps_indices[:, :ℓ], axis=0) for ℓ in range(n + 1)] # noqa: E741
self.I_s = [np.arange(s).reshape(-1, 1) for s in black_box.physical_dimensions]
self.I_r = [np.unique(mps_indices[:, ℓ:], axis=0) for ℓ in range(n + 1)] # noqa: E741
@staticmethod
def combine_indices(*indices: IndexMatrix) -> IndexMatrix:
def cartesian_column(A: Matrix, B: Matrix) -> Matrix:
A_repeated = np.repeat(A, repeats=B.shape[0], axis=0)
B_tiled = np.tile(B, (A.shape[0], 1))
return np.hstack((A_repeated, B_tiled))
return functools.reduce(cartesian_column, indices)
def sample_fiber(self, k: int) -> Tensor3:
i_l, i_s, i_r = self.I_l[k], self.I_s[k], self.I_r[k + 1]
mps_indices = self.combine_indices(i_l, i_s, i_r)
return self.black_box[mps_indices].reshape((len(i_l), len(i_s), len(i_r)))
@property
def recursive_sets(self) -> list[tuple[Vector, Vector]]:
β_list: list[Vector] = [np.zeros(self.I_l[1].shape[0], dtype=int)]
x_list: list[Vector] = [self.I_l[1][:, 0]]
for k in range(1, self.sites):
S_k = self.I_l[k]
S_kp1 = self.I_l[k + 1]
index_map = {tuple(row): idx for idx, row in enumerate(S_k)}
β_k = np.array([index_map[tuple(row[:-1])] for row in S_kp1], dtype=int)
x_k = S_kp1[:, -1].astype(int)
β_list.append(β_k)
x_list.append(x_k)
return list(zip(β_list, x_list))
def _samples_to_mesh_indices(samples: Matrix, mesh: Mesh) -> Matrix:
"""
Project continuous sample points onto the nearest nodes of a discretization mesh.
Given a collection of continuous samples of a function defined over a discretization mesh,
this routine maps each sample to the index of the nearest mesh point along each dimension.
The mapping is performed by normalizing the samples to the mesh intervals and rounding to
the closest grid node. The original sample locations are modified, resulting in a discretized
approximation.
"""
samples = np.asarray(samples, dtype=float)
K, m = samples.shape
if mesh.dimension != m:
raise ValueError("Invalid dimensions.")
indices = np.zeros((K, m), dtype=int)
for dim, interval in enumerate(mesh.intervals):
a, b, N = interval.start, interval.stop, interval.size
samples_norm = (samples[:, dim] - a) / (b - a)
indices[:, dim] = np.clip(
np.round(samples_norm * (N - 1)).astype(int), 0, N - 1
)
return indices
def _random_isometry(rows: int, cols: int) -> Matrix:
if cols > rows:
raise ValueError("cols must be <= rows")
A = np.random.randn(rows, cols)
Q, _ = scipy.linalg.qr(A, mode="economic", overwrite_a=True, check_finite=False)
return Q
__all__ = ["tt_rss"]