import torch
[docs]class HeatMultiFrequency:
    """Multi-frequency heat equation, with :math:`0` boundary condition 
    .. math::
        \\frac{\\partial u }{\\partial t} = \\Delta u 
    where :math:`t \in [0,T],\quad(x_1,x_2)\in  [-1,1]^2`,
    with the boundary condition :math:`u(t, \pm 1, \pm 1) = 0`
    Parameters
    -----------
    mu: torch.Tensor , optional
        2D tensor of shape :math:`[N, d]` or 1D tensor of shape :math:`[d]`, where :math:`N` is the number of samples, :math:`d` is the dimension of the frequencies
        the coefficient of the heat equation,
        if ``None``, it will be randomly generated by :math:`\mu\sim Unif([-1,1]^d)`
    d: int, optional
        the dimension of the frequencies, if ``mu`` is not ``None``, this parameter will be ignored
        if ``mu`` is ``None``, it will be used to generate the random ``mu``
    """
    def __init__(self, mu=None, d=2):
        if mu is None:
            assert d is not None, "d must be provided if mu is None"
            mu = torch.rand(d)
        else:
            d = mu.shape[-1]
        self.mu = mu
        self.d = d
[docs]    def initial_condition(self, points):
        """Generate the heat source function at each point in the domain
        
        .. math::
            u(0,x_1,x_2,\\mu) = -\\frac{1}{d}\\sum_{m=1}^d  \\mu_m sin(\\pi m x_1)sin(\\pi m x_2)/\\sqrt m
        Parameters
        ----------
            points: torch.Tensor 
                2D tensor of shape :math:`[|\mathcal V|, 2]`, where  :math:`|\mathcal V|` is the number of vertices
                all the points must be in :math:`[-1,1]^2`
        Returns
        -------
            torch.Tensor
                1D tensor of shape :math:`[|\mathcal V|]` or :math:`[N, |\mathcal V|]`, where :math:`N` is the number of samples, :math:`|\mathcal V|` is the number of vertices
        """
        assert points.shape[-1] == 2, f"the shape of points must be [n_points, 2], but got {points.shape}"
        assert (points<=1 and points>=-1).all(), f"the points must be in [-1,1]^2, but got {points}"
        mu = self.mu
        d  = self.d
        m = torch.arange(1, d+1)
        if len(mu.shape) == 1:
            mu = mu[None, :] # (1, d)
            m  = m[None, :]  # (1, d)
            x,y = points[:, 0][:, None], points[:, 1][:, None] # (n_points, 1)
        else:
            mu = mu[:, None, ...] # (N, 1, d)
            m  = m[None, None, ...] # (1, 1, d)
            x, y = points[:, 0][None, :, None], points[:, 1][None, :, None] # (1, n_points, 1)
        
        u0 = - (mu * torch.sin(torch.pi * m * x) * torch.sin(torch.pi * m * y) / torch.sqrt(m) / d).sum(-1)
        return u0 
  
[docs]    def solution(self, points, t):
        """Generate the poisson solution function at each point in the domain
           
        .. math::
            u(t,x_1,x_2,\\mu) = -\\frac{1}{d}\\sum_{m=1}^d \\frac{\\mu_m}{\\sqrt{m}} e^{-2m^2\\pi^2t} sin(\\pi m  x_1)sin(\\pi mx_2)
            
        Parameters
        ----------
        points: torch.Tensor 
            2D tensor of shape :math:`[|\mathcal V|, 2]`, where  :math:`|\mathcal V|` is the number of vertices
        t: float
            the time
        
        Returns
        -------
        ut: torch.Tensor 
            1D tensor of shape :math:`[|\mathcal V|]` or :math:`[N, |\mathcal V|]`, where :math:`N` is the number of samples, :math:`|\mathcal V|` is the number of vertices
        """
        assert points.shape[-1] == 2, f"the shape of points must be [n_points, 2], but got {points.shape}"
        assert (points<=1 and points>=-1).all(), f"the points must be in [-1,1]^2, but got {points}"
        assert t >= 0, f"t must be non-negative, but got {t}"
        mu = self.mu
        d  = self.d
        m = torch.arange(1, d+1)
        if len(mu.shape) == 1:
            mu = mu[None, ...] # (1, d)
            m  = m[None, ...]  # (1, d)
            x,y = points[:, 0][:, None], points[:, 1][:, None] # (n_points, 1)
        else:
            mu = mu[:, None, ...] # (N, 1, d)
            m  = m[None, None, ...] # (1, 1, d)
            x, y = points[:, 0][None, :, None], points[:, 1][None, :, None] # (1, n_points, 1)
        ut = - (mu * torch.sin(torch.pi * m * x) * torch.sin(torch.pi * m * y) * torch.exp(-2 * m * m * torch.pi * torch.pi * t)/ torch.sqrt(m) / d).sum(-1)
        return ut