Source code for seemps.analysis.lagrange

from __future__ import annotations
import numpy as np
from numpy.typing import NDArray
import scipy.sparse
import functools
from scipy.sparse import dok_array, csr_array
from typing import Callable
from ..state import MPS, Strategy, DEFAULT_STRATEGY, simplify
from ..state.schmidt import _destructive_svd
from ..cython import _contract_last_and_first, destructively_truncate_vector
from ..typing import Tensor3, MPSOrder, Vector
from .mesh import Interval, ArrayInterval, Mesh, array_affine


[docs] def mps_lagrange_chebyshev_basic( func: Callable, domain: Interval | Mesh, order: int, use_logs: bool = True, mps_order: MPSOrder = "A", strategy: Strategy = DEFAULT_STRATEGY, ) -> MPS: """ Performs a "basic" Lagrange MPS Chebyshev interpolation of a function. Parameters ---------- func : Callable The function to interpolate. domain : Interval | Mesh The domain where the function is defined. order : int The order of the Chebyshev interpolation. use_logs : bool, default=True Whether to compute the Chebyshev cardinal function using logarithms to avoid overflow. mps_order : MPSOrder, default='A' The order of the MPS cores, either "A" (serial) or "B" (interleaved). strategy : Strategy, default=DEFAULT_STRATEGY The MPS simplification strategy. Returns ------- mps : MPS The MPS corresponding to the naive Chebyshev interpolation. """ mesh = Mesh([domain]) if isinstance(domain, Interval) else domain _validate_mesh(mesh) builder = LagrangeBuilder(order) A_L = builder.build_left_core(func, mesh) A_C = builder.build_center_core(use_logs) A_R = builder.build_right_core(use_logs) cores = [A_L] + builder.build_dense_cores(A_C, A_R, mesh, mps_order)[1:] mps = MPS(cores) return simplify(mps, strategy=strategy)
[docs] def mps_lagrange_chebyshev_rr( func: Callable, domain: Interval | Mesh, order: int, use_logs: bool = True, mps_order: MPSOrder = "A", strategy: Strategy = DEFAULT_STRATEGY, ) -> MPS: """ Performs a Lagrange rank-revealing MPS Chebyshev interpolation of a function. Parameters ---------- func : Callable The function to interpolate. domain : Interval | Mesh The domain where the function is defined. order : int The order of the Chebyshev interpolation. use_logs : bool, default=True Whether to compute the Chebyshev cardinal function using logarithms to avoid overflow. mps_order : MPSOrder, default='A' The order of the MPS cores, either "A" (serial) or "B" (interleaved). strategy : Strategy, default=DEFAULT_STRATEGY The MPS simplification strategy. Returns ------- mps : MPS The MPS corresponding to the rank-revealing Chebyshev interpolation. """ mesh = Mesh([domain]) if isinstance(domain, Interval) else domain _validate_mesh(mesh) builder = LagrangeBuilder(order) A_L = builder.build_left_core(func, mesh) A_C = builder.build_center_core(use_logs) A_R = builder.build_right_core(use_logs) cores = builder.build_dense_cores(A_C, A_R, mesh, mps_order) U_L, R = np.linalg.qr(A_L.reshape((2, -1))) tensors: list[NDArray] = [U_L.reshape(1, 2, 2)] for core in cores[1:-1]: B = _contract_last_and_first(R, core) r1, _, r2 = B.shape ## SVD U, S, V = _destructive_svd(B.reshape(r1 * 2, r2)) destructively_truncate_vector(S, strategy) D = S.size U = U[:, :D] R = S.reshape(D, 1) * V[:D, :] ## tensors.append(U.reshape(r1, 2, -1)) U_R = _contract_last_and_first(R, cores[-1]) tensors.append(U_R) return MPS(tensors)
[docs] def mps_lagrange_chebyshev_lrr( func: Callable, domain: Interval | Mesh, order: int, local_order: int, mps_order: MPSOrder = "A", strategy: Strategy = DEFAULT_STRATEGY, ) -> MPS: """ Performs a local rank-revealing Lagrange MPS Chebyshev interpolation of a function. The intermediate tensors are now sparse, with a number of non-zero elements that is proportional to `local_order`, increasing the efficiency of the interpolation. Parameters ---------- func : Callable The function to interpolate. domain : Interval | Mesh The domain where the function is defined. order : int The order of the Chebyshev interpolation. local_order : int The local order of the Chebyshev interpolation. mps_order : MPSOrder, default='A' The order of the MPS cores, either "A" (serial) or "B" (interleaved). strategy : Strategy, default=DEFAULT_STRATEGY The MPS simplification strategy. Returns ------- mps : MPS The MPS corresponding to the local rank-revealing Chebyshev interpolation. """ # TODO: Perform sparse matrix multiplications mesh = Mesh([domain]) if isinstance(domain, Interval) else domain _validate_mesh(mesh) builder = LagrangeBuilder(order, local_order) A_L = builder.build_left_core(func, mesh) A_C = builder.build_center_sparse_core() A_R = builder.build_right_sparse_core() cores = builder.build_sparse_cores(A_C, A_R, mesh, mps_order) U_L, R = np.linalg.qr(A_L.reshape((2, -1))) tensors: list[NDArray] = [U_L.reshape(1, 2, 2)] for core in cores[1:-1]: B = R @ core r1 = B.shape[0] ## SVD U, S, V = _destructive_svd(B.reshape(r1 * 2, -1)) destructively_truncate_vector(S, strategy) D = S.size U = U[:, :D] R = S.reshape(D, 1) * V[:D, :] ## tensors.append(U.reshape(r1, 2, -1)) U_R = R @ cores[-1] tensors.append(U_R.reshape(-1, 2, 1)) return MPS(tensors)
def _validate_mesh(mesh: Mesh): num_qubits = [int(np.log2(N)) for N in mesh.dimensions] if not all(2**n == N for n, N in zip(num_qubits, mesh.dimensions)): raise ValueError("The mesh must be quantizable in qubits.") if len(set(num_qubits)) != 1: raise ValueError("The qubits per dimension must be constant.") class LagrangeBuilder: """Auxiliar class used to build the tensors required for MPS Lagrange interpolation.""" d: int m: int D: int c: Vector angular_grid: Vector extended_grid: Vector den: Vector log_den: Vector sign_den: Vector def __init__( self, order: int, local_order: int | None = None, ): self.d = order self.m = local_order if local_order else order self.D = order + 1 self.c = np.array( [0.5 * (np.cos(np.pi * i / self.d) + 1) for i in range(self.d + 1)] ) self.angular_grid = np.array([i * np.pi / self.d for i in range(self.d + 1)]) if local_order is not None: self.extended_grid = np.array( [(i * np.pi) / self.d for i in range(-self.d, 2 * self.d + 1)] ) else: self.extended_grid = self.angular_grid # Precompute cardinal terms self.den = self.c[:, np.newaxis] - self.c np.fill_diagonal(self.den, 1) self.log_den = np.log(abs(self.den)) self.sign_den = np.sign(self.den) @functools.lru_cache(maxsize=None) # Unbound cache def angular_index(self, theta: float) -> int: """ Returns the index of the closest point of theta to an equispaced angular grid defined in [0, ..., π]. """ return int(np.argmin(abs(theta - self.angular_grid))) def chebyshev_cardinal(self, x: np.ndarray, j: int, use_logs: bool) -> np.ndarray: """ Evaluates the j-th Chebyshev cardinal function (the Lagrange interpolating polynomial for the Chebyshev-Lobatto nodes) at a given point x. """ num = np.delete(x[:, np.newaxis] - self.c, j, axis=1) if use_logs: # Prevents overflow with np.errstate(divide="ignore"): # Ignore warning of log(0) log_num = np.log(abs(num)) log_den = np.delete(self.log_den[j], j) log_div = np.sum(log_num - log_den, axis=1) sign_num = np.sign(num) sign_den = np.delete(self.sign_den[j], j) sign_div = np.prod(sign_num * sign_den, axis=1) return sign_div * np.exp(log_div) else: den = np.delete(self.den[j], j) return np.prod(num / den, axis=1) def local_chebyshev_cardinal(self, x: float, j: int) -> float: """ Evaluates the j-th local Chebyshev cardinal function at a given point x by means of a local angular Lagrange interpolation on an extended angular grid defined in [-π, ..., 2*π] """ θ = np.arccos(2 * x - 1) idx = self.angular_index(θ) P = 0.0 for γ in range(idx - self.m, idx + self.m + 1): γ_rep = -γ if γ < 0 else self.d - (γ - self.d) if γ > self.d else γ if j == γ_rep: P += self.local_angular_cardinal(θ, γ) return P def local_angular_cardinal(self, θ: float, γ: int) -> float: """ Evaluates the γ-th angular Lagrange interpolating polynomial at a given point θ on an extended angular grid defined in [-π, ..., 2*π]. """ idx = self.angular_index(θ) L = 1 for β in range(idx - self.m, idx + self.m + 1): if β != γ: L *= (θ - self.extended_grid[self.d + β]) / ( self.extended_grid[self.d + γ] - self.extended_grid[self.d + β] ) return L def build_left_core( self, func: Callable, mesh: Mesh, channels_first: bool = True ) -> Tensor3: """ Returns the left-most MPS core required for Chebyshev interpolation. """ m = mesh.dimension A = np.zeros((1, 2, self.D**m)) for σ in [0, 1]: intervals: list[Interval] = [] for i in range(m): a, b = mesh.intervals[i].start, mesh.intervals[i].stop c = (σ + self.c) / 2 if i == 0 else self.c arr = array_affine(c, (0, 1), (a, b)) intervals.append(ArrayInterval(arr)) c_mesh = Mesh(intervals) tensor = c_mesh.to_tensor(channels_first) A[0, σ, :] = func(tensor).reshape(-1) return A def build_center_core(self, use_logs: bool) -> Tensor3: """ Returns the central MPS tensor required for Chebyshev interpolation. """ A = np.zeros((self.D, 2, self.D)) for σ in range(2): for i in range(self.D): A[i, σ, :] = self.chebyshev_cardinal(0.5 * (σ + self.c), i, use_logs) return A def build_right_core(self, use_logs: bool) -> Tensor3: """ Returns the right-most MPS tensor required for Chebyshev interpolation. """ A = np.zeros((self.D, 2, 1)) for σ in range(2): for i in range(self.D): A[i, σ, 0] = self.chebyshev_cardinal( np.array([0.5 * σ]), i, use_logs ).item() return A def build_center_sparse_core(self) -> csr_array: """ Returns the central MPS tensor required for local Chebyshev interpolation. For efficiency, it is represented as a (d+1, 2*(d+1)) sparse matrix (CSR). """ A = dok_array((self.D, 2 * self.D), dtype=np.float64) for σ in range(2): for i in range(self.D): for j, c_j in enumerate(self.c): A[i, σ * self.D + j] = self.local_chebyshev_cardinal( 0.5 * (σ + c_j), i ) return A.tocsr() def build_right_sparse_core(self) -> csr_array: """ Returns the right-most MPS tensor required for local Chebyshev interpolation. For efficiency, it is represented as a (d+1, 2) sparse matrix (CSR). """ A = dok_array((self.D, 2), dtype=np.float64) for σ in range(2): for i in range(self.D): A[i, σ] = self.local_chebyshev_cardinal(0.5 * σ, i) return A.tocsr() @staticmethod def build_dense_cores( A_C: Tensor3, A_R: Tensor3, mesh: Mesh, mps_order: MPSOrder ) -> list[Tensor3]: """ Builds the multidimensional cores on the given mesh and mps_order. """ m = mesh.dimension n = int(np.log2(mesh.dimensions[0])) A_R_kron = [_kron_dense(A_R, m - i, 0) for i in range(m)] if mps_order == "A": A_C_kron = [_kron_dense(A_C, m - i, 0) for i in range(m)] cores = [] for A_C, A_R in zip(A_C_kron, A_R_kron): cores.extend([A_C] * (n - 1) + [A_R]) elif mps_order == "B": A_C_kron = [_kron_dense(A_C, m, i) for i in range(m)] cores = A_C_kron * (n - 1) + A_R_kron return cores @staticmethod def build_sparse_cores( A_C: csr_array, A_R: csr_array, mesh: Mesh, mps_order: MPSOrder ) -> list[csr_array]: """ Builds the multidimensional sparse cores on the given mesh and mps_order. """ m = mesh.dimension n = int(np.log2(mesh.dimensions[0])) A_R_kron = [_kron_sparse(A_R, m - i, 0) for i in range(m)] if mps_order == "A": A_C_kron = [_kron_sparse(A_C, m - i, 0) for i in range(m)] cores = [] for A_C, A_R in zip(A_C_kron, A_R_kron): cores.extend([A_C] * (n - 1) + [A_R]) elif mps_order == "B": A_C_kron = [_kron_sparse(A_C, m, i) for i in range(m)] cores = A_C_kron * (n - 1) + A_R_kron return cores def _kron_dense(A: Tensor3, m: int, i: int) -> Tensor3: """ Take the Kronecker product of the tensor A with identity matrices along m dimensions. The function reshapes A from (i, s, j) to (s, i, j) and back to (i*i, s, j*i) after the operation. """ I = np.eye(A.shape[0]) tensors = [np.swapaxes(A, 0, 1) if j == i else I for j in range(m)] B = tensors[0] for tensor in tensors[1:]: B = np.kron(B, tensor) return np.swapaxes(B, 0, 1) def _kron_sparse(A: csr_array, m: int, i: int) -> csr_array: """ Take the Kronecker product of the sparse tensor A with identity matrices along m dimensions. This operation is implemented converting the CSR matrix to a dense format as a temporary workaround. """ # TODO: Fix without transforming to dense matrices. A_dense = A.toarray().reshape(A.shape[0], 2, A.shape[1] // 2) B = _kron_dense(A_dense, m, i) return scipy.sparse.csr_array(B.reshape(B.shape[0], 2 * B.shape[2])) __all__ = [ "mps_lagrange_chebyshev_basic", "mps_lagrange_chebyshev_rr", "mps_lagrange_chebyshev_lrr", ]