Source code for seemps.solve.cgs
from __future__ import annotations
from typing import Callable, Any
from ..state import (
MPS,
MPSSum,
CanonicalMPS,
DEFAULT_TOLERANCE,
DEFAULT_STRATEGY,
Strategy,
simplify,
)
from ..operators import MPO, MPOList, MPOSum
from ..tools import make_logger
[docs]
def cgs_solve(
A: MPO | MPOList | MPOSum,
b: MPS | MPSSum,
guess: MPS | None = None,
maxiter: int = 100,
tolerance: float = DEFAULT_TOLERANCE,
strategy: Strategy = DEFAULT_STRATEGY,
callback: Callable[[MPS, float], Any] | None = None,
) -> tuple[CanonicalMPS, float]:
"""Approximate solution of :math:`A \\psi = b`.
Given the :class:`MPO` `A` and the :class:`MPS` `b`, use the conjugate
gradient method to estimate another MPS that solves the linear system of
equations :math:`A \\psi = b`. Convergence is determined by the
residual :math:`\\Vert{A \\psi - b}\\Vert` being smaller than `tol`.
Parameters
----------
A : MPO | MPOList | MPOSum
Matrix product state that will be inverted
b : MPS | MPSSum
Right-hand side of the equation
maxiter : int, default = 100
Maximum number of iterations
tol : float, default = DEFAULT_TOLERANCE
Error tolerance for the algorithm.
strategy : Strategy, default = DEFAULT_STRATEGY
Truncation strategy for MPS and MPO operations
Returns
-------
MPS
Approximate solution to :math:`A ψ = b`
float
Norm-2 of the residual :math:`\\Vert{A \\psi - b}\\Vert`
"""
normb = b.norm()
if strategy.get_normalize_flag():
strategy = strategy.replace(normalize=False)
x = simplify(b if guess is None else guess, strategy=strategy)
r = b - A @ x
p = simplify(r, strategy=strategy)
residual = r.norm()
with make_logger(2) as logger:
logger(f"CGS algorithm for {maxiter} iterations", flush=True)
for i in range(maxiter):
if residual < tolerance * normb:
logger(
f"CGS converged with residual {residual} below relative tolerance {tolerance}"
)
break
α = residual * residual / A.expectation(p)
x = simplify(MPSSum([1, α], [x, p]), strategy=strategy)
r = b - A @ x
residual, ρold = r.norm(), residual
if callback is not None:
callback(x, residual)
p = simplify(MPSSum([1.0, residual / ρold], [r, p]), strategy=strategy)
logger(f"CGS step {i:5}: |r|^2={residual:5g} tol={tolerance:5g}")
return x, abs(residual)