import torch
[docs]class WaveMultiFrequency:
"""Multi-frequency wave equation, with :math:`0` boundary condition
.. math::
u_{tt} = c^2 \\Delta u
where :math:`t \in [0,T],\quad(x_1,x_2)\in [0,1]^2`,
with the boundary condition :math:`u(t, \pm 1, \pm 1) = 0`
Parameters
-----------
a: torch.Tensor , optional
3D tensor of shape :math:`[N, K, K]` or 2D tensor of shape :math:`[K, K]`, where :math:`N` is the number of samples, :math:`K` is the dimension of the frequencies
the coefficient of the wave equation,
if ``None``, it will be randomly generated by :math:`\\mu\\sim Unif([-1,1]^{K\\times K})`
K: int, optional
the dimension of the frequencies, if ``a`` is not ``None``, this parameter will be ignored
if ``a`` is ``None``, it will be used to generate the random ``a``
c: float, optional
the wave speed, default is :math:`1.0`
r: float, optional
the coefficient of the wave equation, default is :math:`0.5`
"""
def __init__(self, a=None, K=2, c=1.0, r=0.5 ):
if a is None:
assert K is not None, "K should be specified if a is None"
a = torch.zeros((K, K)).uniform_(-1, 1)
else:
K = a.shape[-1]
assert a.shape[-2:] == (K, K), f"the shape of a should be (N, {K}, {K}) or ({K}, {K}), but got {a.shape}"
self.K = K
self.a = a
self.c = c
self.r = r
[docs] def initial_condition(self, points):
"""Generate the wave initial function at each point in the domain
.. math::
u(0, x, y, a) = \\frac{\\pi}{K^2} \\sum_{i,j=1}^{K} a_{ij} \\cdot (i^2 + j^2)^{-r} sin(\\pi ix) sin(\\pi jy)
Parameters
----------
points: torch.Tensor
2D tensor of shape :math:`[|\mathcal V|, 2]`, where :math:`|\mathcal V|` is the number of vertices
all the points must be in :math:`[0,1]^2`
Returns
-------
u0: torch.Tensor
1D tensor of shape :math:`[|\\mathcal V|]` :math:`[N, |\\mathcal V|]`, where :math:`N` is the number of samples, :math:`|\\mathcal V|` is the number of vertices
v0: torch.Tensor
1D tensor of shape :math:`[|\\mathcal V|]` :math:`[N, |\\mathcal V|]`, where :math:`N` is the number of samples, :math:`|\\mathcal V|` is the number of vertices
"""
assert points.shape[-1] == 2, f"the shape of points must be [n_points, 2], but got {points.shape}"
assert (points<=1 and points>=0).all(), f"the points must be in [0,1]^2, but got {points}"
K = self.K
i, j = torch.meshgrid(torch.arange(1,K+1), torch.arange(1,K+1)) # (K, K)
if len(self.a.shape) == 2:
a = self.a[None, :, :] # (1, K, K)
i,j = i[None, :, :], j[None, :, :] # (1, K, K)
x,y = points[:, 0][:, None, None], points[:, 1][:, None, None] # (n_points, 1)
else:
a = self.a[:, None, :, :] # (N, 1, K, K)
i,j = i[None, None, :, :], j[None, None, :, :] # (1, 1, K, K)
x,y = points[:, 0][None, :, None, None], points[:, 1][None, :, None, None] # (1, n_points, 1, 1)
u0 = torch.pi /K/K * (a * (i*i+j*j)**(-self.r) * torch.sin(torch.pi * i * x) * torch.sin(torch.pi * j * y)).sum((-2, -1))
return u0
[docs] def solution(self, points, t=0.1):
"""Generate the wave solution function at each point in the domain
.. math::
u(t, x, y, a) = \\frac{\\pi}{K^2} \\sum_{i,j=1}^{K} a_{ij} \\cdot (i^2 + j^2)^{-r} sin(\\pi ix) sin(\\pi jy) cos(c\\pi t \\sqrt{i^2 + j^2})
Parameters
----------
points: torch.Tensor
2D tensor of shape :math:`[|\mathcal V|, 2]`, where :math:`|\mathcal V|` is the number of vertices
all the points must be in :math:`[0,1]^2`
t: float
the time, default is :math:`0.1`
Returns
-------
ut: torch.Tensor
1D tenor of shape :math:`[|\\mathcal V|]` or :math:`[N, |\\mathcal V|]`, where :math:`N` is the number of samples, :math:`|\\mathcal V|` is the number of vertices
"""
assert points.shape[-1] == 2, f"the shape of points must be [n_points, 2], but got {points.shape}"
assert (points<=1 and points>=0).all(), f"the points must be in [0,1]^2, but got {points}"
assert t >= 0, f"t must be non-negative, but got {t}"
K = self.K
i,j = torch.meshgrid(torch.arange(1, K+1), torch.arange(1,K+1)) # (K, K)
if len(self.a.shape) == 2:
a = self.a[None, :, :] # (1, K, K)
i,j = i[None, :, :], j[None, :, :] # (1, K, K)
x,y = points[:, 0][:, None, None], points[:, 1][:, None, None] # (n_points, 1)
else:
a = self.a[:, None, :, :] # (N, 1, K, K)
i,j = i[None, None, :, :], j[None, None, :, :] # (1, 1, K, K)
x,y = points[:, 0][None, :, None, None], points[:, 1][None, :, None, None] # (1, n_points, 1, 1)
u0 = torch.pi /K/K * (a * (i*i+j*j)**(-self.r) * torch.sin(torch.pi * i * x) * torch.sin(torch.pi * j * y) * torch.cos(self.c * torch.pi * t * torch.sqrt(i*i + j*j))).sum((-2, -1))
return u0