Source code for seemps.solve.bicgs
from __future__ import annotations
from ..state import (
MPS,
MPSSum,
CanonicalMPS,
scprod,
DEFAULT_STRATEGY,
Strategy,
simplify,
)
from ..operators import MPO, MPOList, MPOSum
from ..tools import make_logger
# TODO: Write tests for this
[docs]
def bicgs_solve(
A: MPO | MPOList | MPOSum,
b: MPS | MPSSum,
guess: MPS | None = None,
maxiter: int = 100,
atol: float = 0.0,
rtol: float = 1e-5,
strategy: Strategy = DEFAULT_STRATEGY,
) -> tuple[CanonicalMPS, float]:
"""Approximate solution of :math:`A \\psi = b`.
Given the :class:`MPO` `A` and the :class:`MPS` `b`, use the conjugate
gradient method to estimate another MPS that solves the linear system of
equations :math:`A \\psi = b`.
Parameters
----------
A : MPO | MPOList | MPOSum
Matrix product state that will be inverted
b : MPS | MPSSum
Right-hand side of the equation
maxiter : int, default = 100
Maximum number of iterations
atol, rtol : float
Absolute and relative tolerance for the convergence of the algorithm.
`norm(A@x - b) <= max(rtol * norm(b), atol)`. Defaults are
`rtol=1e-5` and `atol=0`
strategy : Strategy, default = DEFAULT_STRATEGY
Truncation strategy to keep bond dimensions in check. Defaults to
`DEFAULT_STRATEGY`, which is very strict.
Returns
-------
MPS
Approximate solution to :math:`A ψ = b`
float
Norm square of the residual :math:`\\Vert{A \\psi - b}\\Vert^2`
"""
normb = b.norm()
tolerance = max(rtol * normb, atol)
x = simplify(b if guess is None else guess, strategy=strategy)
p = r = r0 = simplify(b - A @ x, strategy)
norm_r = rho = r0.norm()
with make_logger(2) as logger:
logger(f"BICCGS algorithm for {maxiter} iterations", flush=True)
if norm_r < tolerance:
logger(
f"BICCGS converged with residual {norm_r} below tolerance {tolerance}"
)
return x, norm_r
for _ in range(1, maxiter + 1):
v = simplify(A @ p, strategy)
alpha = rho / scprod(r0, v)
h = simplify(x + alpha * p, strategy)
s = simplify(r - alpha * v, strategy)
residual = s.norm()
if residual < tolerance:
logger(
f"BICCGS converged with residual {residual} below tolerance {tolerance}"
)
x = h
break
t = simplify(A @ s, strategy)
w = scprod(t, s) / t.norm_squared()
x = simplify(h + w * s, strategy)
r = simplify(s - w * t, strategy)
norm_r = r.norm()
if norm_r < tolerance:
logger(
f"BICCGS converged with residual {norm_r} below tolerance {tolerance}"
)
break
rho_new = scprod(r0, r)
beta = (rho_new / rho) * (alpha / w)
rho = abs(rho_new)
p = simplify(r + beta * p - (beta * w) * v, strategy)
return x, norm_r # Not converged within max_iter