Source code for torch_fem.assemble.node_assembler


from abc import abstractmethod
import torch 
import torch.nn as nn
import numpy as np
import scipy.sparse
from functools import reduce, partial
import inspect

from .projector import Projector
from ..quadrature import get_quadrature
from ..shape import get_shape_val, get_shape_grad, element_type2order, element_type2dimension
from ..nn import BufferDict

[docs]class NodeAssembler(nn.Module): """The NodeAssembler is used to assemble the operator on the nodes of the mesh The output when calling the NodeAssembler is a vector, which is of shape :math:`[|\\mathcal V|]` or :math:`[|\\mathcal V|\\times H]`, where :math:`|\\mathcal V|` is the number of nodes and :math:`H` is the number of degrees of freedom per node. Attributes ---------- quadrature_weights : BufferDict[str, torch.Tensor] The element type is the key, which should be one of :meth:`torch_fem.shape.element_types`. Each :obj:`element_type` corresponds to a 1D tensor of shape :math:`[Q]`, where :math:`Q` is the number of quadrature points` quadrature_weights of each element type quadrature_points : BufferDict[str, torch.Tensor] The element type is the key, which should be one of :meth:`torch_fem.shape.element_types`. Each :obj:`element_type` corresponds to a 2D tensor of shape :math:`[Q, D]`, where :math:`Q` is the number of quadrature points and :math:`D` is the dimension of the mesh quadrature_points of each element type shape_val : BufferDict[str, torch.Tensor] The element type is the key, which should be one of :meth:`torch_fem.shape.element_types`. Each :obj:`element_type` corresponds to a 2D tensor of shape :math:`[Q, B]`, where :math:`Q` is the number of quadrature points and :math:`B` is the number of basis functions shape_val of each element type projector : BufferDict[str, Projector] The element type is the key, which should be one of :meth:`torch_fem.shape.element_types`. Each :obj:`element_type` corresponds to a projector from element to nodes, each projector is a :meth:`torch_fem.assemble.Projector` object, could be considered as a sparse matrix .. math:: \\mathcal P_e: \\mathbb{R}_{\\text{sparse}}^{|\mathcal C_e| \\times B_e} \\rightarrow \\mathbb{R}^{|\mathcal V|} where :math:`\\mathcal C` is the set of elements, :math:`B` is the number of basis, :math:`\mathcal V` is the set of nodes/vertices/points. projector from element to edge elements : BufferDict[str, torch.Tensor] The element type is the key, which should be one of :meth:`torch_feme.shape.element_types`. Each :obj:`element_type` corresponds to a 2D tensor of shape :math:`[N, B]`, where :math:`N` is the number of elements and :math:`B` is the number of basis functions element connectivity of each element type n_points : int number of points dimension : int dimension of the mesh, either :math:`1` or :math:`2` or :math:`3` element_types : list[str] element types, e.g. :obj:`["triangle6", "quad9"]` """ __autodoc__ = [ '__call__', 'forward', '__post_init__', 'from_assembler', 'from_mesh', ] def __init__(self, quadrature_weights, quadrature_points, shape_val, projector, elements, n_points): super().__init__() element_types = list(quadrature_weights.keys()) dimension = element_type2dimension[element_types[0]] self.quadrature_weights = quadrature_weights self.quadrature_points = quadrature_points self.shape_val = shape_val self.projector = projector self.elements = elements self.dimension = dimension self.element_types = element_types self.n_points = n_points self.__post_init__() def _integrate(self, batch_integral, jxw, n_element, n_basis, use_element_parallel): if not use_element_parallel: error_msg = f"the shape returned by forward function is {batch_integral.shape} which is not supported, should either be [batch_size,{n_basis}] or [batch_size,{n_basis}, dof_per_point]" assert batch_integral.dim() == 2 or batch_integral.dim() == 3, error_msg assert batch_integral.shape[1] == n_basis, error_msg batch_integral = torch.einsum("qi...,eq->ei...", batch_integral, jxw) # [n_element, n_basis, ...] else: error_msg = f"the shape returned by forward function is {batch_integral.shape} which is not supported, should either be [{n_element},batch_size,{n_basis}] or [{n_element},batch_size,{n_basis}, dof_per_point]" assert batch_integral.dim() == 3 or batch_integral.dim() == 4, error_msg assert batch_integral.shape[0] == n_element, error_msg assert batch_integral.shape[2] == n_basis, error_msg batch_integral = torch.einsum("eqb...,eq->eb...", batch_integral, jxw) # [n_element, n_basis, ...] return batch_integral def _build_output(self, integral): return integral.flatten() @property def device(self): return self.quadrature_weights.device @property def dtype(self): return self.quadrature_weights.dtype
[docs] def type(self, dtype): super().__doc__ if dtype == torch.float64: self.double() elif dtype == torch.float32: self.float() else: raise Exception(f"the dtype {dtype} is not supported") return self
def __call__(self, points, func=None,point_data=None, batch_size=None): """ Parameters ---------- points: torch.Tensor 2D tensor of shape :math:`[|\mathcal V|, D]`, where :math:`\mathcal V` is the set of nodes/vertices/points, :math:`D` is the dimension of the domain the coordinates of the points func: function or None, optional the bilinear function, when it's None the forward function will be used, if you want to reuse the same element assembler for different bilinear function, you can pass the bilinear function here point_data: Dict[str, torch.Tensor], optional tensor of shape :math:`[|\mathcal V|, ...]`, where :math:`\mathcal V` is the set of nodes/vertices/points batch_size: int or None, optional the batch size of quadrature points if :obj:`int` is given, the quadrature points will be divided into batches if :obj:`None` is given, the quadrature points will not be divided into batches Returns ------- torch.Tensor a torch.sparse_matrix of shape :math:`[|\\mathcal V|]` or :math:`[|\\mathcal V|\\times H]`, where :math:`|\\mathcal V|` is the number of nodes and :math:`H` is the number of degrees of freedom per node. """ if point_data is None: point_data = {} point_data["x"] = points self = self.type(points.dtype).to(points.device) for key, value in point_data.items(): assert value.shape[0] == points.shape[0], f"the shape of {key} should be [n_point, ...], but got {value.shape}" signature = inspect.signature(self.forward) fn = None use_element_parallel = None integral = None for element_type in self.element_types: element_integral = None n_quadrature = self.quadrature_weights[element_type].shape[0] n_batch = n_quadrature // batch_size if batch_size is not None else 1 n_batch_size = batch_size if batch_size is not None else n_quadrature n_basis = self.shape_val[element_type].shape[1] n_element = self.elements[element_type].shape[0] ele_point_data = {k:v[self.elements[element_type]] for k,v in point_data.items()} ele_coords = points[self.elements[element_type]] # [n_element, n_basis, n_dim] for i in range(n_batch): shape_val = self.shape_val[element_type][i * n_batch_size: (i+1) * n_batch_size] # [batch_size, n_basis] w = self.quadrature_weights[element_type][i * n_batch_size: (i+1) * n_batch_size] # [batch_size] quadrature_points = self.quadrature_points[element_type][i * n_batch_size: (i+1) * n_batch_size] # [batch_size, n_dim] shape_grad, jxw = get_shape_grad(element_type, w, quadrature_points, ele_coords) # [n_element, batch_size, n_basis, n_dim], [n_element, n_batch] # prepare arguments args = [] for key in signature.parameters: if key in ["u", "v"]: args.append(shape_val) elif key in ["gradu", "gradv"]: args.append(shape_grad) elif key in ele_point_data: args.append(torch.einsum("eb...,qb->eqb...",ele_point_data[key], shape_val)) elif key.startswith("grad") and key[4:] in ele_point_data: args.append(torch.einsum("eb...,eqbd->eqb...d",ele_point_data[key[4:]], shape_grad)) else: raise NotImplementedError(f"key {key} is not implemented") # parallel dispatch if fn is None: element_dims = [] quadrature_dims = [] for key in signature.parameters: if key in ["u", "v"]: element_dims.append(None) quadrature_dims.append(0) else: element_dims.append(0) quadrature_dims.append(0) element_dims = tuple(element_dims) quadrature_dims = tuple(quadrature_dims) fn = self.forward if func is None else func if all([x is None for x in element_dims]): # if all is shape_val fn = torch.vmap(fn, in_dims=quadrature_dims) use_element_parallel = False else: fn = torch.vmap( torch.vmap( fn, in_dims = quadrature_dims ), in_dims=element_dims ) use_element_parallel = True batch_integral = fn(*args) # [n_element, batch_size, n_basis, ...] or [batch_size, n_basis, ...] batch_integral = self._integrate(batch_integral, jxw, n_element, n_basis, use_element_parallel) if element_integral is None: element_integral = batch_integral else: element_integral += batch_integral if integral is None: integral = self.projector[element_type](element_integral) # [n_edge, ...] else: integral += self.projector[element_type](element_integral) # [n_edge, ...] return self._build_output(integral) def __post_init__(self): """Override this function to precompute some data after the initialization """ pass def __str__(self): return ( f"{self.__class__.__name__}(\n" f" element_types: {self.element_types}\n" f" n_element: {' '.join(f'{k}:{v.shape[0]}' for k, v in self.elements.items())}\n" f" n_point: {self.n_points}\n" f" n_basis: {' '.join(f'{k}:{v.shape[1]}' for k, v in self.elements.items())}\n" f" n_dim: {self.dimension}\n" f" n_quadrature: {' '.join(f'{k}:{v.shape[0]}' for k, v in self.quadrature_weights.items())}\n" f" forward: \n{inspect.getsource(self.forward)}" f")" ) def __repr__(self): return str(self)
[docs] @abstractmethod def forward(self, *args): """The weak form of the operator, you should override this function. Similar to the :meth:`torch:torch.nn.Module.forward` function, you can use :method: `torch_fem.assemble.ElementAssembler.__call__` to call this function Parameters ---------- u : torch.Tensor, optional 1D tensor shape :math:`[B]`, where :math:`B` is the number of basis v : torch.Tensor, optional 1D tensor shape :math:`[B]`, where :math:`B` is the number of basis gradu : torch.Tensor, optional 2D tensor shape :math:`[B,D]`, where :math:`B` is the number of basis, :math:`D` is the dimension of the dimension gradv : torch.Tensor, optional 2D tensor shape :math:`[B,D]`, where :math:`B` is the number of basis, :math:`D` is the dimension of the dimension x : torch.Tensor, optional 2D tensor shape :math:`[B, D]`, where :math:`B` is the number of basis, :math:`D` is the dimension of the dimension gradx : torch.Tensor, optional 3D tensor shape :math:`[B, D, D]`, where :math:`B` is the number of basis, :math:`D` is the dimension of the dimension **point_data : Dict[str, torch.Tensor], optional The point_data are passed by __call__ if the point data :obj:`"example_key"` passed in is of shape :math:`[|\\mathcal V|, ...]`, then the point data :obj:`"example_key"` passed in will be of shape :math:`[B, ...]`, and the point data :obj:`"gradexample_key"` passed in will be of shape :math:`[B, ..., D]`, where :math:`B` is the number of basis, :math:`D` is the dimension of the dimension Returns ------- torch.Tensor 1D tensor of shape :math:`[B]` or 2D tensor of shape :math:`[B, H]`, where :math:`B` is the number of basis, :math:`H` is the number of degree of freedom per point """ raise NotImplementedError(f"forward is not implemented")
[docs] @classmethod def from_assembler(cls, obj): """Build an NodeAssembler from another :meth:`torch_fem.assemble.NodeAssembler` or :meth:`torch_fem.assemble.ElementAssembler`. It's much faster than :meth:`torch_fem.assemble.NodeAssembler.from_mesh`. When you already have an NodeAssembler or ElementAssembler, you can use this function to build another NodeAssembler sharig the same mesh Parameters ---------- obj: torch_fem.assemble.NodeAssembler or torch_fem.assemble.ElementAssembler an :meth:`torch_fem.assemble.NodeAssembler` or :meth:`torch_fem.assemble.ElementAssembler` object Returns ------- torch_fem.assemble.NodeAssembler the new node_assembler sharing the same mesh """ err_msg = f"the object {obj} should inheritate from NodeAssembler" assert isinstance(obj, NodeAssembler), err_msg return cls( obj.quadrature_weights, obj.quadrature_points, obj.shape_val, obj.projector, obj.elements, obj.n_points )
[docs] @classmethod def from_mesh(cls, mesh, quadrature_order=None): """Build an :meth:`torch_fem.assemble.NodeAssembler` from a mesh :meth:`torch_fem.mesh.Mesh`. It's slower than :meth:`torch_fem.assemble.NodeAssembler.from_assembler`. Because it will precompute the projection matrix $\mathcal P_{\mathcal V}$ Parameters ---------- mesh: torch_fem.mesh.mesh.Mesh a meth:`torch_fem.mesh.Mesh` object quadrature_order: int or None the order should be poisitive integer, if :obj:`None`, the quadrature order will be determined by the :meth:`torch_fem.quadrature.get_quadrature` Returns ------- torch_fem.assemble.NodeAssembler the new node assembler use the topology of the mesh """ elements = mesh.elements() n_points = mesh.points.shape[0] if isinstance(elements, torch.Tensor): elements = {mesh.default_element_type: elements} quadrature_weights = {} quadrature_points = {} shape_val = {} projector = {} for element_type, value in elements.items(): n_element, n_basis = value.shape quadrature_weights[element_type], quadrature_points[element_type] =\ get_quadrature(element_type, quadrature_order) # [n_quadrature], [n_quadrature, n_dim] shape_val[element_type] = get_shape_val(element_type, quadrature_points[element_type]) # [n_quadrature, n_basis] projector[element_type] = Projector( from_ = torch.arange(n_element * n_basis).to(value.device), to_ = value.flatten(), from_shape = (n_element, n_basis), to_shape = (n_points,) ) quadrature_weights = BufferDict(quadrature_weights) quadrature_points = BufferDict(quadrature_points) shape_val = BufferDict(shape_val) projector = BufferDict(projector) elements = BufferDict(elements) return cls(quadrature_weights, quadrature_points, shape_val, projector, elements, n_points)
NodeAssembler.type.__doc__ = nn.Module.type.__doc__