Source code for torch_fem.assemble.element_assembler


from abc import abstractmethod
import torch 
import torch.nn as nn
import numpy as np
import scipy.sparse
from functools import reduce, partial
import inspect
import operator

from .projector import Projector
from ..nn import BufferDict
from ..quadrature import get_quadrature
from ..shape import get_shape_val, get_shape_grad, element_type2order, element_type2dimension
from ..sparse import SparseMatrix


[docs]class ElementAssembler(nn.Module): """ The :obj:`ElementAssembler` is inheritated from :class:`torch:torch.nn.Module`. Therefore, all the operation from :class:`torch:torch.nn.Module` is applicable to :obj:`ElementAssembler` You are not encouraged to build the ElementAssembler directly, instead, you should use :meth:`torch_fem.assemble.ElementAssembler.from_mesh` or :meth:`torch_fem.assemble.ElementAssembler.from_assembler` to build the ElementAssembler from a mesh The output when calling the ElementAssembler is a sparse matrix, which is the global galerkin matrix of shape :math:`\\mathbb R_{\\text{sparse}}^{|\\mathcal V|, |\\mathcal V|}` or :math:`\\mathbb R_{\\text{sparse}}^[|\\mathcal V| \times H, |\\mathcal V| \times H]`, where :math:`H` is the number of degree of freedom per point, :math:`|\\mathcal V|` is the number of points. .. math:: K \overset{\\text{bsr matrix}}{\leftarrow}\hat K_\\text{global} \hat K_{\\text{global}}^{nkl} = \mathcal P_{\mathcal E}^{nhij} \hat K_{\\text{local}}^{hklij} \hat K_{ij} = \\int_\\Omega f(u, v) \\text dv :math:`f` is `forward` function which is defined by inheritating this class * :math:`\\hat K_{\\text{global}}` : non zero value of the global galerkin matrix, :math:`K_{\\text{global}}\\in \\mathbb R^{|\\mathcal E|\\times d\\times d}` * :math:`\\hat K_{\\text{local}}` : local galerkin matrix for each element , :math:`K_{\\text{local}}\\in \\mathbb R^{|\\mathcal C|\\times h\\times h\\times d\\times d}` * :math:`\\mathcal P_{\mathcal E}` : projection (assemble) tensor from :math:`\hat K_{\\text{local}}` to :math:`\hat K_{\\text{global}}, `\\mathcal P_{\\mathcal E} \\in \\mathbb R_{\\text{sparse}}^{|\\mathcal E|\\times |\\mathcal C|\\times h\\times h}` * :math:`\\mathcal C` : elements/cells * :math:`h` : number of basis for each element/cell * :math:`\\mathcal E` : connections for nodes/vertices/points * :math:`\\mathcal V` : nodes/vertices/points Examples: --------- 1. assemble the mass matrix .. math:: M_{ij} = \\int_\\Omega u_i v_j \\text dv .. code-block:: python import torch_fem import torch_fem.functional as F class MassAssembler(torch_fem.ElementAssembler): def forward(self, u, v): return F.dot(u, v) mesh = torch_fem.Mesh.gen_rectangle() assembler = MassAssembler.from_mesh(mesh) M = assembler(mesh.points) 2. assemble the laplace matrix .. math:: K_{ij} = \\int_\\Omega \\nabla u_i \\cdot \\nabla v_j \\text dv .. code-block:: python import torch_fem import torch_fem.functional as F class LaplaceAssembler(torch_fem.ElementAssembler): def forward(self, gradu, gradv): return F.dot(gradu, gradv) mesh = torch_fem.Mesh.gen_circle() assembler = LaplaceAssembler.from_mesh(mesh) K = assembler(mesh.points) Attributes ----------- quadrature_weights : BufferDict[str, torch.Tensor] The element type is the key, which should be one of :meth:`torch_fem.shape.element_types`. Each :obj:`element_type` corresponds to a 1D tensor of shape :math:`[Q]`, where :math:`Q` is the number of quadrature points the quadrature weights of each element type, e.g. :obj:`{"triangle6": torch.tensor([0.5, 0.5])}` quadrature_points : BufferDict[str, torch.Tensor] The element type is the key, which should be one of :meth:`torch_fem.shape.element_types`. Each :obj:`element_type` corresponds to a 2D tensor of shape :math:`[Q, D]`, where :math:`Q` is the number of quadrature points, :math:`D` is the dimension of the domain the quadrature points of each element type, e.g. :obj:`{"triangle6": torch.tensor([[0.5, 0.5], [0.5, 0.0]])}` shape_val : BufferDict[str, torch.Tensor] The element type is the key, which should be one of :meth:`torch_fem.shape.element_types`. Each :obj:`element_type` corresponds to a 2D tensor of shape :math:`[Q, B]`, where :math:`Q` is the number of quadrature points, :math:`B` is the number of basis the shape value of each element type, e.g. :obj:`{"triangle6": torch.tensor([[0.5, 0.5, 0.0], [0.0, 0.5, 0.5]])}` projector : BufferDict[str, Projector] The element type is the key, which should be one of :meth:`torch_fem.shape.element_types`. ach :obj:`element_type` corresponds to a projector from element to edge, each projector is a :meth:`torch_fem.assemble.Projector` object, could be considered as a sparse matrix .. math:: \\mathcal P_e: \\mathbb{R}_{\\text{sparse}}^{|\\mathcal C_e| \\times B_e \\times B_e} \\rightarrow \\mathbb{R}^{|\\mathcal E|} where :math:`\\mathcal C` is the set of elements, :math:`B` is the number of basis, :math:`\\mathcal E` is the set of edges. elements : BufferDict[str, torch.Tensor] The element type is the key, which should be one of :meth:`torch_fem.shape.element_types`. Each :obj:`element_type` corresponds to a 2D tensor of shape :math:`[|\\mathcal C|, B]`, where :math:`\\mathcal C` is the set of elements, :math:`B` is the number of basis the element connectivity of each element type, e.g. :obj:`{"triangle6": torch.tensor([[0, 1, 2], [1, 2, 3]])}` edges : torch.Tensor 2D tensor of shape :math:`[2, |\\mathcal E|]`, where :math:`\\mathcal E` is the set of edges edge connectivity considering all element_types, e.g. :obj:`torch.tensor([[0, 1, 2], [1, 2, 3]])` n_points : int number of points dimension : int dimension of the mesh, either :math:`1` or :math:`2` or :math:`3` element_types : list[str] element types, e.g. :obj:`["triangle6", "quad9"]` """ __autodoc__ = [ '__call__', 'forward', '__post_init__', 'from_assembler', 'from_mesh', ] def __init__(self, quadrature_weights, quadrature_points, shape_val, projector, elements, edges, n_points): super().__init__() element_types = list(quadrature_weights.keys()) dimension = element_type2dimension[element_types[0]] self.quadrature_weights = quadrature_weights self.quadrature_points = quadrature_points self.shape_val = shape_val self.projector = projector self.elements = elements self.register_buffer("edges", edges) self.dimension = dimension self.element_types = list(elements.keys()) self.n_points = n_points self.__post_init__() @property def device(self): """ Returns ------- torch.device the device of the assembler, either :obj:`torch.device("cpu")` or :obj:`torch.device(f"cuda:{x}")` """ return self.quadrature_weights.device @property def dtype(self): """ Returns ------- torch.dtype the data type of the assembler, either :obj:`torch.float32` or :obj:`torch.float64` """ return self.quadrature_weights.dtype
[docs] def type(self, dtype): if dtype == torch.float64: self.double() elif dtype == torch.float32: self.float() else: raise Exception(f"the dtype {dtype} is not supported") return self
def _integrate(self, batch_integral, jxw, n_element, n_basis, use_element_parallel): if use_element_parallel: error_msg = f"the shape returned by forward function is {[*batch_integral.shape]} which is not supported, should either be [{n_element}, batch_size, {n_basis},{n_basis}] or [{n_element}, batch_size,{n_basis},{n_basis}, dof_per_point, dof_per_point]" assert batch_integral.dim() == 4 or batch_integral.dim() == 6, error_msg assert batch_integral.shape[0] == n_element, error_msg assert batch_integral.shape[2] == n_basis, error_msg assert batch_integral.shape[3] == n_basis, error_msg if batch_integral.dim() == 6: assert batch_integral.shape[-1] == batch_integral.shape[-2], error_msg batch_integral = torch.einsum("eqij...,eq->eij...", batch_integral, jxw) else: error_msg = f"the shape returned by forward function is {[*batch_integral.shape]} which is not supported, should either be [batch_size, {n_basis},{n_basis}] or [batch_size,{n_basis},{n_basis}, dof_per_point, dof_per_point]" assert batch_integral.dim() == 3 or batch_integral.dim() == 5, error_msg assert batch_integral.shape[1] == n_basis, error_msg assert batch_integral.shape[2] == n_basis, error_msg if batch_integral.dim() == 5: assert batch_integral.shape[-1] == batch_integral.shape[-2], error_msg batch_integral = torch.einsum("qij...,eq->eij...", batch_integral, jxw) return batch_integral def _build_output(self, integral): """ Parameters: ----------- integral: torch.Tensor of shape [n_edge, ...] the integral of each edge Returns: -------- SparseMatrix """ if integral.dim() == 1: return SparseMatrix(integral, self.edges[0], self.edges[1], shape=(self.n_points, self.n_points)) elif integral.dim() == 3: return SparseMatrix.from_block_coo(integral, self.edges[0], self.edges[1], shape=(self.n_points, self.n_points)) else: raise Exception(f"the shape of integral is supposed to be 1D or 3D, but got {integral.shape}") def __call__(self, points, func=None,point_data=None, batch_size=None): """ Parameters ---------- points: torch.Tensor 2D tensor of shape :math:`[|\mathcal V|, D]`, where :math:`\mathcal V` is the set of nodes/vertices/points, :math:`D` is the dimension of the domain the coordinates of the points func: function or None, optional the bilinear function, when it's None the forward function will be used, if you want to reuse the same element assembler for different bilinear function, you can pass the bilinear function here point_data: Dict[str, torch.Tensor], optional tensor of shape :math:`[|\mathcal V|, ...]`, where :math:`\mathcal V` is the set of nodes/vertices/points batch_size: int or None, optional the batch size of quadrature points if :obj:`int` is given, the quadrature points will be divided into batches if :obj:`None` is given, the quadrature points will not be divided into batches Returns ------- SparseMatrix a torch.sparse_matrix of shape :math:`\\mathbb R_{\\text{sparse}}^{|\\mathcal V|, |\\mathcal V|}` or :math:`\\mathbb R_{\\text{sparse}}^[|\\mathcal V| * H, |\\mathcal V| * H]`, where :math:`H` is the number of degree of freedom per point, :math:`|\\mathcal V|` is the number of points """ if point_data is None: point_data = {} point_data["x"] = points self = self.type(points.dtype).to(points.device) for key, value in point_data.items(): assert value.shape[0] == points.shape[0], f"the shape of {key} should be [n_point, ...], but got {value.shape}" signature = inspect.signature(self.forward) fn = None use_element_parallel = None integral = None for element_type in self.element_types: element_integral = None n_quadrature = self.quadrature_weights[element_type].shape[0] n_batch = n_quadrature // batch_size if batch_size is not None else 1 n_batch_size = batch_size if batch_size is not None else n_quadrature n_basis = self.shape_val[element_type].shape[1] n_element = self.elements[element_type].shape[0] ele_point_data = {k:v[self.elements[element_type]] for k,v in point_data.items()} ele_coords = points[self.elements[element_type]] # [n_element, n_basis, n_dim] for i in range(n_batch): shape_val = self.shape_val[element_type][i * n_batch_size: (i+1) * n_batch_size] # [batch_size, n_basis] w = self.quadrature_weights[element_type][i * n_batch_size: (i+1) * n_batch_size] # [batch_size] quadrature_points = self.quadrature_points[element_type][i * n_batch_size: (i+1) * n_batch_size] # [batch_size, n_dim] shape_grad, jxw = get_shape_grad(element_type, w, quadrature_points, ele_coords) # [n_element, batch_size, n_basis, n_dim], [n_element, n_batch] # prepare arguments args = [] for key in signature.parameters: if key in ["u", "v"]: args.append(shape_val) elif key in ["gradu", "gradv"]: args.append(shape_grad) elif key in ele_point_data: args.append(torch.einsum("eb...,qb->eqb...",ele_point_data[key], shape_val)) elif key.startswith("grad") and key[4:] in ele_point_data: args.append(torch.einsum("eb...,eqbd->eqb...d",ele_point_data[key[4:]], shape_grad)) else: raise NotImplementedError(f"key {key} is not implemented") # parallel dispatch if fn is None: element_dims = [] quadrature_dims = [] for key in signature.parameters: if key in ["u", "v"]: element_dims.append(None) quadrature_dims.append(0) else: element_dims.append(0) quadrature_dims.append(0) element_dims = tuple(element_dims) quadrature_dims = tuple(quadrature_dims) fn = self.forward if func is None else func if all([x is None for x in element_dims]): # if all is shape_val fn = torch.vmap(fn, in_dims=quadrature_dims) use_element_parallel = False else: fn = torch.vmap( torch.vmap( fn, in_dims = quadrature_dims ), in_dims=element_dims ) use_element_parallel = True batch_integral = fn(*args) # [n_element, batch_size, n_basis, n_basis, ...] or [n_batch, batch_size, n_basis, ...] batch_integral = self._integrate(batch_integral, jxw, n_element, n_basis, use_element_parallel) if element_integral is None: element_integral = batch_integral else: element_integral += batch_integral if integral is None: integral = self.projector[element_type](element_integral) # [n_edge, ...] else: integral += self.projector[element_type](element_integral) # [n_edge, ...] return self._build_output(integral)
[docs] @abstractmethod def forward(self, **kwargs): """The weak form of the operator, you should override this function. Similar to the :meth:`torch:torch.nn.Module.forward` function, you can use :meth: `torch_fem.assemble.ElementAssembler.__call__` to call this function Parameters ---------- u : torch.Tensor, optional 1D tensor shape :math:`[B]`, where :math:`B` is the number of basis v : torch.Tensor, optional 1D tensor shape :math:`[B]`, where :math:`B` is the number of basis gradu : torch.Tensor, optional 2D tensor shape :math:`[B,D]`, where :math:`B` is the number of basis, :math:`D` is the dimension of the dimension gradv : torch.Tensor, optional 2D tensor shape :math:`[B,D]`, where :math:`B` is the number of basis, :math:`D` is the dimension of the dimension x : torch.Tensor, optional 2D tensor shape :math:`[B, D]`, where :math:`B` is the number of basis, :math:`D` is the dimension of the dimension gradx : torch.Tensor, optional 3D tensor shape :math:`[B, D, D]`, where :math:`B` is the number of basis, :math:`D` is the dimension of the dimension **point_data : Dict[str, torch.Tensor], optional The point_data are passed by __call__ if the point data :obj:`"example_key"` passed in is of shape :math:`[|\\mathcal V|, ...]`, then the point data :obj:`"example_key"` passed in will be of shape :math:`[B, ...]`, and the point data :obj:`"gradexample_key"` passed in will be of shape :math:`[B, ..., D]`, where :math:`B` is the number of basis, :math:`D` is the dimension of the dimension Returns ------- torch.Tensor 2D tensor of shape :math:`[B,B]` or 4D tensor of shape :math:`[B, B, H, H]`, where :math:`B` is the number of basis, :math:`H` is the number of degree of freedom per point """ raise NotImplementedError(f"forward is not implemented")
def __post_init__(self): """Override this function to precompute some data after the initialization """ pass def __str__(self): return ( f"{self.__class__.__name__}(\n" f" element_types: {self.element_types}\n" f" n_element: {' '.join(f'{k}:{v.shape[0]}' for k, v in self.elements.items())}\n" f" n_point: {self.n_points}\n" f" n_basis: {' '.join(f'{k}:{v.shape[1]}' for k, v in self.elements.items())}\n" f" n_dim: {self.dimension}\n" f" n_quadrature: {' '.join(f'{k}:{v.shape[0]}' for k, v in self.quadrature_weights.items())}\n" f" forward: \n{inspect.getsource(self.forward)}" f")" ) def __repr__(self): return str(self)
[docs] @classmethod def from_assembler(cls, obj): """Build an :meth:`torch_fem.assemble.ElementAssembler` from another :meth:`torch_fem.assemble.ElementAssembler`. It's much faster than :meth:`torch_fem.assemble.ElementAssembler.from_mesh`. When you already have an ElementAssembler, you can use this function to build another ElementAssembler sharig the same mesh Parameters ---------- obj: torch_fem.assemble.ElementAssembler an meth:`torch_fem.assemble.ElementAssembler` object Returns ------- torch_fem.assemble.ElementAssembler the new element assembler sharing the same mesh """ assert isinstance(obj, ElementAssembler), f"obj must be an instance of ElementAssembler, but got {type(obj)}" return cls(obj.quadrature_weights, obj.quadrature_points, obj.shape_val, obj.projector, obj.elements, obj.edges, obj.n_points)
[docs] @classmethod def from_mesh(cls, mesh, quadrature_order=None): """Build an :meth:`torch_fem.assemble.ElementAssembler` from a mesh :meth:`torch_fem.mesh.Mesh`. It's much slower than :meth:`torch_fem.assemble.ElementAssembler.from_assembler`. Because it will precompute the projection matrix $\mathcal P_{\mathcal E}$ Parameters ---------- mesh: torch_fem.mesh.mesh.Mesh a meth:`torch_fem.mesh.Mesh` object quadrature_order: int or None the order should be poisitive integer, if :obj:`None`, the quadrature order will be determined by the :meth:`torch_fem.quadrature.get_quadrature` Returns ------- torch_fem.assemble.ElementAssembler the new element assembler use the topology of the mesh """ elements = mesh.elements() n_points = mesh.points.shape[0] if isinstance(elements, torch.Tensor): elements = {mesh.default_element_type: elements} quadrature_weights = {} quadrature_points = {} shape_val = {} projector = {} elem_u, elem_v = [], [] for element_type, value in elements.items(): n_element, n_basis = value.shape quadrature_weights[element_type], quadrature_points[element_type] =\ get_quadrature(element_type, quadrature_order) # [n_quadrature], [n_quadrature, n_dim] shape_val[element_type] = get_shape_val(element_type, quadrature_points[element_type]) # [n_quadrature, n_basis] for i in range(n_basis): for j in range(n_basis): elem_u.append(value[:, i]) elem_v.append(value[:, j]) elem_u, elem_v = torch.stack(elem_u, -1).flatten(), torch.stack(elem_v, -1).flatten() # [num_elements * num_basis * num_basis] elem_u, elem_v = elem_u.cpu().numpy().copy(), elem_v.cpu().numpy().copy() tmp = scipy.sparse.coo_matrix(( # used to remove duplicated edges np.ones_like(elem_u), # data (elem_u, elem_v), # (row, col) ), shape = (n_points, n_points)).tocsr().tocoo() edge_u, edge_v = tmp.row, tmp.col num_edges = len(edge_u) eids_csr = scipy.sparse.coo_matrix(( np.arange(num_edges), (edge_u, edge_v) ), shape=(n_points, n_points)).tocsr() elem_eids = np.array(eids_csr[elem_u, elem_v].copy()).ravel() ptr = 0 for element_type, value in elements.items(): n_element, n_basis = value.shape elem_eids = np.array(eids_csr[elem_u[ptr:ptr+n_element*n_basis*n_basis], elem_v[ptr:ptr+n_element*n_basis*n_basis]].copy()).ravel() ptr += n_element * n_basis * n_basis projector[element_type] = Projector( from_ = torch.arange(n_element * n_basis * n_basis), to_ = torch.from_numpy(elem_eids), from_shape = (n_element, n_basis, n_basis), to_shape = (num_edges,), ) edges = torch.from_numpy(np.stack([edge_u, edge_v], 0)) quadrature_weights = BufferDict(quadrature_weights) quadrature_points = BufferDict(quadrature_points) shape_val = BufferDict(shape_val) projector = BufferDict(projector) elements = BufferDict({k:v.long() for k,v in elements.items()}) return cls(quadrature_weights, quadrature_points, shape_val, projector, elements, edges, n_points)
ElementAssembler.type.__doc__ = nn.Module.type.__doc__