from __future__ import annotations
import numpy as np
from typing import Callable
from abc import ABC, abstractmethod
from ...state import MPS, MPSSum, CanonicalMPS, Strategy, DEFAULT_STRATEGY, simplify
from ...operators import MPO, MPOList, MPOSum, simplify_mpo
from ...typing import Vector
from ...tools import make_logger
from ..mesh import Interval
from ..factories import mps_interval
ScalarFunction = Callable[[Vector], float]
# TODO: Implement polynomial bases with unbounded orthogonality domains (e.g. Hermite with [-∞, ∞])
class PolynomialExpansion(ABC):
"""
Abstract base class for polynomial expansions.
A `PolynomialExpansion` represents a truncated expansion
f(x) ≈ ∑_{k=0}^d c_k P_k(x),
where {P_k} is a polynomial basis generated by a three-term recurrence relation.
The expansion is fully characterized by:
- The expansion coefficients `coefficients`.
- The recurrence coefficients (α_k, β_k, γ_k) defining the basis.
- The orthogonality domain [a, b] of the basis.
- An affine normalization (σ, μ) such that P_1(x) = σ x + μ.
Subclasses define concrete polynomial families (e.g. power, Chebyshev, Legendre, Hermite)
by specifying the recurrence coefficients projection routines.
The expansion can be evaluated on tensor-network objects via `to_mps` and `to_mpo`, which
construct MPS/MPO representations of composed functions f(g(x)) or operator functions f(A).
Attributes
----------
coefficients : Vector
Expansion coefficients {c_k} of f(x) in the chosen polynomial basis.
orthogonality_domain : tuple[float, float]
Real interval on which the basis is orthogonal.
For example, (-1, 1) for Chebyshev or Legendre, and (-∞, ∞) for Hermite.
Set to None for non-orthogonal bases, such as the monomial basis.
affine_fix : tuple[float, float]
Pair of coefficients (σ, μ) fixing the affine gauge of the basis via P_1(x) = σ x + μ.
"""
coefficients: Vector
# NOTE: This hits a limitation of Python typing: subclasses cannot refine class attributes.
orthogonality_domain: tuple[float, float] # (a, b)
affine_fix: tuple[float, float] # (σ, μ)
def __init__(
self,
coefficients: Vector,
orthogonality_domain: tuple[float, float],
affine_fix: tuple[float, float] = (1.0, 0.0),
):
self.coefficients = coefficients
self.orthogonality_domain = orthogonality_domain
self.affine_fix = affine_fix
[docs]
@abstractmethod
def recurrence_coefficients(self, k: int) -> tuple[float, float, float]:
"""
Return the three-term coefficients (α_k, β_k, γ_k) for the recursion
P_{k+1}(x) = (α_k x + β_k) P_k(x) - γ_k P_{k-1}(x).
"""
...
[docs]
def rescale_mps(self, mps: MPS) -> MPS:
"""
Rescale the argument MPS from the approximation domain to the orthogonality
domain of the basis, if applicable. This is delegated to the polynomial basis
subclass to correctly account for their orthogonality domain.
"""
return mps # Do nothing
[docs]
def rescale_mpo(self, mpo: MPO) -> MPO:
"""
Rescale the argument MPO from the approximation domain to the orthogonality
domain of the basis, if applicable. This is delegated to the polynomial basis
subclass to correctly account for their orthogonality domain.
"""
return mpo # Do nothing
[docs]
def to_mps(
self,
argument: Interval | MPS,
clenshaw: bool = True,
strategy: Strategy = DEFAULT_STRATEGY,
rescale_argument: bool = True,
) -> MPS:
"""
Construct the MPS representation of a composed function via a polynomial expansion.
Given a polynomial expansion of a scalar function f(x) and an argument g(x)
provided either as an `Interval` or an `MPS`, this method builds an MPS approximation
of the composed function f(g(x)).
Evaluation can be performed either using the Clenshaw recurrence or by direct
polynomial recursion. If the polynomial family has a finite `orthogonality_domain`,
the argument is affinely mapped to that domain prior to evaluation.
Parameters
----------
argument : Interval or MPS
The argument g to which the polynomial expansion is applied.
clenshaw : bool, default=True
Whether to use the Clenshaw recurrence for polynomial evaluation.
strategy : Strategy, default=DEFAULT_STRATEGY
Simplification strategy for intermediate MPS operations.
rescale_argument : bool, default=True
Whether to rescale the argument to the orthogonality domain of the basis, if applicable.
This utilizes the methods `rescale_mps` and `rescale_mpo` defined in the basis subclass.
Returns
-------
MPS
An MPS approximation of f(g(x)).
"""
return _mps_polynomial_expansion(
self, argument, clenshaw, strategy, rescale_argument
)
[docs]
def to_mpo(
self,
argument: MPO,
clenshaw: bool = True,
strategy: Strategy = DEFAULT_STRATEGY,
rescale_argument: bool = True,
) -> MPO:
"""
Construct the MPO representation of a composed function via a polynomial expansion.
Given a polynomial expansion of a scalar function f(x) and an argument g(x)
provided either as an `MPO`, this method builds an MPO approximation of the composed
function f(g(x)).
Evaluation can be performed either using the Clenshaw recurrence or by direct
polynomial recursion. If the polynomial family has a finite `orthogonality_domain`,
the argument is affinely mapped to that domain prior to evaluation.
Parameters
----------
argument : MPO
The argument g to which the polynomial expansion is applied.
clenshaw : bool, default=True
Whether to use the Clenshaw recurrence for polynomial evaluation.
strategy : Strategy, default=DEFAULT_STRATEGY
Simplification strategy for intermediate MPO operations.
rescale_argument : bool, default=True
Whether to rescale the argument to the orthogonality domain of the basis, if applicable.
This utilizes the methods `rescale_mps` and `rescale_mpo` defined in the basis subclass.
Returns
-------
MPO
An MPO approximation of f(g(x)).
"""
return _mpo_polynomial_expansion(
self, argument, clenshaw, strategy, rescale_argument
)
class PowerExpansion(PolynomialExpansion):
"""
Polynomial expansion in the monomial basis {1, x, x^2, ...}, coresponding to a
standard power or Taylor series.
The recurrence relation is trivial, P_{k+1}(x) = x · P_k(x), and the basis is
not orthogonal with respect to any inner product (`orthogonality_domain` is None).
When evaluated using the Clenshaw recursion, this expansion reduces to Horner's method,
providing a numerically stable and efficient evaluation of polynomial functions in
tensor-network form.
"""
def __init__(self, coefficients: Vector):
super().__init__(
coefficients=coefficients, orthogonality_domain=(-np.inf, np.inf)
)
[docs]
def recurrence_coefficients(self, k: int) -> tuple[float, float, float]:
return (1.0, 0.0, 0.0)
def _mps_polynomial_expansion(
expansion: PolynomialExpansion,
argument: Interval | MPS,
clenshaw: bool = True,
strategy: Strategy = DEFAULT_STRATEGY,
rescale_argument: bool = True,
) -> MPS:
logger = make_logger(2)
if isinstance(argument, Interval):
X = mps_interval(argument)
elif isinstance(argument, MPS):
X = argument
else:
raise ValueError("Either an Interval or an initial MPS must be provided.")
# Normalized initial states
I = MPS([np.ones((1, s, 1)) for s in X.physical_dimensions()])
norm_I = I.norm()
I_hat = CanonicalMPS(I, center=0, normalize=True, strategy=strategy)
X = expansion.rescale_mps(X) if rescale_argument else X
norm_X = X.norm()
X_hat = CanonicalMPS(X, center=0, normalize=True, strategy=strategy)
# Basis coefficients
σ, μ = expansion.affine_fix
c = expansion.coefficients
d = len(c) - 1
recurrences = [expansion.recurrence_coefficients(l) for l in range(d + 2)]
if clenshaw:
# Y_k = c_k I + (α_k X + β_k) Y_{k+1} - γ_{k+1} Y_{k+2}
logger("MPS Clenshaw evaluation started")
Y_kp1 = Y_kp2 = I_hat.zero_state()
# Backward recursion: k = d, d-1, ..., 1
for k in range(d, 0, -1):
α_k, β_k, _ = recurrences[k]
_, _, γ_kp1 = recurrences[k + 1]
weights = [c[k] * norm_I, α_k * norm_X, -γ_kp1]
states = [I_hat, X_hat * Y_kp1, Y_kp2]
if β_k != 0: # Avoid zero branch when β_k == 0
weights.append(β_k)
states.append(Y_kp1)
Y_k = simplify(MPSSum(weights, states, check_args=False), strategy=strategy)
logger(
f"MPS Clenshaw step {k + 1}/{d + 1}, maxbond={Y_k.max_bond_dimension()}, error={Y_k.error():6e}"
)
Y_kp2, Y_kp1 = Y_kp1, Y_k
# F = c_0 I + (σ X + μ) * Y_1 - γ_1 Y_2
_, _, γ_1 = recurrences[1]
weights = [c[0] * norm_I, σ * norm_X, -γ_1]
states = [I_hat, X_hat * Y_kp1, Y_kp2]
if μ != 0: # Avoid zero branch when μ == 0
weights.append(μ)
states.append(Y_kp1)
F = simplify(MPSSum(weights, states, check_args=False), strategy=strategy)
else:
# P_{k+1} = (α_{k} X + β_{k}) P_k - γ_k P_{k-1}
# F_{k+1} = F_k + c_{k+1} P_{k+1}
logger("MPS expansion (direct) started")
P_0 = norm_I * I_hat
P_1 = simplify(
σ * norm_X * X_hat
if μ == 0 # Avoid zero branch when μ == 0
else MPSSum([σ * norm_X, μ * norm_I], [X_hat, I_hat], check_args=False),
strategy=strategy,
)
F = simplify(
MPSSum([c[0], c[1]], [P_0, P_1], check_args=False), strategy=strategy
)
# Forward recursion: k = 1, 2, ..., d-1
P_km1, P_k = P_0, P_1
for k in range(1, d):
α_k, β_k, γ_k = recurrences[k]
weights = [α_k * norm_X, -γ_k]
states = [X_hat * P_k, P_km1]
if β_k != 0:
weights.append(β_k)
states.append(P_k)
P_kp1 = simplify(
MPSSum(weights, states, check_args=False), strategy=strategy
)
F = simplify(
MPSSum(weights=[1.0, c[k + 1]], states=[F, P_kp1], check_args=False),
strategy=strategy,
)
logger(
f"MPS expansion step {k + 1}/{d + 1}, maxbond={F.max_bond_dimension()}, error={F.error():6e}"
)
P_km1, P_k = P_k, P_kp1
logger.close()
return F
def _mpo_polynomial_expansion(
expansion: PolynomialExpansion,
argument: MPO,
clenshaw: bool = True,
strategy: Strategy = DEFAULT_STRATEGY,
rescale_argument: bool = True,
) -> MPO:
logger = make_logger(2)
X = expansion.rescale_mpo(argument) if rescale_argument else argument
X = simplify_mpo(X, strategy=strategy)
I = MPO([np.eye(2).reshape(1, 2, 2, 1)] * len(X))
σ, μ = expansion.affine_fix
c = expansion.coefficients
d = len(c) - 1
recurrences = [expansion.recurrence_coefficients(l) for l in range(d + 2)]
mpos: list[MPO | MPOList]
if clenshaw:
# Y_k = c_k I + (α_k X + β_k) y_{k+1} - γ_{k+1} Y_{k+2}
logger("MPO Clenshaw evaluation started")
Y_kp1 = Y_kp2 = MPO([np.zeros((1, 2, 2, 1))] * len(X))
# Backward recursion: k = d, d-1, ..., 1
for k in range(d, 0, -1):
α_k, β_k, _ = recurrences[k]
_, _, γ_kp1 = recurrences[k + 1]
weights = [c[k], α_k, -γ_kp1]
mpos = [I, MPOList([X, Y_kp1]), Y_kp2]
if β_k != 0:
weights.append(β_k)
mpos.append(Y_kp1)
Y_k = simplify_mpo(MPOSum(mpos, weights), strategy=strategy)
logger(
f"MPO Clenshaw step {k + 1}/{d + 1}, maxbond={Y_k.max_bond_dimension()}"
)
Y_kp2, Y_kp1 = Y_kp1, Y_k
# F = c_0 I + (σ X + μ) * Y_1 - γ_1 Y_2
_, _, γ_1 = recurrences[1]
weights = [c[0], σ, -γ_1]
mpos = [I, MPOList([X, Y_kp1]), Y_kp2]
if μ != 0:
weights.append(μ)
mpos.append(Y_kp1)
F = simplify_mpo(MPOSum(mpos, weights), strategy=strategy)
else:
# P_{k+1} = (α_{k} X + β_{k}) P_k - γ_k P_{k-1}
# F_{k+1} = F_k + c_{k+1} P_{k+1}
logger("MPO expansion (direct) started")
P_0 = I
P_1 = simplify_mpo(MPOSum(weights=[σ, μ], mpos=[X, I]), strategy=strategy)
F = simplify_mpo(
MPOSum(weights=[c[0], c[1]], mpos=[P_0, P_1]), strategy=strategy
)
P_km1, P_k = P_0, P_1
for k in range(1, d):
α_k, β_k, γ_k = recurrences[k]
weights = [α_k, -γ_k]
mpos = [MPOList([X, P_k]), P_km1]
if β_k != 0:
weights.append(β_k)
mpos.append(P_k)
P_kp1 = simplify_mpo(MPOSum(mpos, weights), strategy=strategy)
F = simplify_mpo(
MPOSum(weights=[1.0, c[k + 1]], mpos=[F, P_kp1]), strategy=strategy
)
logger(
f"MPO expansion step {k + 1}/{d + 1}, maxbond={F.max_bond_dimension()}"
)
P_km1, P_k = P_k, P_kp1
logger.close()
return F