Source code for torch_fem.functional.assemble_helpers

import torch



[docs]def trace(x): """ .. math:: \\text{trace}(A)_{\\cdots} = \\sum_{i=1}^n A_{\\cdots ii} Parameters ---------- x : torch.Tensor :math:`[..., D, D]`, where :math:`D` is the dimension of the matrix Returns ------- torch.Tensor :math:`[...]` """ return torch.einsum(f"...ii->...", x)
[docs]def dot(a, b, reduce_dim=-1): """ .. math:: \\text{dot}(A, B)_{\\cdots ab} = \\sum_{i=1}^n A_{\\cdots ai} B_{\\cdots bi} Parameters ---------- a : torch.Tensor :math:`[..., B, D]`, where :math:`B` is the number of basis, :math:`D` is the dimension of the matrix b : torch.Tensor :math:`[..., B, D]`, where :math:`B` is the number of basis, :math:`D` is the dimension of the matrix Returns ------- torch.Tensor :math:`[..., B, B]`, where :math:`B` is the number of basis """ if reduce_dim == -1: return torch.einsum("...ik,...jk->...ij", a, b) elif reduce_dim == -2: return torch.einsum("...ika,...jkb->...ijab", a, b) else: raise ValueError(f"reduce_dim must be -1 or -2, but got {reduce_dim}")
[docs]def ddot(a, b): """ .. math:: \\text{ddot}(A, B)_{\\cdots ab} = \\sum_{i=1}^n A_{\\cdots aij} B_{\\cdots bij} Parameters ---------- a : torch.Tensor :math:`[..., B, D, D]`, where :math:`B` is the number of basis, :math:`D` is the dimension of the matrix b : torch.Tensor :math:`[..., B, D, D]`, where :math:`B` is the number of basis, :math:`D` is the dimension of the matrix Returns -------- torch.Tensor :math:`[..., B, B]`, where :math:`B` is the number of basis """ return torch.einsum("...imn,...jmn->...ij", a, b)
[docs]def mul(a, b): """ .. math:: \\text{mul}(A, B)_{\\cdots ij} = \\sum_{i=1}^n A_{\\cdots i} B_{\\cdots j} Parameters ---------- a : torch.Tensor :math:`[..., B]`, where :math:`B` is the number of basis b : torch.Tensor :math:`[..., B]`, where :math:`B` is the number of basis Returns ------- torch.Tensor [..., n_basis, n_basis] """ return torch.einsum("...i,...j->...ij", a, b)
[docs]def eye(value, dim): """ .. math:: \\text{eye}(v, n)_{\\cdots ij} = \\begin{cases} v_{\\cdots}, & i=j \\\\ 0, & i \\neq j \\end{cases} Parameters ---------- value : torch.Tensor :math:`[...]`, the filled value of the eye dim : int :math:`D`, the dimension of the eye Returns ------- torch.Tensor :math:`[..., D, D]` """ dims = value.shape zeros = torch.zeros_like(value) result = torch.stack([torch.stack([zeros if j != i else value for j in range(dim)],-1) for i in range(dim)], -2) return result
[docs]def sym(a): """ .. math:: \\text{sym}(A)_{\\cdots ij} = \\frac{1}{2} (A_{\\cdots i} + A_{\\cdots j}) Parameters ---------- a : torch.Tensor :math:`[..., D]`, where :math:`D` is the dimension of the matrix Returns ------- torch.Tensor :math:`[..., D]`, where :math:`D` is the dimension of the matrix """ return 0.5 * (a[..., None] + a[..., None, :])
[docs]def vector(x): """ .. math:: \\text{vector}(A) = \\begin{bmatrix}A_{\\cdots}^0\\ \\vdots \\ A_{\\cdots}^{n_{\\text{row}}-1\end{bmatrix} Parameters ---------- x: : List[torch.Tensor] tensor list of shape [...] Returns ------- torch.Tensor :math:`[..., n_{\\text{row}}]` """ return torch.stack(x, -1)
[docs]def matrix(x): """ .. math:: \\text{matrix}(A) = \\begin{bmatrix} A_{\\cdots}^{0,0} & \\cdots & A_{\\cdots}^{n_{\\text{col}}-1} \\\\ \\vdots & \\ddots & \\vdots \\\\ A_{\\cdots}^{0,n_{\\text{row}}-1} & \\cdots & A_{\\cdots}^{n_{\\text{col}}-1,n_{\\text{row}}-1} \\end{bmatrix} Parameters ---------- x : List[List[torch.Tensor]] tensor list of list of shape [...] Returns ------- torch.Tensor :math:`[..., n_{\\text{col}}, n_{\\text{row}}]` """ return torch.stack([torch.stack(row, -1) for row in x], -2)
[docs]def transpose(x): """ .. math:: \\text{transpose}(A)_{\\cdots ij} = A_{\\cdots ji} Parameters ---------- x : torch.Tensor :math:`[..., a, b]` Returns ------- torch.Tensor :math:`[..., b, a]` """ return torch.einsum("...ij->...ji", x)
[docs]def matmul(a, b): """ .. math:: \\text{matmul}(A, B)_{\\cdots ij} = \\sum_{k=1}^n A_{\\cdots ik} B_{\\cdots kj} Parameters: ----------- a : torch.Tensor :math:`[..., a, b]` b : torch.Tensor :math:`[..., b, c]` Returns: -------- torch.Tensor :math:`[..., a, c]` """ return torch.einsum("...ij,...jk->...ik", a, b)