import torch
from .explicit_rungekutta import ExplicitRungeKutta
from .implicit_linear_rungekutta import ImplicitLinearRungeKutta
[docs]class ExplicitEuler(ExplicitRungeKutta):
r"""
.. math::
\begin{array}{c|c}
\textbf{c} & \mathfrak{A} \\
\hline
& \textbf{b}^\top
\end{array}
=
\begin{array}{c|c}
0 & 0 \\
\hline
& 1
\end{array}
.. math::
\Psi^{t,t+\tau}\textbf{u} \approx \textbf{u} + \tau \textbf{f}(t,\textbf{u})
Examples
--------
.. math::
\frac{\text{d}u}{\text{d}t} = u
.. code-block:: python
import torch
from torch_fem.ode import ExplicitEuler
class MyExplicitEuler(ExplicitEuler):
def forward(self, t, u):
return u
u0 = torch.rand(4)
dt = 0.1
ut_gt = u0 + dt * u0
ut_my = MyExplicitEuler().step(0, u0, dt)
assert torch.allclose(ut_gt, ut_my)
"""
def __init__(self):
a = torch.zeros(1, 1)
b = torch.ones(1)
super().__init__(a, b)
[docs]class ImplicitLinearEuler(ImplicitLinearRungeKutta):
r"""
.. math::
\begin{array}{c|c}
\textbf{c} & \mathfrak{A} \\
\hline
& \textbf{b}^\top
\end{array}
=
\begin{array}{c|c}
1 & 1 \\
\hline
& 1
\end{array}
.. math::
\Psi^{t,t+\tau}\textbf{u} \approx \textbf{w}\quad \textbf{w}=\textbf{u}+\tau\textbf{f}(t+\tau,\textbf{w})
Examples
--------
.. math::
\frac{\text{d}u}{\text{d}t} = u
.. code-block:: python
import torch
from torch_fem.ode import ImplicitLinearEuler
class MyImplicitLinearEuler(ImplicitLinearEuler):
pass
u0 = torch.rand(4).double()
dt = 0.1
ut_gt = (1/(1-dt)) * u0
ut_my = MyImplicitLinearEuler().step(0, u0, dt)
assert torch.allclose(ut_gt, ut_my), f"expected {ut_gt}, got {ut_my}"
"""
def __init__(self):
a = torch.ones(1, 1)
b = torch.ones(1)
super().__init__(a, b)
[docs]class MidPointLinearEuler(ImplicitLinearRungeKutta):
r"""
.. math::
\begin{array}{c|c}
\textbf{c} & \mathfrak{A} \\
\hline
& \textbf{b}^\top
\end{array}
=
\begin{array}{c|c}
\frac{1}{2} & \frac{1}{2} \\
\hline
& 1
\end{array}
.. math::
\Psi^{t,t+\tau}\textbf{u} \approx \textbf{w}\quad \textbf{w} = \textbf{u} +\tau \textbf{f}\left(t+\frac{\tau}{2},\frac{\textbf{w}+\textbf{u}}{2}\right)
Examples
--------
.. math::
\frac{\text{d} u}{\text{d} t} = u
.. code-block:: python
import torch
from torch_fem.ode import MidPointLinearEuler
class MyMidPointLinearEuler(MidPointLinearEuler):
pass
u0 = torch.rand(4)
dt = 0.1
ut_gt = ((dt+2)/(2-dt)) * u0
ut_my = MyMidPointLinearEuler().step(0, u0, dt)
assert torch.allclose(ut_gt, ut_my), f"expected {ut_gt}, got {ut_my}"
"""
def __init__(self):
a = torch.ones(1, 1) / 2
b = torch.ones(1)
super().__init__(a, b)