Source code for torch_fem.dataset.equation.wave

import torch 


[docs]class WaveMultiFrequency: """Multi-frequency wave equation, with :math:`0` boundary condition .. math:: u_{tt} = c^2 \\Delta u where :math:`t \in [0,T],\quad(x_1,x_2)\in [0,1]^2`, with the boundary condition :math:`u(t, \pm 1, \pm 1) = 0` Parameters ----------- a: torch.Tensor , optional 3D tensor of shape :math:`[N, K, K]` or 2D tensor of shape :math:`[K, K]`, where :math:`N` is the number of samples, :math:`K` is the dimension of the frequencies the coefficient of the wave equation, if ``None``, it will be randomly generated by :math:`\\mu\\sim Unif([-1,1]^{K\\times K})` K: int, optional the dimension of the frequencies, if ``a`` is not ``None``, this parameter will be ignored if ``a`` is ``None``, it will be used to generate the random ``a`` c: float, optional the wave speed, default is :math:`1.0` r: float, optional the coefficient of the wave equation, default is :math:`0.5` """ def __init__(self, a=None, K=2, c=1.0, r=0.5 ): if a is None: assert K is not None, "K should be specified if a is None" a = torch.zeros((K, K)).uniform_(-1, 1) else: K = a.shape[-1] assert a.shape[-2:] == (K, K), f"the shape of a should be (N, {K}, {K}) or ({K}, {K}), but got {a.shape}" self.K = K self.a = a self.c = c self.r = r
[docs] def initial_condition(self, points): """Generate the wave initial function at each point in the domain .. math:: u(0, x, y, a) = \\frac{\\pi}{K^2} \\sum_{i,j=1}^{K} a_{ij} \\cdot (i^2 + j^2)^{-r} sin(\\pi ix) sin(\\pi jy) Parameters ---------- points: torch.Tensor 2D tensor of shape :math:`[|\mathcal V|, 2]`, where :math:`|\mathcal V|` is the number of vertices all the points must be in :math:`[0,1]^2` Returns ------- u0: torch.Tensor 1D tensor of shape :math:`[|\\mathcal V|]` :math:`[N, |\\mathcal V|]`, where :math:`N` is the number of samples, :math:`|\\mathcal V|` is the number of vertices v0: torch.Tensor 1D tensor of shape :math:`[|\\mathcal V|]` :math:`[N, |\\mathcal V|]`, where :math:`N` is the number of samples, :math:`|\\mathcal V|` is the number of vertices """ assert points.shape[-1] == 2, f"the shape of points must be [n_points, 2], but got {points.shape}" assert (points<=1 and points>=0).all(), f"the points must be in [0,1]^2, but got {points}" K = self.K i, j = torch.meshgrid(torch.arange(1,K+1), torch.arange(1,K+1)) # (K, K) if len(self.a.shape) == 2: a = self.a[None, :, :] # (1, K, K) i,j = i[None, :, :], j[None, :, :] # (1, K, K) x,y = points[:, 0][:, None, None], points[:, 1][:, None, None] # (n_points, 1) else: a = self.a[:, None, :, :] # (N, 1, K, K) i,j = i[None, None, :, :], j[None, None, :, :] # (1, 1, K, K) x,y = points[:, 0][None, :, None, None], points[:, 1][None, :, None, None] # (1, n_points, 1, 1) u0 = torch.pi /K/K * (a * (i*i+j*j)**(-self.r) * torch.sin(torch.pi * i * x) * torch.sin(torch.pi * j * y)).sum((-2, -1)) return u0
[docs] def solution(self, points, t=0.1): """Generate the wave solution function at each point in the domain .. math:: u(t, x, y, a) = \\frac{\\pi}{K^2} \\sum_{i,j=1}^{K} a_{ij} \\cdot (i^2 + j^2)^{-r} sin(\\pi ix) sin(\\pi jy) cos(c\\pi t \\sqrt{i^2 + j^2}) Parameters ---------- points: torch.Tensor 2D tensor of shape :math:`[|\mathcal V|, 2]`, where :math:`|\mathcal V|` is the number of vertices all the points must be in :math:`[0,1]^2` t: float the time, default is :math:`0.1` Returns ------- ut: torch.Tensor 1D tenor of shape :math:`[|\\mathcal V|]` or :math:`[N, |\\mathcal V|]`, where :math:`N` is the number of samples, :math:`|\\mathcal V|` is the number of vertices """ assert points.shape[-1] == 2, f"the shape of points must be [n_points, 2], but got {points.shape}" assert (points<=1 and points>=0).all(), f"the points must be in [0,1]^2, but got {points}" assert t >= 0, f"t must be non-negative, but got {t}" K = self.K i,j = torch.meshgrid(torch.arange(1, K+1), torch.arange(1,K+1)) # (K, K) if len(self.a.shape) == 2: a = self.a[None, :, :] # (1, K, K) i,j = i[None, :, :], j[None, :, :] # (1, K, K) x,y = points[:, 0][:, None, None], points[:, 1][:, None, None] # (n_points, 1) else: a = self.a[:, None, :, :] # (N, 1, K, K) i,j = i[None, None, :, :], j[None, None, :, :] # (1, 1, K, K) x,y = points[:, 0][None, :, None, None], points[:, 1][None, :, None, None] # (1, n_points, 1, 1) u0 = torch.pi /K/K * (a * (i*i+j*j)**(-self.r) * torch.sin(torch.pi * i * x) * torch.sin(torch.pi * j * y) * torch.cos(self.c * torch.pi * t * torch.sqrt(i*i + j*j))).sum((-2, -1)) return u0