from abc import ABC, abstractmethod
import torch
from torch_fem.sparse import SparseMatrix
[docs]class ImplicitLinearRungeKutta:
r"""
.. math::
M(t) \frac{\partial u}{\partial t} = A(t) u + B(t)
* :math:`M\in \mathbb R^{n\times n}`
* :math:`A\in \mathbb R^{n\times n}`
* :math:`B\in \mathbb R^{n}`
* :math:`u\in \mathbb R^n`
.. math::
\begin{bmatrix}
M_0 - A_0\tau a_{0,0}& - A_0\tau a_{0,1}&\cdots & - A_{0}\tau a_{0,{n-1}}\\
-A_1\tau a_{1,0}& M_1-A_1\tau a_{1,1} & \cdots & - A_{1}\tau a_{1,{n-1}}\\
\vdots & \vdots &\ddots & \vdots \\
-A_{n-1}\tau a_{{n-1},0} & -A_{n-1}\tau a_{{n-1},1} & \cdots & M_{n-1} - A_{n-1}\tau a_{n-1,n-1}
\end{bmatrix}
\begin{bmatrix}
\textbf k_0\\ \textbf k_1 \\\vdots \\\textbf k_{n-1}
\end{bmatrix}=
\begin{bmatrix}
B_0 + A_0 u \\
B_1 + A_1 u \\
\vdots\\
B_{n-1} + A_{n-1} u
\end{bmatrix}
"""
def __init__(self, a, b):
assert a.dim() == 2, f"expected a to be 2D tensor, got {a.dim()}"
assert b.dim() == 1, f"expected b to be 1D tensor, got {b.dim()}"
assert a.shape[0] == a.shape[1], f"expected a to be square, got {a.shape}"
assert a.shape[0] == b.shape[0], f"expected a and b to have same shape, got {a.shape} and {b.shape}"
assert b.sum() == 1, f"expected b to sum to 1, got {b.sum()}"
self.a = a
self.b = b
self.c = a.sum(dim=1)
self.s = b.shape[0]
self.__post_init__()
def __post_init__(self):
"""precompute something after the initialization of torch_fem.ode.ImplicitLinearRungeKutta
"""
pass
[docs] def forward_M(self, t):
r"""left side matrix
.. math::
M \frac{\partial u}{\partial t} = A(t)u + B(t)
Parameters
----------
t : float
time
Returns
-------
torch_fem.sparse.SparseMatrix or torch.Tensor or float
normally, 2D :meth:`torch.Tensor` or :meth:`torch_fem.sparse.SparseMatrix` of shape :math:`[D, D]` where :math:`D` is the dimension of the problem;
if return :obj:`int` or :obj:`float`, the left side matrix :math:`M` is assumed to be a diagonal matrix with the same value
"""
return 1.0
[docs] def forward_A(self, t):
r"""compute the linear mapping term :math:`A(t)`
.. math::
M \frac{\partial u}{\partial t} = A(t)u + B(t)
Parameters
----------
t : float
time
Returns
-------
torch_fem.sparse.SparseMatrix or torch.Tensor or float
2D :meth:`torch.Tensor` or :meth:`torch_fem.sparse.SparseMatrix` of shape :math:`[D, D]` where :math:`D` is the dimension of the problem;
if return :obj:`int` or :obj:`float`, the linear mapping term is assumed to be a diagonal matrix with the same value
"""
return 1.0
[docs] def forward_B(self, t):
r"""compute the linear mapping term :math:`B(t)`
.. math::
M \frac{\partial u}{\partial t} = A(t)u + B(t)
Parameters
----------
t : float
time
Returns
-------
torch.Tensor or float
1D :meth:`torch.Tensor` of shape :math:`[D]` where :math:`D` is the dimension of the problem;
if return :obj:`int` or :obj:`float`, the linear mapping term is assumed to be a vector with the same value
"""
return 0.0
[docs] def pre_solve_lhs(self, K):
r"""precompute something before solving the linear system,
for example, do the condensation
Parameters
----------
K : torch.Tensor or torch_fem.sparse.SparseMatrix
the left side matrix
Returns
-------
torch.Tensor or torch_fem.sparse.SparseMatrix
the left side matrix after precompute
"""
return K
[docs] def pre_solve_rhs(self, f):
r"""precompute something before solving the linear system,
for example, do the condensation
Parameters
----------
f : torch.Tensor
the right hand side vector
Returns
-------
torch.Tensor
the right hand side vector after precompute
"""
return f
[docs] def post_solve(self, u):
r"""postprocess after solving the linear system,
for example, do the condensation recovery
Parameters
----------
u: torch.Tensor
the solution of the linear system
Returns
-------
torch.Tensor
the solution after postprocess
"""
return u
[docs] def step(self, t0, u0, dt):
"""
.. math::
Parameters
----------
t0 : float
initial time
u0 : torch.Tensor
initial value of shape :math:`[D]` where D is the dimension of the problem
dt : float
time step
"""
assert u0.dim() == 1, f"expected u0 to be 1D tensor, got {u0.dim()}"
a = self.a.type(u0.dtype).to(u0.device)
b = self.b.type(u0.dtype).to(u0.device)
c = self.c.type(u0.dtype).to(u0.device)
D = u0.shape[0]
h = dt
ts = t0 + dt * self.c
lhs = [[None for _ in range(self.s)] for _ in range(self.s)]
rhs = [None for _ in range(self.s)]
use_sparse = None
for i in range(self.s):
Ai = self.forward_A(ts[i])
Bi = self.forward_B(ts[i])
Mi = self.forward_M(ts[i])
assert isinstance(Ai, (SparseMatrix, torch.Tensor, int, float)) , f"expected A to be SparseMatrix or torch.Tensor or float, got {type(Ai)}"
assert isinstance(Bi, (torch.Tensor, int, float)), f"expected B to be torch.Tensor or float, got {type(Bi)}"
assert isinstance(Mi, (SparseMatrix, torch.Tensor, int, float)) , f"expected M to be SparseMatrix or torch.Tensor or float, got {type(Mi)}"
if i == 0:
use_sparse = isinstance(Mi, SparseMatrix) or isinstance(Ai, SparseMatrix)
else: # check if all the matrices are of the same type
if use_sparse: # if use_sparse, then all the matrices should be SparseMatrix or float
assert not isinstance(Ai, torch.Tensor), f"expected A to be SparseMatrix or None, got {type(Ai)}"
assert not isinstance(Mi, torch.Tensor), f"expected M to be SparseMatrix or None, got {type(Mi)}"
else: # if not use_sparse, then all the matrices should be torch.Tensor or float
assert not isinstance(Ai, SparseMatrix), f"expected A to be torch.Tensor or None, got {type(Ai)}"
assert not isinstance(Mi, SparseMatrix), f"expected M to be torch.Tensor or None, got {type(Mi)}"
# convert Mi, Ai to torch.Tensor or SparseMatrix
if isinstance(Mi, (int, float)):
Mi = float(Mi)
Mi = SparseMatrix.eye(D, value=Mi) if use_sparse else torch.eye(D) * Mi
if isinstance(Ai, (int, float)):
Ai = float(Ai)
Ai = SparseMatrix.eye(D, value=Ai) if use_sparse else torch.eye(D) * Ai
Mi = Mi.type(u0.dtype).to(u0.device)
Ai = Ai.type(u0.dtype).to(u0.device)
# main logic
for j in range(self.s):
lhs[i][j] = -h * a[i,j] * Ai
if i == j:
lhs[i][j] = lhs[i][j] + Mi
rhs[i] = Bi + Ai @ u0
# pre_solve
for i in range(self.s):
for j in range(self.s):
lhs[i][j] = self.pre_solve_lhs(lhs[i][j])
rhs[i] = self.pre_solve_rhs(rhs[i])
# combine lhs and rhs
if use_sparse:
lhs = SparseMatrix.combine(lhs)
else:
lhs = torch.cat([torch.cat(lhs[i], 1) for i in range(self.s)], 0)
rhs = torch.cat(rhs, 0)
# solve linear system
if use_sparse:
k = lhs.solve(rhs)
else:
k = torch.linalg.solve(lhs, rhs)
k = k.reshape(self.s, D)
u = u0 + h * b @ k
# post_solve
u = self.post_solve(u)
return u