API Reference

This section provides the complete API documentation for torch-sla.


Core Classes

SparseTensor

The main class for working with sparse matrices. Supports batched operations, automatic differentiation, and multiple backends.

class torch_sla.SparseTensor(values: Tensor, row_indices: Tensor, col_indices: Tensor, shape: Tuple[int, ...], sparse_dim: Tuple[int, int] = (-2, -1))[source]

Bases: object

Wrapper class for PyTorch sparse tensors with batched and block support.

Supports tensors with shape […batch, M, N, …block] where: - Leading dimensions […batch] are batch dimensions - (M, N) are the sparse matrix dimensions (at sparse_dim positions) - Trailing dimensions […block] are block dimensions

Parameters:
  • values (torch.Tensor) – Non-zero values with shape: - Simple: [nnz] - Batched: […batch, nnz] - Block: [nnz, *block_shape] - Batched+Block: […batch, nnz, *block_shape]

  • row_indices (torch.Tensor) – Row indices with shape [nnz]. Must be on the same device as values.

  • col_indices (torch.Tensor) – Column indices with shape [nnz]. Must be on the same device as values.

  • shape (Tuple[int, ...]) – Full tensor shape […batch, M, N, *block_shape].

  • sparse_dim (Tuple[int, int], optional) – Which dimensions are sparse (M, N). Default: (-2, -1) meaning last two before any block dimensions.

values

The non-zero values.

Type:

torch.Tensor

row_indices

Row indices of non-zeros.

Type:

torch.Tensor

col_indices

Column indices of non-zeros.

Type:

torch.Tensor

shape

Full tensor shape.

Type:

Tuple[int, …]

sparse_shape

The (M, N) dimensions.

Type:

Tuple[int, int]

batch_shape

The batch dimensions.

Type:

Tuple[int, …]

block_shape

The block dimensions.

Type:

Tuple[int, …]

Examples

1. Simple 2D Sparse Matrix [M, N]

>>> import torch
>>> from torch_sla import SparseTensor
>>>
>>> # Create a 3x3 tridiagonal matrix in COO format
>>> val = torch.tensor([4.0, -1.0, -1.0, 4.0, -1.0, -1.0, 4.0])
>>> row = torch.tensor([0, 0, 1, 1, 1, 2, 2])
>>> col = torch.tensor([0, 1, 0, 1, 2, 1, 2])
>>> A = SparseTensor(val, row, col, (3, 3))
>>> print(A)
SparseTensor(shape=(3, 3), sparse=(3, 3), nnz=7, dtype=torch.float64, device=cpu)
>>>
>>> # Solve Ax = b
>>> b = torch.tensor([1.0, 2.0, 3.0])
>>> x = A.solve(b)

2. Batched Sparse Matrices [B, M, N]

Same sparsity pattern, different values for each batch.

>>> # 4 matrices, each 3x3, same structure
>>> batch_size = 4
>>> val_batch = val.unsqueeze(0).expand(batch_size, -1).clone()  # [4, 7]
>>> for i in range(batch_size):
...     val_batch[i] = val * (1.0 + 0.1 * i)  # Scale each matrix
>>>
>>> A_batch = SparseTensor(val_batch, row, col, (4, 3, 3))
>>> print(A_batch.batch_shape)  # (4,)
>>> print(A_batch.sparse_shape)  # (3, 3)
>>>
>>> # Batched solve
>>> b_batch = torch.randn(4, 3)
>>> x_batch = A_batch.solve(b_batch)  # [4, 3]

3. Multi-Dimensional Batch [B1, B2, M, N]

>>> B1, B2 = 2, 3  # e.g., 2 materials x 3 temperatures
>>> val_batch = val.unsqueeze(0).unsqueeze(0).expand(B1, B2, -1).clone()  # [2, 3, 7]
>>> A_multi = SparseTensor(val_batch, row, col, (B1, B2, 3, 3))
>>> print(A_multi.batch_shape)  # (2, 3)
>>>
>>> b_multi = torch.randn(B1, B2, 3)
>>> x_multi = A_multi.solve(b_multi)  # [2, 3, 3]

4. Block Sparse Matrix [M, N, K, K] (Block Size K)

Each non-zero entry is a KxK dense block instead of a scalar.

>>> # 2x2 block matrix with 2x2 blocks = 4x4 total
>>> block_size = 2
>>> nnz = 3  # 3 non-zero blocks
>>>
>>> # Values: [nnz, K, K] = [3, 2, 2]
>>> val_block = torch.randn(nnz, block_size, block_size)
>>> row_block = torch.tensor([0, 0, 1])  # Block row indices
>>> col_block = torch.tensor([0, 1, 1])  # Block col indices
>>>
>>> # Shape: (num_block_rows, num_block_cols, block_size, block_size)
>>> A_block = SparseTensor(val_block, row_block, col_block, (2, 2, 2, 2))
>>> print(A_block.block_shape)  # (2, 2)
>>> print(A_block.sparse_shape)  # (2, 2) - number of blocks
>>> print(A_block.shape)  # (2, 2, 2, 2) - full shape

5. Batched Block Sparse [B, M, N, K, K]

>>> batch_size = 4
>>> val_batch_block = torch.randn(batch_size, nnz, block_size, block_size)  # [4, 3, 2, 2]
>>> A_batch_block = SparseTensor(val_batch_block, row_block, col_block, (4, 2, 2, 2, 2))
>>> print(A_batch_block.batch_shape)  # (4,)
>>> print(A_batch_block.block_shape)  # (2, 2)

6. Create from Dense Matrix

>>> A_dense = torch.randn(100, 100)
>>> A_dense[A_dense.abs() < 0.5] = 0  # Sparsify
>>> A = SparseTensor.from_dense(A_dense)

7. Create from PyTorch Sparse Tensor

>>> A_torch = torch.randn(100, 100).to_sparse_coo()
>>> A = SparseTensor.from_torch_sparse(A_torch)

8. Property Detection

>>> A = SparseTensor(val, row, col, (3, 3))
>>> A.is_symmetric()  # tensor(True) - returns tensor for batch support
>>> A.is_positive_definite()  # tensor(True)
>>> A.is_positive_definite('cholesky')  # Use Cholesky factorization check

9. Matrix Operations

>>> # Matrix-vector multiply
>>> y = A @ x  # SparseTensor @ dense vector
>>>
>>> # Sparse-sparse multiply (returns SparseTensor with sparse gradients)
>>> C = A @ A
>>>
>>> # Norms
>>> A.norm('fro')  # Frobenius norm
>>>
>>> # Eigenvalues (symmetric matrices)
>>> eigenvalues, eigenvectors = A.eigsh(k=2, which='LM')

10. CUDA Support

>>> A_cuda = A.cuda()
>>> x = A_cuda.solve(b.cuda())  # Uses cuDSS or cuSOLVER
classmethod from_dense(A: Tensor, sparse_dim: Tuple[int, int] = (-2, -1)) SparseTensor[source]

Create SparseTensor from dense tensor.

Parameters:
  • A (torch.Tensor) – Dense tensor with shape […batch, M, N, …block].

  • sparse_dim (Tuple[int, int], optional) – Which dimensions are sparse. Default: (-2, -1).

Returns:

Sparse representation of A.

Return type:

SparseTensor

Examples

>>> A_dense = torch.randn(3, 3)
>>> A_dense[A_dense.abs() < 0.5] = 0
>>> A = SparseTensor.from_dense(A_dense)
classmethod from_torch_sparse(A: Tensor) SparseTensor[source]

Create SparseTensor from PyTorch sparse tensor.

Parameters:

A (torch.Tensor) – PyTorch sparse COO or CSR tensor (2D only).

Returns:

SparseTensor representation.

Return type:

SparseTensor

Examples

>>> A_coo = torch.randn(3, 3).to_sparse_coo()
>>> A = SparseTensor.from_torch_sparse(A_coo)
property shape: Tuple[int, ...]

Full tensor shape […batch, M, N, …block].

property sparse_shape: Tuple[int, int]

The (M, N) sparse matrix dimensions.

property batch_shape: Tuple[int, ...]

The batch dimensions before the sparse dimensions.

property block_shape: Tuple[int, ...]

The block dimensions after the sparse dimensions.

property sparse_dim: Tuple[int, int]

The dimensions that are sparse (M, N).

property ndim: int

Number of dimensions.

property nnz: int

Number of non-zero elements (per batch/block).

property dtype: dtype

Data type of the values.

property device: device

Device of the tensor.

property is_cuda: bool

Whether the tensor is on CUDA.

property is_batched: bool

Whether the tensor has batch dimensions.

property is_block: bool

Whether the tensor has block dimensions.

property batch_size: int

Total number of batch elements (product of batch_shape).

property is_square: bool

Whether the sparse dimensions are square (M == N).

to(device: str | device | None = None, dtype: dtype | None = None) SparseTensor[source]

Move tensor to device and/or convert dtype.

Parameters:
  • device (str or torch.device, optional) – Target device (e.g., ‘cuda’, ‘cpu’, ‘cuda:0’).

  • dtype (torch.dtype, optional) – Target data type (e.g., torch.float32, torch.float64).

Returns:

New SparseTensor on the target device/dtype.

Return type:

SparseTensor

Examples

>>> A = SparseTensor(val, row, col, shape)
>>> A_cuda = A.to('cuda')
>>> A_float32 = A.to(dtype=torch.float32)
>>> A_cuda_float32 = A.to('cuda', torch.float32)
cuda(device: int | None = None) SparseTensor[source]

Move tensor to CUDA device.

Parameters:

device (int, optional) – CUDA device index. Default: current device.

Returns:

Tensor on CUDA.

Return type:

SparseTensor

cpu() SparseTensor[source]

Move tensor to CPU.

Returns:

Tensor on CPU.

Return type:

SparseTensor

float() SparseTensor[source]

Convert to float32.

double() SparseTensor[source]

Convert to float64.

half() SparseTensor[source]

Convert to float16.

to_torch_sparse(batch_idx: Tuple[int, ...] | None = None) Tensor[source]

Convert to PyTorch sparse COO tensor.

Parameters:

batch_idx (Tuple[int, ...], optional) – For batched tensors, which batch element to convert. Default: (0, 0, …) for first batch element.

Returns:

PyTorch sparse COO tensor.

Return type:

torch.Tensor

to_dense(batch_idx: Tuple[int, ...] | None = None) Tensor[source]

Convert to dense tensor.

Parameters:

batch_idx (Tuple[int, ...], optional) – For batched tensors, which batch element to convert.

Returns:

Dense tensor.

Return type:

torch.Tensor

to_csr(batch_idx: Tuple[int, ...] | None = None) Tensor[source]

Convert to CSR format.

Parameters:

batch_idx (Tuple[int, ...], optional) – For batched tensors, which batch element to convert.

Returns:

PyTorch sparse CSR tensor.

Return type:

torch.Tensor

partition(num_partitions: int, coords: Tensor | None = None, partition_method: str = 'auto', verbose: bool = False) DSparseTensor[source]

Partition into a distributed sparse tensor.

Creates a DSparseTensor with automatic domain decomposition. This is useful for distributed computing and parallel solvers.

Parameters:
  • num_partitions (int) – Number of partitions to create

  • coords (torch.Tensor, optional) – Node coordinates for geometric partitioning [num_nodes, dim]. Required for ‘rcb’ and ‘slicing’ methods.

  • partition_method (str) – Partitioning method: - ‘auto’: Auto-select (uses ‘rcb’ if coords provided, else ‘metis’) - ‘metis’: Graph-based partitioning (requires pymetis) - ‘rcb’: Recursive Coordinate Bisection (requires coords) - ‘slicing’: Simple coordinate slicing (requires coords) - ‘simple’: Simple 1D partitioning by node index

  • verbose (bool) – Whether to print partition info

Returns:

Distributed sparse tensor with the specified partitions

Return type:

DSparseTensor

Example

>>> A = SparseTensor(val, row, col, shape)
>>> D = A.partition(num_partitions=4)
>>> for i in range(4):
...     partition = D[i]
...     y = partition.matvec(x_local)

Notes

  • Use D.to_sparse_tensor() to gather back to a SparseTensor

  • For distributed training, use partition_for_rank() instead

partition_for_rank(rank: int, world_size: int, coords: Tensor | None = None, partition_method: str = 'simple', verbose: bool = False) DSparseMatrix[source]

Get partition for a specific rank in distributed environment.

This is the recommended API for multi-process distributed computing. Each rank calls this method with its own rank ID to get its local partition. The partitioning is deterministic and consistent across all ranks.

Parameters:
  • rank (int) – This process’s rank (0 to world_size-1)

  • world_size (int) – Total number of processes

  • coords (torch.Tensor, optional) – Node coordinates for geometric partitioning

  • partition_method (str) – Partitioning method (‘simple’, ‘metis’, ‘rcb’, ‘slicing’)

  • verbose (bool) – Print partition info

Returns:

Local partition for this rank

Return type:

DSparseMatrix

Example

>>> # In multi-process code:
>>> A = SparseTensor(val, row, col, shape)
>>> partition = A.partition_for_rank(rank, world_size)
>>> y_local = partition.matvec(x_local)

Notes

  • This uses DSparseTensor.from_global_distributed() internally, which broadcasts partition IDs from rank 0 for consistency.

  • Requires torch.distributed to be initialized.

T() SparseTensor[source]

Transpose the sparse dimensions.

Returns:

Transposed tensor with row/col indices swapped.

Return type:

SparseTensor

flatten_blocks() SparseTensor[source]

Flatten block dimensions into the sparse (M, N) dimensions.

For a block-sparse tensor with shape […batch, M, N, *block_shape], this creates a new tensor with shape […batch, M*block_M, N*block_N] where each block entry becomes multiple scalar entries.

Returns:

Flattened tensor without block dimensions.

Return type:

SparseTensor

Example

>>> # Block sparse: shape (10, 10, 2, 2), block_shape=(2, 2)
>>> A = SparseTensor(val, row, col, (10, 10, 2, 2))
>>> A_flat = A.flatten_blocks()
>>> print(A_flat.shape)  # (20, 20)
>>> print(A_flat.nnz)    # nnz * 4 (each block has 4 elements)

Notes

  • Only works for 2D block shapes (block_M, block_N).

  • Use unflatten_blocks(block_shape) to reverse this operation.

  • The flattened tensor’s sparsity pattern may have duplicates that need to be coalesced.

unflatten_blocks(block_shape: Tuple[int, int]) SparseTensor[source]

Restore block structure from a flattened tensor.

This is the inverse of flatten_blocks(). It groups scalar entries back into block entries.

Parameters:

block_shape (Tuple[int, int]) – The (block_M, block_N) dimensions to create. M and N must be divisible by block_M and block_N respectively.

Returns:

Block-sparse tensor with the specified block shape.

Return type:

SparseTensor

Example

>>> A_flat = SparseTensor(val, row, col, (20, 20))
>>> A_block = A_flat.unflatten_blocks((2, 2))
>>> print(A_block.shape)  # (10, 10, 2, 2)
>>> print(A_block.block_shape)  # (2, 2)

Notes

  • Requires that the sparsity pattern is block-aligned.

  • All block entries must be present (dense within each block).

  • For sparse blocks, use to_block_sparse() instead.

is_symmetric(atol: float = 1e-08, rtol: float = 1e-05, force_recompute: bool = False) Tensor[source]

Check if the matrix is symmetric (A == A^T).

For batched tensors, checks each matrix independently and returns a boolean tensor with shape matching the batch dimensions.

Parameters:
  • atol (float, optional) – Absolute tolerance for comparison. Default: 1e-8.

  • rtol (float, optional) – Relative tolerance for comparison. Default: 1e-5.

  • force_recompute (bool, optional) – If True, recompute even if cached. Default: False.

Returns:

Boolean tensor with shape: - [] (scalar) for non-batched tensors - [*batch_shape] for batched tensors

Return type:

torch.Tensor

Examples

>>> A = SparseTensor(val, row, col, (3, 3))
>>> A.is_symmetric()  # tensor(True) or tensor(False)
>>> A_batch = SparseTensor(val_batch, row, col, (4, 3, 3))
>>> A_batch.is_symmetric()  # tensor([True, True, True, True])
is_positive_definite(method: Literal['gershgorin', 'cholesky', 'eigenvalue'] = 'gershgorin', force_recompute: bool = False) Tensor[source]

Check if the matrix is positive definite.

For batched tensors, checks each matrix independently and returns a boolean tensor with shape matching the batch dimensions.

Parameters:
  • method ({"gershgorin", "cholesky", "eigenvalue"}, optional) – Method for checking: - “gershgorin”: Fast check using Gershgorin circles (sufficient but not necessary) - “cholesky”: Try Cholesky decomposition (necessary and sufficient, slower) - “eigenvalue”: Check smallest eigenvalues (necessary and sufficient, slowest) Default: “gershgorin”.

  • force_recompute (bool, optional) – If True, recompute even if cached. Default: False.

Returns:

Boolean tensor with shape: - [] (scalar) for non-batched tensors - [*batch_shape] for batched tensors

Return type:

torch.Tensor

Examples

>>> A = SparseTensor(val, row, col, (3, 3))
>>> A.is_positive_definite()  # tensor(True) or tensor(False)
>>> A.is_positive_definite(method="cholesky")  # More accurate check
>>> A_batch = SparseTensor(val_batch, row, col, (4, 3, 3))
>>> A_batch.is_positive_definite()  # tensor([True, True, True, True])
connected_components() Tuple[Tensor, int][source]

Find connected components of the graph represented by this sparse matrix.

Uses union-find algorithm for efficiency. Treats the matrix as an undirected graph adjacency matrix.

Returns:

  • labels (torch.Tensor) – Component label for each node, shape [N]. Labels are in range [0, num_components).

  • num_components (int) – Number of connected components.

Notes

  • Only works for non-batched 2D matrices

  • Matrix is treated as undirected (edges in either direction count)

  • Self-loops are ignored for connectivity

Examples

>>> # Block diagonal matrix with 3 components
>>> A = SparseTensor(val, row, col, (100, 100))
>>> labels, num_comp = A.connected_components()
>>> print(f"Found {num_comp} components")
has_isolated_components() bool[source]

Check if the matrix has multiple connected components.

Returns:

True if matrix has more than one connected component.

Return type:

bool

Examples

>>> A = SparseTensor(val, row, col, (100, 100))
>>> if A.has_isolated_components():
...     components = A.to_connected_components()
to_connected_components() SparseTensorList[source]

Split the matrix into a list of connected component subgraphs.

Each component becomes a separate SparseTensor with reindexed nodes.

Returns:

List of SparseTensors, one per connected component.

Return type:

SparseTensorList

Notes

  • Each component’s nodes are reindexed from 0

  • Original node indices can be recovered from the mapping

Examples

>>> A = SparseTensor(val, row, col, (100, 100))
>>> components = A.to_connected_components()
>>> print(f"Split into {len(components)} components")
>>> for i, comp in enumerate(components):
...     print(f"  Component {i}: {comp.shape}")
solve(b: Tensor, backend: Literal['scipy', 'eigen', 'pytorch', 'cusolver', 'cudss', 'auto'] = 'auto', method: Literal['auto', 'superlu', 'umfpack', 'lu', 'qr', 'cholesky', 'ldlt', 'cg', 'bicgstab', 'gmres', 'lgmres', 'minres', 'qmr'] = 'auto', atol: float = 1e-10, maxiter: int = 10000, tol: float = 1e-12) Tensor[source]

Solve the sparse linear system Ax = b.

Automatically handles batched tensors: if A is […batch, M, N] and b is […batch, M], returns x with shape […batch, N].

Parameters:
  • b (torch.Tensor) – Right-hand side vector(s). Shape: - Non-batched: [M] or [M, K] for multiple RHS - Batched: […batch, M] or […batch, M, K]

  • backend ({"auto", "scipy", "eigen", "cusolver", "cudss"}, optional) – Solver backend. Default: “auto” (selects based on device). - “scipy”: Uses SciPy’s sparse solvers (CPU only) - “eigen”: Uses Eigen C++ library (CPU only) - “cusolver”: Uses NVIDIA cuSOLVER (CUDA only) - “cudss”: Uses NVIDIA cuDSS (CUDA only)

  • method (str, optional) – Solver method. Default: “auto” (selects based on matrix properties). - Direct methods: “superlu”, “umfpack”, “lu”, “qr”, “cholesky”, “ldlt” - Iterative methods: “cg”, “bicgstab”, “gmres”, “minres”

  • atol (float, optional) – Absolute tolerance for iterative solvers. Default: 1e-10.

  • maxiter (int, optional) – Maximum iterations for iterative solvers. Default: 10000.

  • tol (float, optional) – Relative tolerance for direct solvers. Default: 1e-12.

Returns:

Solution x with same batch shape as b.

Return type:

torch.Tensor

Raises:

Examples

>>> # Simple solve
>>> A = SparseTensor(val, row, col, (3, 3))
>>> b = torch.randn(3)
>>> x = A.solve(b)
>>> # Batched solve
>>> A_batch = SparseTensor(val_batch, row, col, (4, 3, 3))
>>> b_batch = torch.randn(4, 3)
>>> x_batch = A_batch.solve(b_batch)
>>> # Specify backend
>>> x = A.solve(b, backend='scipy', method='cg')
solve_batch(values: Tensor, b: Tensor, backend: Literal['scipy', 'eigen', 'pytorch', 'cusolver', 'cudss', 'auto'] = 'auto', method: Literal['auto', 'superlu', 'umfpack', 'lu', 'qr', 'cholesky', 'ldlt', 'cg', 'bicgstab', 'gmres', 'lgmres', 'minres', 'qmr'] = 'auto', atol: float = 1e-10, maxiter: int = 10000, tol: float = 1e-12) Tensor[source]

Solve with different values but same sparsity structure.

This is efficient when you have the same structure but different values (e.g., time-stepping, optimization, parameter sweeps).

Parameters:
  • values (torch.Tensor) – Matrix values. Shape […batch, nnz] where … are batch dimensions. All matrices share the same row_indices and col_indices.

  • b (torch.Tensor) – Right-hand side. Shape […batch, M].

  • backend ({"auto", "scipy", "eigen", "cusolver", "cudss"}, optional) – Solver backend. See solve() for details. Default: “auto”.

  • method (str, optional) – Solver method. See solve() for details. Default: “auto”.

  • atol (float, optional) – Absolute tolerance for iterative solvers. Default: 1e-10.

  • maxiter (int, optional) – Maximum iterations for iterative solvers. Default: 10000.

  • tol (float, optional) – Relative tolerance. Default: 1e-12.

Returns:

Solution x with shape […batch, N].

Return type:

torch.Tensor

Examples

>>> # Template matrix
>>> A = SparseTensor(val, row, col, (10, 10))
>>> # Batch of different values
>>> val_batch = torch.stack([val * (1 + 0.1*i) for i in range(4)])  # [4, nnz]
>>> b_batch = torch.randn(4, 10)
>>> # Solve all at once
>>> x_batch = A.solve_batch(val_batch, b_batch)  # [4, 10]
nonlinear_solve(residual_fn, u0: Tensor, *params, method: Literal['newton', 'picard', 'anderson'] = 'newton', tol: float = 1e-06, atol: float = 1e-10, max_iter: int = 50, line_search: bool = True, verbose: bool = False, linear_solver: Literal['scipy', 'eigen', 'pytorch', 'cusolver', 'cudss', 'auto'] = 'pytorch', linear_method: Literal['auto', 'superlu', 'umfpack', 'lu', 'qr', 'cholesky', 'ldlt', 'cg', 'bicgstab', 'gmres', 'lgmres', 'minres', 'qmr'] = 'cg') Tensor[source]

Solve nonlinear equation F(u, A, θ) = 0 with adjoint-based gradients.

The SparseTensor A is automatically passed as the first parameter to the residual function, enabling gradients to flow through A’s values.

Parameters:
  • residual_fn (Callable) – Function F(u, A, *params) -> residual tensor. - u: Current solution estimate - A: This SparseTensor (passed automatically) - *params: Additional parameters with requires_grad=True

  • u0 (torch.Tensor) – Initial guess for solution.

  • *params (torch.Tensor) – Additional parameters (e.g., boundary conditions, coefficients). Tensors with requires_grad=True will receive gradients.

  • method ({'newton', 'picard', 'anderson'}, optional) – Nonlinear solver method: - ‘newton’: Newton-Raphson with line search (default, fast) - ‘picard’: Fixed-point iteration (simple, slow) - ‘anderson’: Anderson acceleration (memory efficient)

  • tol (float, optional) – Relative convergence tolerance. Default: 1e-6.

  • atol (float, optional) – Absolute convergence tolerance. Default: 1e-10.

  • max_iter (int, optional) – Maximum nonlinear iterations. Default: 50.

  • line_search (bool, optional) – Use Armijo line search for Newton. Default: True.

  • verbose (bool, optional) – Print convergence information. Default: False.

  • linear_solver (str, optional) – Backend for linear solves. Default: ‘pytorch’.

  • linear_method (str, optional) – Method for linear solves. Default: ‘cg’.

Returns:

Solution u* satisfying F(u*, A, θ) ≈ 0.

Return type:

torch.Tensor

Examples

>>> # Nonlinear PDE: A @ u + u² = f
>>> def residual(u, A, f):
...     return A @ u + u**2 - f
...
>>> A = SparseTensor(val, row, col, (n, n))
>>> f = torch.randn(n, requires_grad=True)
>>> u0 = torch.zeros(n)
>>>
>>> u = A.nonlinear_solve(residual, u0, f, method='newton')
>>>
>>> # Gradients flow via adjoint method
>>> loss = u.sum()
>>> loss.backward()
>>> print(f.grad)  # ∂u/∂f
>>> print(A.values.grad)  # ∂u/∂A (if A.values.requires_grad)
>>> # Nonlinear elasticity: K(u) @ u = F
>>> def residual_elasticity(u, K, F, material):
...     # K depends on displacement through material nonlinearity
...     return K @ u - F + material * u**3
...
>>> u = K.nonlinear_solve(residual_elasticity, u0, F, material)
norm(ord: Literal['fro', 1, 2] = 'fro') Tensor[source]

Compute matrix norm.

For batched tensors, returns norm for each batch element.

Parameters:

ord ({'fro', 1, 2}, optional) – Norm type: - ‘fro’: Frobenius norm (default) - 1: Maximum absolute column sum - 2: Spectral norm (largest singular value)

Returns:

Norm value(s). Shape [] for non-batched, [*batch_shape] for batched.

Return type:

torch.Tensor

Examples

>>> A = SparseTensor(val, row, col, (3, 3))
>>> A.norm('fro')  # tensor(5.0)
>>> A_batch = SparseTensor(val_batch, row, col, (4, 3, 3))
>>> A_batch.norm('fro')  # tensor([5.0, 5.0, 5.0, 5.0])
spy(batch_idx: Tuple[int, ...] | None = None, ax=None, title: str | None = None, cmap: str = 'viridis', show_grid: bool = True, grid_color: str = '#cccccc', grid_linewidth: float = 0.5, show_colorbar: bool = True, figsize: Tuple[float, float] = (8, 8), save_path: str | None = None, dpi: int = 150)[source]

Visualize the sparsity pattern with values shown as color intensity.

Creates a spy plot where each matrix element is rendered as a pixel. Non-zero elements are colored with intensity proportional to the absolute value, while zero elements are shown as white. This provides a pixel-perfect visualization without overlapping markers.

Parameters:
  • batch_idx (Tuple[int, ...], optional) – For batched tensors, which batch element to visualize. Required if the tensor is batched.

  • ax (matplotlib.axes.Axes, optional) – Axes to plot on. If None, creates a new figure.

  • title (str, optional) – Plot title. Defaults to showing matrix info.

  • cmap (str, optional) – Colormap for values. Default: ‘viridis’. Other options: ‘plasma’, ‘hot’, ‘coolwarm’, ‘Greys’, etc.

  • show_grid (bool, optional) – Whether to show grid lines (only for matrices <= 30x30). Default: True.

  • grid_color (str, optional) – Color of grid lines. Default: ‘#cccccc’ (light gray).

  • grid_linewidth (float, optional) – Width of grid lines. Default: 0.5.

  • show_colorbar (bool, optional) – Whether to show colorbar for values. Default: True.

  • figsize (Tuple[float, float], optional) – Figure size in inches. Default: (8, 8).

  • save_path (str, optional) – If provided, save figure to this path.

  • dpi (int, optional) – DPI for saved figure. Default: 150.

Returns:

ax – The axes object with the plot.

Return type:

matplotlib.axes.Axes

Examples

>>> A = SparseTensor(val, row, col, (100, 100))
>>> A.spy()  # Basic spy plot
>>> A.spy(cmap='hot', show_grid=False)  # Custom colormap, no grid
>>> A.spy(save_path='matrix.png')  # Save to file
>>> # For batched tensor
>>> A_batch = SparseTensor(val_batch, row, col, (4, 100, 100))
>>> A_batch.spy(batch_idx=(0,))  # Visualize first batch element
eigs(k: int = 6, which: str = 'LM', sigma: float | None = None, return_eigenvectors: bool = True) Tuple[Tensor, Tensor | None][source]

Compute k eigenvalues and eigenvectors.

For batched tensors, computes for each batch element. For CUDA tensors, uses LOBPCG algorithm.

Parameters:
  • k (int, optional) – Number of eigenvalues to compute. Default: 6.

  • which ({"LM", "SM", "LR", "SR", "LA", "SA"}, optional) – Which eigenvalues to find: - “LM”: Largest magnitude (default) - “SM”: Smallest magnitude - “LR”/”SR”: Largest/smallest real part - “LA”/”SA”: Largest/smallest algebraic (for symmetric)

  • sigma (float, optional) – Find eigenvalues near sigma (shift-invert mode).

  • return_eigenvectors (bool, optional) – Whether to return eigenvectors. Default: True.

Returns:

  • eigenvalues (torch.Tensor) – Shape [k] for non-batched, [*batch_shape, k] for batched.

  • eigenvectors (torch.Tensor or None) – Shape [M, k] for non-batched, [*batch_shape, M, k] for batched. None if return_eigenvectors is False.

Notes

Gradient Support:

  • Both CPU and CUDA: Fully differentiable via adjoint method

  • Uses O(1) graph nodes regardless of iteration count

  • For symmetric matrices, prefer eigsh() for efficiency

Warning: For non-symmetric matrices with complex eigenvalues, gradient computation is only supported for the real part.

Examples

>>> A = SparseTensor(val.requires_grad_(True), row, col, (n, n))
>>> eigenvalues, eigenvectors = A.eigs(k=3)
>>> loss = eigenvalues.real.sum()  # For complex eigenvalues
>>> loss.backward()
eigsh(k: int = 6, which: str = 'LM', sigma: float | None = None, return_eigenvectors: bool = True) Tuple[Tensor, Tensor | None][source]

Compute k eigenvalues for symmetric matrices.

More efficient than eigs() for symmetric matrices.

Parameters:
  • k (int, optional) – Number of eigenvalues to compute. Default: 6.

  • which ({"LM", "SM", "LA", "SA"}, optional) – Which eigenvalues to find: - “LM”: Largest magnitude (default) - “SM”: Smallest magnitude - “LA”/”SA”: Largest/smallest algebraic

  • sigma (float, optional) – Find eigenvalues near sigma.

  • return_eigenvectors (bool, optional) – Whether to return eigenvectors. Default: True.

Returns:

  • eigenvalues (torch.Tensor) – Shape [k] for non-batched, [*batch_shape, k] for batched.

  • eigenvectors (torch.Tensor or None) – Shape [M, k] for non-batched, [*batch_shape, M, k] for batched.

Notes

Gradient Support:

  • Both CPU and CUDA: Fully differentiable via adjoint method

  • Uses O(1) graph nodes regardless of iteration count

  • Gradient computed as: ∂L/∂A = Σ_i (∂L/∂λ_i) * v_i @ v_i.T

Examples

>>> A = SparseTensor(val.requires_grad_(True), row, col, (n, n))
>>> eigenvalues, eigenvectors = A.eigsh(k=3)
>>> loss = eigenvalues.sum()
>>> loss.backward()  # Computes ∂loss/∂val
svd(k: int = 6) Tuple[Tensor, Tensor, Tensor][source]

Compute truncated SVD.

Parameters:

k (int, optional) – Number of singular values to compute. Default: 6.

Returns:

  • U (torch.Tensor) – Left singular vectors. Shape [M, k] or [*batch_shape, M, k].

  • S (torch.Tensor) – Singular values. Shape [k] or [*batch_shape, k].

  • Vt (torch.Tensor) – Right singular vectors. Shape [k, N] or [*batch_shape, k, N].

Notes

Gradient Support:

  • CUDA: Fully differentiable (uses power iteration with PyTorch operations)

  • CPU: NOT differentiable (uses SciPy which breaks gradient chain)

For differentiable SVD on CPU, use A.to_dense() and torch.linalg.svd().

condition_number(ord: int = 2) Tensor[source]

Estimate condition number.

Parameters:

ord (int, optional) – Norm order for condition number. Default: 2 (spectral).

Returns:

Condition number. Shape [] or [*batch_shape].

Return type:

torch.Tensor

det() Tensor[source]

Compute determinant of the sparse matrix with gradient support.

Uses LU decomposition (CPU) or dense conversion (CUDA) to compute the determinant efficiently. Supports automatic differentiation via the adjoint method.

Returns:

Determinant value. Shape [] for single matrix or [*batch_shape] for batched.

Return type:

torch.Tensor

Raises:

ValueError – If matrix is not square

Notes

  • Only square matrices have determinants

  • For large matrices, determinant values can overflow/underflow

  • Consider using log-determinant for numerical stability in such cases

  • Supports both CPU (via SciPy) and CUDA (via torch.linalg.det)

  • For batched tensors, computes determinant independently for each batch

  • Fully differentiable: gradients computed via adjoint method

  • Gradient formula: ∂det(A)/∂A = det(A) * (A^{-1})^T

Performance Warning

CUDA performance is significantly slower than CPU for sparse matrices!

  • CPU: Uses sparse LU decomposition (O(nnz^1.5)), ~0.3-0.8ms for n=10-1000

  • CUDA: Converts to dense (O(n²) memory + O(n³) compute), ~0.2-2.5ms

The CUDA version requires converting the sparse matrix to dense format because cuSOLVER/cuDSS don’t expose determinant computation for sparse matrices. This makes it inefficient for large sparse matrices.

Recommendation: For sparse matrices, use .cpu().det().cuda() instead:

>>> # Slow: CUDA with dense conversion
>>> det_slow = A_cuda.det()  # ~2.5ms for n=1000
>>>
>>> # Fast: CPU with sparse LU
>>> det_fast = A_cuda.cpu().det()  # ~0.8ms for n=1000
>>> det_fast = det_fast.cuda()  # Move result back if needed

Examples

>>> # Simple 2x2 matrix
>>> val = torch.tensor([1.0, 2.0, 3.0, 4.0], requires_grad=True)
>>> row = torch.tensor([0, 0, 1, 1])
>>> col = torch.tensor([0, 1, 0, 1])
>>> A = SparseTensor(val, row, col, (2, 2))
>>> det = A.det()
>>> print(det)  # Should be -2.0
>>> det.backward()
>>> print(val.grad)  # Gradient w.r.t. matrix values
>>>
>>> # CUDA support
>>> A_cuda = A.cuda()
>>> det_cuda = A_cuda.det()
>>>
>>> # Batched matrices
>>> val_batch = val.unsqueeze(0).expand(3, -1).clone()
>>> A_batch = SparseTensor(val_batch, row, col, (3, 2, 2))
>>> det_batch = A_batch.det()
>>> print(det_batch.shape)  # torch.Size([3])
lu() LUFactorization[source]

Compute LU decomposition for repeated solves.

Returns:

Factorization object with solve() method.

Return type:

LUFactorization

Examples

>>> A = SparseTensor(val, row, col, (10, 10))
>>> lu = A.lu()
>>> x1 = lu.solve(b1)
>>> x2 = lu.solve(b2)  # Reuses factorization
sum(axis: int | Tuple[int, ...] | None = None, keepdim: bool = False) Tensor | SparseTensor[source]

Sum of sparse tensor elements over specified axis.

Parameters:
  • axis (int, tuple of ints, or None) –

    Axis or axes along which to sum. Axes correspond to: - Batch dimensions: […batch] at the beginning - Sparse dimensions: (M, N) at sparse_dim positions - Block dimensions: […block] at the end

    If None, sum over all elements (returns scalar tensor).

  • keepdim (bool) – Whether to keep the reduced dimensions.

Returns:

  • If reducing over sparse dimensions: returns dense tensor

  • If reducing over batch/block dimensions only: returns SparseTensor

  • If axis=None: returns scalar tensor

Return type:

torch.Tensor or SparseTensor

Examples

>>> # Shape: [batch=2, M=10, N=10, block=3]
>>> A = SparseTensor(val, row, col, (2, 10, 10, 3))
>>>
>>> A.sum()           # Scalar: sum all elements
>>> A.sum(axis=0)     # Sum over batch -> [10, 10, 3]
>>> A.sum(axis=1)     # Sum over M (rows) -> [2, 10, 3] (dense)
>>> A.sum(axis=2)     # Sum over N (cols) -> [2, 10, 3] (dense)
>>> A.sum(axis=3)     # Sum over block -> SparseTensor [2, 10, 10]
>>> A.sum(axis=(1,2)) # Sum over M and N -> [2, 3] (dense)
mean(axis: int | Tuple[int, ...] | None = None, keepdim: bool = False) Tensor | SparseTensor[source]

Mean of sparse tensor elements over specified axis.

Note: For sparse dimensions, this computes mean of non-zero values only, NOT the mean over all M*N elements. For full mean, use to_dense().mean().

Parameters:
  • axis (int, tuple of ints, or None) – Axis or axes along which to compute mean.

  • keepdim (bool) – Whether to keep the reduced dimensions.

Returns:

Mean values.

Return type:

torch.Tensor or SparseTensor

Examples

>>> A = SparseTensor(val, row, col, (10, 10))
>>> A.mean()           # Mean of all non-zero values
>>> A.mean(axis=0)     # Mean over batch dimension
prod(axis: int | Tuple[int, ...] | None = None, keepdim: bool = False) Tensor | SparseTensor[source]

Product of sparse tensor elements over specified axis.

Warning: For sparse matrices, zero elements are not included in the product. This means prod() computes the product of non-zero values only.

Parameters:
  • axis (int, tuple of ints, or None) – Axis or axes along which to compute product.

  • keepdim (bool) – Whether to keep the reduced dimensions.

Returns:

Product values.

Return type:

torch.Tensor or SparseTensor

Examples

>>> A = SparseTensor(val, row, col, (10, 10))
>>> A.prod()           # Product of all non-zero values
>>> A.prod(axis=0)     # Product over batch dimension
max(axis: int | Tuple[int, ...] | None = None, keepdim: bool = False) Tensor | SparseTensor[source]

Max of non-zero values over specified axis.

min(axis: int | Tuple[int, ...] | None = None, keepdim: bool = False) Tensor | SparseTensor[source]

Min of non-zero values over specified axis.

abs() SparseTensor[source]

Element-wise absolute value.

sqrt() SparseTensor[source]

Element-wise square root.

square() SparseTensor[source]

Element-wise square.

exp() SparseTensor[source]

Element-wise exponential.

log() SparseTensor[source]

Element-wise natural logarithm.

log10() SparseTensor[source]

Element-wise base-10 logarithm.

log2() SparseTensor[source]

Element-wise base-2 logarithm.

sin() SparseTensor[source]

Element-wise sine.

cos() SparseTensor[source]

Element-wise cosine.

tan() SparseTensor[source]

Element-wise tangent.

sinh() SparseTensor[source]

Element-wise hyperbolic sine.

cosh() SparseTensor[source]

Element-wise hyperbolic cosine.

tanh() SparseTensor[source]

Element-wise hyperbolic tangent.

sigmoid() SparseTensor[source]

Element-wise sigmoid.

relu() SparseTensor[source]

Element-wise ReLU.

clamp(min: float | None = None, max: float | None = None) SparseTensor[source]

Element-wise clamp.

sign() SparseTensor[source]

Element-wise sign.

floor() SparseTensor[source]

Element-wise floor.

ceil() SparseTensor[source]

Element-wise ceil.

round() SparseTensor[source]

Element-wise round.

reciprocal() SparseTensor[source]

Element-wise reciprocal (1/x).

pow(exponent: float | int | Tensor) SparseTensor[source]

Element-wise power.

logical_not() SparseTensor[source]

Element-wise logical NOT.

logical_and(other: SparseTensor) SparseTensor[source]

Element-wise logical AND.

logical_or(other: SparseTensor) SparseTensor[source]

Element-wise logical OR.

logical_xor(other: SparseTensor) SparseTensor[source]

Element-wise logical XOR.

isnan() SparseTensor[source]

Element-wise isnan check.

isinf() SparseTensor[source]

Element-wise isinf check.

isfinite() SparseTensor[source]

Element-wise isfinite check.

detach() SparseTensor[source]

Detach from computation graph. Preserves subclass type.

requires_grad_(requires_grad: bool = True) SparseTensor[source]

Enable/disable gradient tracking.

property requires_grad: bool

Whether gradient tracking is enabled.

property grad: Tensor | None

Gradient of values if available.

clone() SparseTensor[source]

Create a copy of this SparseTensor. Preserves subclass type.

contiguous() SparseTensor[source]

Make values contiguous in memory. Preserves subclass type.

save(path: str | PathLike, metadata: Dict[str, str] | None = None) None[source]

Save SparseTensor to safetensors format.

Parameters:
  • path (str or PathLike) – Output file path (should end with .safetensors).

  • metadata (dict, optional) – Additional metadata to store.

Example

>>> A = SparseTensor(val, row, col, (100, 100))
>>> A.save("matrix.safetensors")
classmethod load(path: str | PathLike, device: str | device = 'cpu') SparseTensor[source]

Load SparseTensor from safetensors format.

Parameters:
  • path (str or PathLike) – Input file path.

  • device (str or torch.device) – Device to load tensors to.

Returns:

The loaded sparse tensor.

Return type:

SparseTensor

Example

>>> A = SparseTensor.load("matrix.safetensors", device="cuda")
save_distributed(directory: str | PathLike, num_partitions: int, partition_method: str = 'simple', coords: Tensor | None = None, verbose: bool = False) None[source]

Save as partitioned files for distributed loading.

Creates a directory with metadata and per-partition files. Each rank can then load only its own partition.

Parameters:
  • directory (str or PathLike) – Output directory path.

  • num_partitions (int) – Number of partitions to create.

  • partition_method (str) – ‘simple’, ‘metis’, or ‘geometric’.

  • coords (torch.Tensor, optional) – Node coordinates for geometric partitioning.

  • verbose (bool) – Print progress.

Example

>>> A.save_distributed("matrix_dist", num_partitions=4)
# Each rank loads its partition:
>>> partition = DSparseMatrix.load("matrix_dist", rank)

SparseTensorList

Container for multiple sparse matrices with different sparsity patterns. Useful for batched operations on heterogeneous graphs.

class torch_sla.SparseTensorList(tensors: List[SparseTensor])[source]

Bases: object

A list of SparseTensors with different structures.

Provides a unified interface for batch operations on matrices with different sparsity patterns. Unlike batched SparseTensor (which requires same structure), SparseTensorList allows each matrix to have different shape and sparsity pattern.

tensorsList[SparseTensor]

List of SparseTensor objects.

shapes

List of shapes for each tensor.

Type:

List[Tuple[int, …]]

device

Device (from first tensor).

Type:

torch.device

dtype

Data type (from first tensor).

Type:

torch.dtype

Examples

>>> # Create matrices with different sizes
>>> A1 = SparseTensor(val1, row1, col1, (10, 10))
>>> A2 = SparseTensor(val2, row2, col2, (20, 20))
>>> A3 = SparseTensor(val3, row3, col3, (30, 30))
>>> # Create list
>>> matrices = SparseTensorList([A1, A2, A3])
>>> print(matrices.shapes)  # [(10, 10), (20, 20), (30, 30)]
>>> # Batch solve
>>> x_list = matrices.solve([b1, b2, b3])
>>> # Check properties for all
>>> is_sym = matrices.is_symmetric()  # [tensor(True), tensor(True), tensor(True)]
classmethod from_coo_list(matrices: List[Tuple[Tensor, Tensor, Tensor, Tuple[int, ...]]]) SparseTensorList[source]

Create from list of COO data tuples.

Parameters:

matrices (List[Tuple]) – List of (values, row_indices, col_indices, shape) tuples.

Returns:

List of SparseTensors.

Return type:

SparseTensorList

Examples

>>> data = [
...     (val1, row1, col1, (10, 10)),
...     (val2, row2, col2, (20, 20)),
... ]
>>> matrices = SparseTensorList.from_coo_list(data)
classmethod from_torch_sparse_list(A_list: List[Tensor]) SparseTensorList[source]

Create from list of PyTorch sparse tensors.

Parameters:

A_list (List[torch.Tensor]) – List of PyTorch sparse COO tensors.

Returns:

List of SparseTensors.

Return type:

SparseTensorList

property shapes: List[Tuple[int, ...]]

List of shapes for each tensor.

property device: device

Device of the first tensor.

property dtype: dtype

Data type of the first tensor.

to(device: str | device) SparseTensorList[source]

Move all tensors to device.

Parameters:

device (str or torch.device) – Target device.

Returns:

New list with tensors on target device.

Return type:

SparseTensorList

cuda() SparseTensorList[source]

Move all tensors to CUDA.

cpu() SparseTensorList[source]

Move all tensors to CPU.

sum(axis: int | None = None) List[Tensor] | Tensor[source]

Sum values in each matrix.

Parameters:

axis (int, optional) – If None: sum all values in each matrix, return List[scalar]. If 0: sum over rows for each matrix. If 1: sum over columns for each matrix.

Returns:

If axis is None: List of scalar tensors (one per matrix). If axis is 0 or 1: List of 1D tensors.

Return type:

List[torch.Tensor] or torch.Tensor

Examples

>>> matrices = SparseTensorList([A1, A2, A3])
>>> totals = matrices.sum()  # [sum(A1), sum(A2), sum(A3)]
>>> row_sums = matrices.sum(axis=1)  # [A1.sum(1), A2.sum(1), ...]
mean(axis: int | None = None) List[Tensor][source]

Mean of values in each matrix.

Parameters:

axis (int, optional) – Same as sum().

Returns:

List of mean values/vectors.

Return type:

List[torch.Tensor]

max() List[Tensor][source]

Maximum value in each matrix.

min() List[Tensor][source]

Minimum value in each matrix.

abs() SparseTensorList[source]

Absolute value of all elements.

clamp(min: float | None = None, max: float | None = None) SparseTensorList[source]

Clamp values in all matrices.

pow(exponent: float) SparseTensorList[source]

Element-wise power.

sqrt() SparseTensorList[source]

Element-wise square root.

exp() SparseTensorList[source]

Element-wise exponential.

log() SparseTensorList[source]

Element-wise natural logarithm.

solve(b_list: List[Tensor], **kwargs) List[Tensor][source]

Solve linear systems for all matrices.

Parameters:
  • b_list (List[torch.Tensor]) – List of right-hand side vectors, one per matrix.

  • **kwargs – Additional arguments passed to SparseTensor.solve().

Returns:

List of solutions.

Return type:

List[torch.Tensor]

Examples

>>> matrices = SparseTensorList([A1, A2, A3])
>>> x_list = matrices.solve([b1, b2, b3])
is_symmetric(**kwargs) List[Tensor][source]

Check symmetry for all matrices.

Parameters:

**kwargs – Arguments passed to SparseTensor.is_symmetric().

Returns:

List of boolean tensors.

Return type:

List[torch.Tensor]

is_positive_definite(**kwargs) List[Tensor][source]

Check positive definiteness for all matrices.

Parameters:

**kwargs – Arguments passed to SparseTensor.is_positive_definite().

Returns:

List of boolean tensors.

Return type:

List[torch.Tensor]

norm(ord: Literal['fro', 1, 2] = 'fro') List[Tensor][source]

Compute norms for all matrices.

Parameters:

ord ({'fro', 1, 2}) – Norm type.

Returns:

List of norm values.

Return type:

List[torch.Tensor]

eigs(k: int = 6, **kwargs) List[Tuple[Tensor, Tensor | None]][source]

Compute eigenvalues for all matrices.

Parameters:
  • k (int) – Number of eigenvalues.

  • **kwargs – Additional arguments.

Returns:

List of (eigenvalues, eigenvectors) tuples.

Return type:

List[Tuple[torch.Tensor, Optional[torch.Tensor]]]

eigsh(k: int = 6, **kwargs) List[Tuple[Tensor, Tensor | None]][source]

Compute eigenvalues for symmetric matrices.

Parameters:
  • k (int) – Number of eigenvalues.

  • **kwargs – Additional arguments.

Returns:

List of (eigenvalues, eigenvectors) tuples.

Return type:

List[Tuple[torch.Tensor, Optional[torch.Tensor]]]

svd(k: int = 6) List[Tuple[Tensor, Tensor, Tensor]][source]

Compute SVD for all matrices.

Parameters:

k (int) – Number of singular values.

Returns:

List of (U, S, Vt) tuples.

Return type:

List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]

condition_number(ord: int = 2) List[Tensor][source]

Compute condition numbers for all matrices.

Parameters:

ord (int) – Norm order.

Returns:

List of condition numbers.

Return type:

List[torch.Tensor]

det() List[Tensor][source]

Compute determinants for all matrices.

Returns:

List of determinant values.

Return type:

List[torch.Tensor]

Examples

>>> matrices = SparseTensorList([A1, A2, A3])
>>> dets = matrices.det()
>>> print([d.item() for d in dets])
spy(indices: List[int] | None = None, ncols: int = 3, figsize: Tuple[float, float] | None = None, **kwargs)[source]

Visualize sparsity patterns for multiple matrices in a grid.

Parameters:
  • indices (List[int], optional) – Which matrices to visualize. Default: all.

  • ncols (int, optional) – Number of columns in subplot grid. Default: 3.

  • figsize (Tuple[float, float], optional) – Figure size. Auto-computed if None.

  • **kwargs – Additional arguments passed to SparseTensor.spy().

Returns:

fig – The figure object.

Return type:

matplotlib.figure.Figure

Examples

>>> matrices = SparseTensorList([A1, A2, A3, A4])
>>> matrices.spy()  # Visualize all in grid
>>> matrices.spy(indices=[0, 2])  # Visualize specific ones
to_block_diagonal() SparseTensor[source]

Merge all matrices into a single block-diagonal SparseTensor.

Creates a sparse matrix where each input matrix appears as a block on the diagonal: diag(A1, A2, …, An).

Returns:

Block-diagonal matrix with shape (sum(M_i), sum(N_i)).

Return type:

SparseTensor

Notes

The resulting matrix has the structure:

` [A1  0  0  ...] [ 0 A2  0  ...] [ 0  0 A3  ...] [... ... ... ] `

Examples

>>> A1 = SparseTensor(val1, row1, col1, (10, 10))
>>> A2 = SparseTensor(val2, row2, col2, (20, 20))
>>> stl = SparseTensorList([A1, A2])
>>> A_block = stl.to_block_diagonal()  # Shape (30, 30)
classmethod from_block_diagonal(sparse: SparseTensor, sizes: List[Tuple[int, int]]) SparseTensorList[source]

Split a block-diagonal SparseTensor into a list of matrices.

Parameters:
  • sparse (SparseTensor) – Block-diagonal matrix to split.

  • sizes (List[Tuple[int, int]]) – List of (rows, cols) for each block. Must sum to sparse.shape.

Returns:

List of extracted blocks.

Return type:

SparseTensorList

Examples

>>> A_block = SparseTensor(val, row, col, (30, 30))
>>> stl = SparseTensorList.from_block_diagonal(A_block, [(10, 10), (20, 20)])
>>> print(len(stl))  # 2
property block_sizes: List[Tuple[int, int]]

Get the (rows, cols) size of each matrix.

Returns:

List of (M, N) tuples.

Return type:

List[Tuple[int, int]]

property total_nnz: int

Total number of non-zeros across all matrices.

property total_shape: Tuple[int, int]

Shape of the block-diagonal representation.

partition(num_partitions: int, threshold: int = 1000, partition_method: str = 'auto', device: str | device | None = None, verbose: bool = False) DSparseTensorList[source]

Create distributed version for parallel computing.

Parameters:
  • num_partitions (int) – Number of partitions (typically = world_size).

  • threshold (int) – Graphs with nodes >= threshold are partitioned across ranks. Smaller graphs are assigned whole to individual ranks.

  • partition_method (str) – Method for partitioning large graphs: ‘metis’, ‘simple’, ‘auto’.

  • device (torch.device, optional) – Target device.

  • verbose (bool) – Print partition info.

Returns:

Distributed sparse tensor list.

Return type:

DSparseTensorList

Notes

Hybrid Strategy:

  • Small graphs (< threshold nodes): Assigned whole to ranks round-robin. Zero edge cuts, no halo exchange needed.

  • Large graphs (>= threshold nodes): Partitioned across all ranks. Uses halo exchange for boundary nodes.

This is optimal for molecular datasets where most molecules are small but some (proteins, polymers) can be very large.

Examples

>>> stl = SparseTensorList([A1, A2, A3, ...])
>>> dstl = stl.partition(num_partitions=4, threshold=1000)
>>> y_list = dstl @ x_list  # Distributed matmul

LUFactorization

LU factorization for efficient repeated solves with the same matrix.

class torch_sla.LUFactorization(lu_factor, shape: Tuple[int, int], dtype: dtype, device: device)[source]

Bases: object

LU factorization wrapper for efficient repeated solves.

Created by SparseTensor.lu().

Parameters:

Examples

>>> A = SparseTensor(val, row, col, (10, 10))
>>> lu = A.lu()
>>> x1 = lu.solve(b1)  # First solve
>>> x2 = lu.solve(b2)  # Much faster - reuses factorization
solve(b: Tensor) Tensor[source]

Solve Ax = b using the cached factorization.

Parameters:

b (torch.Tensor) – Right-hand side vector.

Returns:

Solution x.

Return type:

torch.Tensor


Distributed Classes

DSparseTensor

Distributed sparse tensor with domain decomposition support. Uses halo exchange for communication between partitions.

class torch_sla.DSparseTensor(values: Tensor, row_indices: Tensor, col_indices: Tensor, shape: Tuple[int, int], num_partitions: int, coords: Tensor | None = None, partition_method: str = 'auto', device: str | device | None = None, verbose: bool = True)[source]

Bases: object

Distributed Sparse Tensor with automatic partitioning and halo exchange.

A Pythonic wrapper that provides a unified interface for distributed sparse matrix operations. Supports indexing to access individual partitions.

Parameters:
  • values (torch.Tensor) – Non-zero values [nnz]

  • row_indices (torch.Tensor) – Row indices [nnz]

  • col_indices (torch.Tensor) – Column indices [nnz]

  • shape (Tuple[int, int]) – Matrix shape (m, n)

  • num_partitions (int) – Number of partitions to create

  • coords (torch.Tensor, optional) – Node coordinates for geometric partitioning [num_nodes, dim]

  • partition_method (str) – Partitioning method: ‘metis’, ‘rcb’, ‘slicing’, ‘simple’

  • device (str or torch.device) – Device for the matrix data

  • verbose (bool) – Whether to print partition info

Example

>>> import torch
>>> from torch_sla import DSparseTensor
>>>
>>> # Create distributed tensor with 4 partitions
>>> A = DSparseTensor(val, row, col, shape, num_partitions=4)
>>>
>>> # Access individual partitions
>>> A0 = A[0]  # First partition
>>> A1 = A[1]  # Second partition
>>>
>>> # Iterate over partitions
>>> for partition in A:
>>>     x = partition.solve(b_local)
>>>
>>> # Properties
>>> print(A.num_partitions)  # 4
>>> print(A.shape)           # Global shape
>>> print(len(A))            # 4
>>>
>>> # Move to CUDA
>>> A_cuda = A.cuda()
>>>
>>> # Local halo exchange (for testing)
>>> x_list = [torch.zeros(A[i].num_local) for i in range(4)]
>>> A.halo_exchange_local(x_list)
classmethod from_sparse_tensor(sparse_tensor: SparseTensor, num_partitions: int, coords: Tensor | None = None, partition_method: str = 'auto', device: str | device | None = None, verbose: bool = True) DSparseTensor[source]

Create DSparseTensor from a SparseTensor.

Parameters:
  • sparse_tensor (SparseTensor) – Input sparse tensor (must be 2D, not batched)

  • num_partitions (int) – Number of partitions

  • coords (torch.Tensor, optional) – Node coordinates for geometric partitioning

  • partition_method (str) – Partitioning method

  • device (str or torch.device, optional) – Target device (defaults to sparse_tensor’s device)

  • verbose (bool) – Whether to print partition info

Returns:

Distributed sparse tensor

Return type:

DSparseTensor

classmethod from_torch_sparse(A: Tensor, num_partitions: int, **kwargs) DSparseTensor[source]

Create DSparseTensor from PyTorch sparse tensor.

classmethod from_global_distributed(values: Tensor, row_indices: Tensor, col_indices: Tensor, shape: Tuple[int, int], rank: int, world_size: int, coords: Tensor | None = None, partition_method: str = 'auto', device: str | device | None = None, verbose: bool = True) DSparseMatrix[source]

Create local partition in a distributed-safe manner.

This method ensures that all ranks compute the same partition assignment by having rank 0 compute the partition IDs and broadcasting to all ranks.

Parameters:
  • values (torch.Tensor) – Global non-zero values [nnz]

  • row_indices (torch.Tensor) – Global row indices [nnz]

  • col_indices (torch.Tensor) – Global column indices [nnz]

  • shape (Tuple[int, int]) – Global matrix shape (M, N)

  • rank (int) – Current process rank

  • world_size (int) – Total number of processes

  • coords (torch.Tensor, optional) – Node coordinates for geometric partitioning [num_nodes, dim]

  • partition_method (str) – Partitioning method: ‘metis’, ‘rcb’, ‘slicing’, ‘simple’

  • device (str or torch.device, optional) – Target device

  • verbose (bool) – Whether to print partition info

Returns:

Local partition matrix for this rank

Return type:

DSparseMatrix

Example

>>> import torch.distributed as dist
>>>
>>> # In each process:
>>> rank = dist.get_rank()
>>> world_size = dist.get_world_size()
>>>
>>> local_matrix = DSparseTensor.from_global_distributed(
...     val, row, col, shape,
...     rank=rank, world_size=world_size
... )
classmethod from_device_mesh(values: Tensor, row_indices: Tensor, col_indices: Tensor, shape: Tuple[int, int], device_mesh: DeviceMesh, coords: Tensor | None = None, partition_method: str = 'simple', placement: str = 'shard_rows', verbose: bool = False) DSparseMatrix[source]

Create local partition using PyTorch DeviceMesh.

This is the recommended method for distributed training with PyTorch’s DTensor ecosystem. Each rank receives only its local partition.

Parameters:
  • values (torch.Tensor) – Global non-zero values [nnz] (same on all ranks)

  • row_indices (torch.Tensor) – Global row indices [nnz]

  • col_indices (torch.Tensor) – Global column indices [nnz]

  • shape (Tuple[int, int]) – Global matrix shape (M, N)

  • device_mesh (DeviceMesh) – PyTorch DeviceMesh specifying device topology

  • coords (torch.Tensor, optional) – Node coordinates for geometric partitioning

  • partition_method (str) – Partitioning method: ‘metis’, ‘rcb’, ‘simple’ Default is ‘simple’ for determinism in distributed setting

  • placement (str) – How to distribute: ‘shard_rows’, ‘shard_cols’, ‘replicate’

  • verbose (bool) – Whether to print partition info

Returns:

Local partition for this rank

Return type:

DSparseMatrix

Example

>>> from torch.distributed.device_mesh import init_device_mesh
>>> from torch_sla import DSparseTensor
>>>
>>> # Initialize 4-GPU device mesh
>>> mesh = init_device_mesh("cuda", (4,), mesh_dim_names=("dp",))
>>>
>>> # Create distributed sparse tensor (each rank gets its partition)
>>> local_matrix = DSparseTensor.from_device_mesh(
...     val, row, col, shape,
...     device_mesh=mesh,
...     partition_method='simple'
... )
>>>
>>> # Local operations
>>> y_local = local_matrix.matvec(x_local)
>>> x_local = local_matrix.solve(b_local)
property shape: Tuple[int, int]

Global matrix shape.

property num_partitions: int

Number of partitions.

property device: device

Device of the matrix data.

property dtype: dtype

Data type of matrix values.

property nnz: int

Total number of non-zeros.

property partition_ids: Tensor

Partition assignment for each node.

property is_cuda: bool

Check if matrix is on CUDA.

to(device: str | device) DSparseTensor[source]

Move all partitions to a different device.

Parameters:

device (str or torch.device) – Target device

Returns:

New distributed tensor on target device

Return type:

DSparseTensor

cuda(device: int | None = None) DSparseTensor[source]

Move to CUDA device.

cpu() DSparseTensor[source]

Move to CPU.

halo_exchange_local(x_list: List[Tensor]) None[source]

Local halo exchange for single-process simulation.

Exchanges halo values between all partitions locally. Useful for testing without actual distributed setup.

Parameters:

x_list (List[torch.Tensor]) – List of local vectors, one per partition. Each vector is modified in-place to update halo values.

matvec_all(x_list: List[Tensor], exchange_halo: bool = True) List[Tensor][source]

Matrix-vector multiply on all partitions.

Performs y = A @ x for each partition, with optional halo exchange.

Parameters:
  • x_list (List[torch.Tensor]) – List of local vectors, one per partition. Each vector should have size = num_owned + num_halo for that partition.

  • exchange_halo (bool) – Whether to perform halo exchange before multiplication. Default True.

Returns:

List of result vectors, one per partition. Each result has size = num_owned (only owned nodes have valid results).

Return type:

List[torch.Tensor]

Example

>>> D = SparseTensor(val, row, col, shape).partition(4)
>>> x_local = D.scatter_local(x_global)
>>> y_local = D.matvec_all(x_local)
>>> y_global = D.gather_global(y_local)
solve_all(b_list: List[Tensor], **kwargs) List[Tensor][source]

Solve on all partitions (subdomain solves).

NOTE: This performs LOCAL subdomain solves, NOT a global distributed solve. Each partition solves its own local system independently. For a true distributed solve, use solve_distributed().

Parameters:
  • b_list (List[torch.Tensor]) – List of local RHS vectors, one per partition

  • **kwargs – Additional arguments passed to each partition’s solve method

Returns:

List of solution vectors, one per partition

Return type:

List[torch.Tensor]

solve_distributed(b_global: Tensor | DTensor, method: str = 'cg', atol: float = 1e-10, maxiter: int = 1000, verbose: bool = False) Tensor | DTensor[source]

Distributed solve: find x such that A @ x = b using all partitions.

This performs a TRUE distributed solve where all partitions collaborate to solve the global system. Uses distributed CG with global reductions.

Parameters:
  • b_global (torch.Tensor or DTensor) – Global RHS vector [N]. - If torch.Tensor: treated as global vector - If DTensor: automatically handles distributed input/output

  • method (str) – Solver method: ‘cg’ (Conjugate Gradient)

  • atol (float) – Absolute tolerance for convergence

  • maxiter (int) – Maximum iterations

  • verbose (bool) – Print convergence info

Returns:

Global solution vector [N]. Returns DTensor if input is DTensor, otherwise torch.Tensor.

Return type:

torch.Tensor or DTensor

Example

>>> D = A.partition(num_partitions=4)
>>> x = D.solve_distributed(b)  # Distributed CG solve
>>> residual = torch.norm(A @ x - b)
>>> # With DTensor input
>>> from torch.distributed.tensor import DTensor, Replicate
>>> b_dt = DTensor.from_local(b_local, mesh, [Replicate()])
>>> x_dt = D.solve_distributed(b_dt)  # Returns DTensor
gather_global(x_list: List[Tensor]) Tensor[source]

Gather local vectors to global vector.

Parameters:

x_list (List[torch.Tensor]) – List of local vectors, one per partition

Returns:

Global vector

Return type:

torch.Tensor

scatter_local(x_global: Tensor) List[Tensor][source]

Scatter global vector to local vectors.

Parameters:

x_global (torch.Tensor) – Global vector

Returns:

List of local vectors (with halo values filled)

Return type:

List[torch.Tensor]

to_sparse_tensor() SparseTensor[source]

Gather all partitions into a single SparseTensor.

This creates a global SparseTensor from the distributed data. Useful for verification, debugging, or when you need to perform operations that require the full matrix.

Returns:

Global sparse tensor containing all data

Return type:

SparseTensor

Example

>>> D = DSparseTensor(val, row, col, shape, num_partitions=4)
>>> A = D.to_sparse_tensor()  # Gather to global SparseTensor
>>> x = A.solve(b)  # Solve on the full matrix
gather() SparseTensor

Gather all partitions into a single SparseTensor.

This creates a global SparseTensor from the distributed data. Useful for verification, debugging, or when you need to perform operations that require the full matrix.

Returns:

Global sparse tensor containing all data

Return type:

SparseTensor

Example

>>> D = DSparseTensor(val, row, col, shape, num_partitions=4)
>>> A = D.to_sparse_tensor()  # Gather to global SparseTensor
>>> x = A.solve(b)  # Solve on the full matrix
to_list() DSparseTensorList[source]

Split into DSparseTensorList based on connected components.

If the matrix has isolated subgraphs (block-diagonal structure), splits it into separate distributed matrices, one per component.

Returns:

List of distributed matrices, one per connected component.

Return type:

DSparseTensorList

Notes

This is useful when you have a block-diagonal matrix representing multiple independent graphs and want to process them separately.

Examples

>>> D = DSparseTensor(val, row, col, shape, num_partitions=4)
>>> if D.has_isolated_components():
...     dstl = D.to_list()  # Split into components
has_isolated_components() bool[source]

Check if the matrix has multiple connected components.

Returns:

True if matrix has more than one connected component.

Return type:

bool

classmethod from_list(dstl: DSparseTensorList, verbose: bool = False) DSparseTensor[source]

Merge DSparseTensorList into a single block-diagonal DSparseTensor.

Parameters:
  • dstl (DSparseTensorList) – List of distributed matrices to merge.

  • verbose (bool) – Print info.

Returns:

Block-diagonal distributed matrix.

Return type:

DSparseTensor

Examples

>>> dstl = DSparseTensorList.from_sparse_tensor_list(stl, 4)
>>> D = DSparseTensor.from_list(dstl)  # Merge to block-diagonal
scatter_to_dtensor(x_global: Tensor, device_mesh: DeviceMesh, shard_dim: int = 0) DTensor[source]

Convert a global tensor to a sharded DTensor aligned with matrix partitioning.

This creates a DTensor where each rank holds the portion of the vector corresponding to its owned nodes in the matrix partitioning.

Parameters:
  • x_global (torch.Tensor) – Global vector of shape [N]

  • device_mesh (DeviceMesh) – PyTorch DeviceMesh for distribution

  • shard_dim (int) – Dimension to shard (default 0 for vectors)

Returns:

Sharded DTensor with local data for this rank

Return type:

DTensor

Example

>>> mesh = init_device_mesh("cuda", (4,))
>>> x_global = torch.randn(N)
>>> x_dt = D.scatter_to_dtensor(x_global, mesh)
gather_from_dtensor(x_dtensor: DTensor) Tensor[source]

Convert a DTensor to a global tensor.

Parameters:

x_dtensor (DTensor) – Distributed tensor

Returns:

Full global tensor

Return type:

torch.Tensor

Example

>>> x_global = D.gather_from_dtensor(x_dt)
to_dtensor(x: Tensor, device_mesh: DeviceMesh, replicate: bool = True) DTensor[source]

Convert a tensor to DTensor with specified placement.

Parameters:
  • x (torch.Tensor) – Input tensor

  • device_mesh (DeviceMesh) – PyTorch DeviceMesh

  • replicate (bool) – If True, create a replicated DTensor (same data on all ranks). If False, create a sharded DTensor (data is split).

Returns:

Resulting DTensor

Return type:

DTensor

Example

>>> mesh = init_device_mesh("cuda", (4,))
>>> x_dt = D.to_dtensor(x, mesh, replicate=True)
property supports_dtensor: bool

Check if DTensor operations are available.

eigsh(k: int = 6, which: str = 'LM', sigma: float | None = None, return_eigenvectors: bool = True, maxiter: int = 1000, tol: float = 1e-08) Tuple[Tensor, Tensor | None][source]

Compute k eigenvalues for symmetric matrices using distributed LOBPCG.

This is a TRUE distributed algorithm - no data gather required. Uses distributed matvec with global QR decomposition.

Parameters:
  • k (int, optional) – Number of eigenvalues to compute. Default: 6.

  • which ({"LM", "SM", "LA", "SA"}, optional) – Which eigenvalues to find: - “LM”/”LA”: Largest (default) - “SM”/”SA”: Smallest

  • sigma (float, optional) – Find eigenvalues near sigma (not yet supported).

  • return_eigenvectors (bool, optional) – Whether to return eigenvectors. Default: True.

  • maxiter (int, optional) – Maximum LOBPCG iterations. Default: 1000.

  • tol (float, optional) – Convergence tolerance. Default: 1e-8.

Returns:

  • eigenvalues (torch.Tensor) – Shape [k].

  • eigenvectors (torch.Tensor or None) – Shape [N, k] if return_eigenvectors is True.

Notes

Distributed Algorithm:

  • Uses distributed LOBPCG (Locally Optimal Block PCG)

  • Only requires distributed matvec + global reductions

  • Memory: O(N * k) per node for eigenvectors

  • Communication: O(k^2) per iteration for Rayleigh-Ritz

Gradient Support:

  • Gradients flow through the distributed matvec operations

  • O(iterations) graph nodes (not O(1) like adjoint)

eigs(k: int = 6, which: str = 'LM', sigma: float | None = None, return_eigenvectors: bool = True, maxiter: int = 1000, tol: float = 1e-08) Tuple[Tensor, Tensor | None][source]

Compute k eigenvalues using distributed LOBPCG.

For symmetric matrices, equivalent to eigsh(). For non-symmetric, currently falls back to eigsh() (symmetric assumption).

Parameters:
  • k (int, optional) – Number of eigenvalues to compute. Default: 6.

  • which (str, optional) – Which eigenvalues to find.

  • sigma (float, optional) – Find eigenvalues near sigma.

  • return_eigenvectors (bool, optional) – Whether to return eigenvectors. Default: True.

  • maxiter (int, optional) – Maximum iterations. Default: 1000.

  • tol (float, optional) – Convergence tolerance. Default: 1e-8.

Returns:

  • eigenvalues (torch.Tensor) – Shape [k].

  • eigenvectors (torch.Tensor or None) – Shape [N, k] if return_eigenvectors is True.

svd(k: int = 6, maxiter: int = 1000, tol: float = 1e-08) Tuple[Tensor, Tensor, Tensor][source]

Compute truncated SVD using distributed power iteration.

Uses A^T @ A for eigenvalues, then recovers U from A @ V.

Parameters:
  • k (int, optional) – Number of singular values to compute. Default: 6.

  • maxiter (int, optional) – Maximum iterations. Default: 1000.

  • tol (float, optional) – Convergence tolerance. Default: 1e-8.

Returns:

  • U (torch.Tensor) – Left singular vectors. Shape [M, k].

  • S (torch.Tensor) – Singular values. Shape [k].

  • Vt (torch.Tensor) – Right singular vectors. Shape [k, N].

Notes

Distributed Algorithm:

  • Computes eigenvalues of A^T @ A using distributed LOBPCG

  • No data gather required

norm(ord: Literal['fro', 1, 2] = 'fro') Tensor[source]

Compute matrix norm (distributed).

For Frobenius norm, computed locally and aggregated. For spectral norm, uses distributed SVD.

Parameters:

ord ({'fro', 1, 2}) – Type of norm: - ‘fro’: Frobenius norm (distributed sum) - 1: Maximum column sum - 2: Spectral norm (largest singular value via distributed SVD)

Returns:

Scalar tensor containing the norm value.

Return type:

torch.Tensor

condition_number(ord: int = 2) Tensor[source]

Estimate condition number using distributed SVD.

Parameters:

ord (int, optional) – Norm order. Default: 2 (spectral).

Returns:

Condition number estimate (σ_max / σ_min).

Return type:

torch.Tensor

det() Tensor[source]

Compute determinant of the distributed sparse matrix.

WARNING: This operation requires gathering the full matrix to compute the determinant, as determinant is a global property that cannot be computed in a truly distributed manner without full matrix information.

The determinant is computed by: 1. Gathering all partitions into a global SparseTensor 2. Computing the determinant using LU decomposition (CPU) or

torch.linalg.det (CUDA)

Returns:

Determinant value (scalar tensor).

Return type:

torch.Tensor

Raises:

ValueError – If matrix is not square

Notes

  • Only square matrices have determinants

  • This method gathers all data, so use with caution for large matrices

  • Supports gradient computation via autograd

  • For very large matrices, consider using log-determinant or other approximations instead

Examples

>>> import torch
>>> from torch_sla import DSparseTensor
>>>
>>> # Create distributed sparse matrix
>>> val = torch.tensor([4.0, -1.0, -1.0, 4.0, -1.0, -1.0, 4.0])
>>> row = torch.tensor([0, 0, 1, 1, 1, 2, 2])
>>> col = torch.tensor([0, 1, 0, 1, 2, 1, 2])
>>> D = DSparseTensor(val, row, col, (3, 3), num_partitions=2)
>>>
>>> # Compute determinant (gathers to single node)
>>> det = D.det()
>>> print(det)
>>>
>>> # With gradient support
>>> val = val.requires_grad_(True)
>>> D = DSparseTensor(val, row, col, (3, 3), num_partitions=2)
>>> det = D.det()
>>> det.backward()
>>> print(val.grad)  # Gradient w.r.t. matrix values
T() DSparseTensor[source]

Transpose the distributed sparse tensor.

Returns a new DSparseTensor with swapped row/column indices.

Returns:

Transposed matrix.

Return type:

DSparseTensor

to_dense() Tensor[source]

Convert to dense tensor.

WARNING: This gathers all data to a single node. Only use for small matrices or debugging.

Returns:

Dense matrix of shape (M, N).

Return type:

torch.Tensor

is_symmetric(atol: float = 1e-08, rtol: float = 1e-05) Tensor[source]

Check if matrix is symmetric.

Can be done distributedly by comparing values with transpose.

Parameters:
  • atol (float) – Absolute tolerance for symmetry check.

  • rtol (float) – Relative tolerance for symmetry check.

Returns:

Boolean scalar tensor.

Return type:

torch.Tensor

is_positive_definite() Tensor[source]

Check if matrix is positive definite.

Uses distributed eigenvalue computation.

Returns:

Boolean scalar tensor.

Return type:

torch.Tensor

lu()[source]

Compute LU decomposition.

WARNING: LU is inherently not distributed-friendly. This gathers data to a single node.

For distributed solves, use solve_distributed() with iterative methods.

Returns:

Factorization object with solve() method.

Return type:

LUFactorization

spy(**kwargs)[source]

Visualize sparsity pattern.

Gathers data for visualization.

Parameters:

**kwargs – Arguments passed to SparseTensor.spy().

nonlinear_solve(residual_fn, u0: Tensor, *params, method: str = 'newton', tol: float = 1e-06, atol: float = 1e-10, max_iter: int = 50, line_search: bool = True, verbose: bool = False) Tensor[source]

Solve nonlinear equation F(u, D, *params) = 0 using distributed Newton-Krylov.

Uses Jacobian-free Newton-Krylov with distributed CG for linear solves.

Parameters:
  • residual_fn (callable) – Function F(u, D, *params) -> residual tensor. D is this DSparseTensor.

  • u0 (torch.Tensor) – Initial guess (global vector).

  • *params (torch.Tensor) – Additional parameters.

  • method (str) – ‘newton’: Newton-Krylov with distributed CG ‘picard’: Fixed-point iteration

  • tol (float) – Relative tolerance.

  • atol (float) – Absolute tolerance.

  • max_iter (int) – Maximum outer iterations.

  • line_search (bool) – Use Armijo line search.

  • verbose (bool) – Print convergence info.

Returns:

Solution u such that F(u, D, *params) ≈ 0.

Return type:

torch.Tensor

Notes

Distributed Algorithm:

  • Uses Jacobian-free Newton-Krylov (JFNK)

  • Linear solves use distributed CG

  • Jacobian-vector products computed via finite differences

save(directory: str | PathLike, verbose: bool = False) None[source]

Save DSparseTensor to disk.

Creates a directory with metadata and per-partition files.

Parameters:
  • directory (str or PathLike) – Output directory.

  • verbose (bool) – Print progress.

Example

>>> D = A.partition(num_partitions=4)
>>> D.save("matrix_dist")
classmethod load(directory: str | PathLike, device: str | device = 'cpu') DSparseTensor[source]

Load a complete DSparseTensor from disk.

Parameters:
  • directory (str or PathLike) – Directory containing saved data.

  • device (str or torch.device) – Device to load to.

Returns:

The loaded distributed sparse tensor.

Return type:

DSparseTensor

Example

>>> D = DSparseTensor.load("matrix_dist", device="cuda")

DSparseMatrix

Distributed sparse matrix designed for large-scale CFD/FEM computations. Provides domain decomposition with halo exchange.

class torch_sla.DSparseMatrix(partition: Partition, local_values: Tensor, local_row: Tensor, local_col: Tensor, local_shape: Tuple[int, int], global_shape: Tuple[int, int], num_partitions: int, device: str | device = 'cpu', verbose: bool = True)[source]

Bases: object

Distributed Sparse Matrix with halo exchange support.

Designed for large-scale CFD/FEM computations following industrial practices from Ansys, OpenFOAM, etc.

The matrix is partitioned across multiple processes/GPUs, with automatic halo (ghost) node management for parallel iterative solvers.

Supports both CPU and CUDA devices.

partition

Local partition information

Type:

Partition

local_values

Non-zero values for local portion of matrix

Type:

torch.Tensor

local_row

Local row indices

Type:

torch.Tensor

local_col

Local column indices

Type:

torch.Tensor

local_shape

Shape of local matrix (including halo)

Type:

Tuple[int, int]

global_shape

Shape of global matrix

Type:

Tuple[int, int]

device

Device where the matrix data resides (cpu or cuda)

Type:

torch.device

Example

>>> # Create distributed matrix on CPU
>>> A = DSparseMatrix.from_global(val, row, col, shape, num_parts=4, my_part=0, device='cpu')
>>>
>>> # Create distributed matrix on CUDA
>>> A_cuda = DSparseMatrix.from_global(val, row, col, shape, num_parts=4, my_part=0, device='cuda')
>>>
>>> # Distributed matrix-vector product with halo exchange
>>> y = A.matvec(x)  # Automatically handles halo exchange
>>>
>>> # Explicit halo exchange
>>> A.halo_exchange(x)  # Update halo values in x
to(device: str | device) DSparseMatrix[source]

Move the distributed matrix to a different device.

Parameters:

device (str or torch.device) – Target device (‘cpu’, ‘cuda’, ‘cuda:0’, etc.)

Returns:

New distributed matrix on the target device

Return type:

DSparseMatrix

cuda(device: int | None = None) DSparseMatrix[source]

Move to CUDA device

cpu() DSparseMatrix[source]

Move to CPU

property is_cuda: bool

Check if matrix is on CUDA

classmethod from_global(values: Tensor, row: Tensor, col: Tensor, shape: Tuple[int, int], num_partitions: int, my_partition: int, partition_ids: Tensor | None = None, coords: Tensor | None = None, device: str | device = 'cpu', verbose: bool = True) DSparseMatrix[source]

Create distributed matrix from global COO data.

Parameters:
  • values (torch.Tensor) – Global COO sparse matrix data

  • row (torch.Tensor) – Global COO sparse matrix data

  • col (torch.Tensor) – Global COO sparse matrix data

  • shape (Tuple[int, int]) – Global matrix shape

  • num_partitions (int) – Number of partitions

  • my_partition (int) – This process’s partition ID (0 to num_partitions-1)

  • partition_ids (torch.Tensor, optional) – Pre-computed partition assignments. If None, computed automatically.

  • coords (torch.Tensor, optional) – Node coordinates for geometric partitioning [num_nodes, dim]

  • device (str or torch.device) – Device for local data (‘cpu’, ‘cuda’, ‘cuda:0’, etc.)

  • verbose (bool) – Whether to print partition info

Returns:

Local portion of the distributed matrix

Return type:

DSparseMatrix

property num_owned: int

Number of owned (non-halo) nodes

property num_halo: int

Number of halo/ghost nodes

property num_local: int

Total local nodes (owned + halo)

property nnz: int

Number of non-zeros in local matrix

property dtype: dtype

Data type of matrix values

halo_exchange(x: Tensor, async_op: bool = False) Tensor | None[source]

Exchange halo/ghost values with neighbors.

This is the core operation for parallel iterative methods. Updates the halo portion of x with values from neighboring partitions.

Parameters:
  • x (torch.Tensor) – Local vector [num_local] with owned values filled in. Halo values will be updated.

  • async_op (bool) – If True, return immediately and return a future.

Returns:

x – Vector with updated halo values (same tensor, modified in-place)

Return type:

torch.Tensor

Example

>>> # During iterative solve
>>> for iteration in range(max_iter):
>>>     # Compute local update
>>>     x_new = local_gauss_seidel_step(A_local, x, b)
>>>
>>>     # Exchange boundary values
>>>     A.halo_exchange(x_new)
>>>
>>>     # Check convergence using owned nodes only
>>>     residual = compute_residual(A_local, x_new, b)
halo_exchange_local(x_list: List[Tensor]) None[source]

Local halo exchange for single-process multi-partition simulation.

Useful for testing/debugging without actual distributed setup.

Parameters:

x_list (List[torch.Tensor]) – List of local vectors, one per partition

matvec(x: Tensor, exchange_halo: bool = True) Tensor[source]

Local matrix-vector product y = A_local @ x.

Parameters:
  • x (torch.Tensor) – Local vector [num_local]

  • exchange_halo (bool) – If True, perform halo exchange before multiplication

Returns:

y – Result vector [num_local]

Return type:

torch.Tensor

matvec_overlap(x: Tensor) Tensor[source]

Matrix-vector product with communication-computation overlap.

This optimized version overlaps halo communication with computation: 1. Start async halo exchange 2. Compute interior part (rows that don’t depend on halo) 3. Wait for halo exchange to complete 4. Compute boundary part (rows that depend on halo) 5. Combine results

Note: This is only beneficial in true distributed settings where there is actual network latency to hide. In single-process mode, this falls back to regular matvec.

Parameters:

x (torch.Tensor) – Local vector [num_local]

Returns:

y – Result vector [num_local]

Return type:

torch.Tensor

halo_exchange_async(x: Tensor)[source]

Start asynchronous halo exchange.

Returns a handle that can be passed to _wait_halo_exchange().

solve(b: Tensor, method: str = 'cg', preconditioner: str = 'jacobi', atol: float = 1e-10, rtol: float = 1e-06, maxiter: int = 1000, verbose: bool = False, distributed: bool = True, overlap: bool = False, use_cache: bool = True) Tensor[source]

Solve linear system Ax = b.

Optimizations enabled by default: - CSR cache: Avoids repeated COO->CSR conversion (use_cache=True) - Jacobi preconditioner: ~5% speedup for Poisson-like problems

Parameters:
  • b (torch.Tensor) – Right-hand side. Shape [num_owned] for owned nodes only.

  • method (str) – Solver method: ‘cg’ (default), ‘jacobi’, ‘gauss_seidel’

  • preconditioner (str) – Preconditioner for CG: ‘none’, ‘jacobi’ (default), ‘ssor’, ‘ic0’, ‘polynomial’

  • atol (float) – Absolute tolerance for convergence

  • rtol (float) – Relative tolerance for convergence (|r| < rtol * |b|)

  • maxiter (int) – Maximum iterations

  • verbose (bool) – Print convergence info (rank 0 only for distributed)

  • distributed (bool, default=True) – If True (default): Solve the GLOBAL system using distributed algorithms with all_reduce for global dot products. If False: Solve only the LOCAL subdomain problem (useful as preconditioner in domain decomposition methods).

  • overlap (bool, default=False) – If True: Overlap communication with computation. Note: Only beneficial for slow interconnects (InfiniBand, Ethernet). For NVLink, synchronous communication is faster.

  • use_cache (bool, default=True) – If True (default): Cache CSR format and diagonal for reuse. Provides ~2% speedup and ~27% memory reduction.

Returns:

x – Solution for owned nodes, shape [num_owned]

Return type:

torch.Tensor

Examples

>>> # Distributed solve (default) - all ranks cooperate
>>> x = local_matrix.solve(b_owned)
>>> # Local subdomain solve - no global communication
>>> x = local_matrix.solve(b_owned, distributed=False)
>>> # With different preconditioner
>>> x = local_matrix.solve(b_owned, preconditioner='ssor')
>>> # Disable caching (for memory-constrained cases)
>>> x = local_matrix.solve(b_owned, use_cache=False)
eigsh(k: int = 6, which: str = 'LM', maxiter: int = 200, tol: float = 1e-08, verbose: bool = False, distributed: bool = True) Tuple[Tensor, Tensor][source]

Compute k eigenvalues of symmetric matrix.

Parameters:
  • k (int) – Number of eigenvalues to compute

  • which (str) – Which eigenvalues: “LM” (largest magnitude), “SM” (smallest magnitude)

  • maxiter (int) – Maximum iterations

  • tol (float) – Convergence tolerance

  • verbose (bool) – Print convergence info (rank 0 only)

  • distributed (bool, default=True) – If True (default): Use distributed LOBPCG with global reductions. If False: Gather to single SparseTensor and compute locally (not recommended for large matrices).

Returns:

  • eigenvalues (torch.Tensor) – k eigenvalues, shape [k]

  • eigenvectors_owned (torch.Tensor) – Eigenvectors for owned nodes only, shape [num_owned, k]

gather_global(x_local: Tensor) Tensor | None[source]

Gather local vectors to global vector (on rank 0).

Parameters:

x_local (torch.Tensor) – Local vector [num_owned]

Returns:

x_global – Global vector on rank 0, None on other ranks

Return type:

torch.Tensor or None

det() Tensor[source]

Compute determinant of the distributed sparse matrix.

NOTE: DSparseMatrix represents a single partition. To compute the determinant of the full global matrix, you need to use DSparseTensor which manages all partitions, or manually gather all partitions.

This method raises an error to guide users to the correct approach.

Raises:

NotImplementedError – DSparseMatrix is a single partition. Use DSparseTensor.det() instead.

Examples

>>> # Correct way: Use DSparseTensor
>>> from torch_sla import DSparseTensor
>>> D = DSparseTensor(val, row, col, shape, num_partitions=4)
>>> det = D.det()  # This works
>>>
>>> # If you have individual DSparseMatrix partitions, you need to
>>> # reconstruct the global matrix first
classmethod load(directory: str | PathLike, rank: int, world_size: int | None = None, device: str | device = 'cpu') DSparseMatrix[source]

Load a partition from disk for the given rank.

Each rank should call this with its own rank to load only its partition.

Parameters:
  • directory (str or PathLike) – Directory containing partitioned data.

  • rank (int) – Rank of this process.

  • world_size (int, optional) – Total number of processes (must match num_partitions).

  • device (str or torch.device) – Device to load tensors to.

Returns:

The partition for this rank.

Return type:

DSparseMatrix

Example

>>> rank = dist.get_rank()
>>> world_size = dist.get_world_size()
>>> partition = DSparseMatrix.load("matrix_dist", rank, world_size, "cuda")

Partition

Dataclass representing a single partition/subdomain for distributed computing.

class torch_sla.Partition(partition_id: int, local_nodes: Tensor, owned_nodes: Tensor, halo_nodes: Tensor, neighbor_partitions: List[int], send_indices: Dict[int, Tensor], recv_indices: Dict[int, Tensor], global_to_local: Tensor, local_to_global: Tensor)[source]

Represents a single partition/subdomain

partition_id: int
local_nodes: Tensor
owned_nodes: Tensor
halo_nodes: Tensor
neighbor_partitions: List[int]
send_indices: Dict[int, Tensor]
recv_indices: Dict[int, Tensor]
global_to_local: Tensor
local_to_global: Tensor

Linear Solve Functions

spsolve

torch_sla.spsolve(val: Tensor, row: Tensor, col: Tensor, shape: Tuple[int, int], b: Tensor, backend: Literal['scipy', 'eigen', 'pytorch', 'cusolver', 'cudss', 'auto'] = 'auto', method: Literal['auto', 'superlu', 'umfpack', 'lu', 'qr', 'cholesky', 'ldlt', 'cg', 'bicgstab', 'gmres', 'lgmres', 'minres', 'qmr'] = 'auto', atol: float = 1e-10, maxiter: int = 10000, tol: float = 1e-12, matrix_type: str = 'general', is_symmetric: bool = False, is_spd: bool = False, preconditioner: str = 'jacobi', mixed_precision: bool = False) Tensor[source]

Solve the Sparse Linear Equation Ax = b with gradient support.

Supports multiple backends for CPU and CUDA tensors.

Parameters:
  • val (torch.Tensor) – [nnz] Non-zero values of sparse matrix A in COO format

  • row (torch.Tensor) – [nnz] Row indices

  • col (torch.Tensor) – [nnz] Column indices

  • shape (Tuple[int, int]) – (m, n) Shape of sparse matrix A

  • b (torch.Tensor) – [m] Right-hand side vector

  • backend (str, optional) – Backend to use: - ‘auto’: Auto-select based on device and problem size (default) - ‘scipy’: SciPy (CPU only, uses SuperLU/UMFPACK) - ‘eigen’: Eigen C++ (CPU only, iterative) - ‘pytorch’: PyTorch-native (CPU & CUDA, iterative) - best for large problems - ‘cusolver’: NVIDIA cuSOLVER (CUDA only, direct) - ‘cudss’: NVIDIA cuDSS (CUDA only, direct)

  • method (str, optional) – Solver method. Available methods depend on backend: - ‘auto’: Auto-select based on matrix properties - ‘superlu’, ‘umfpack’: Direct solvers (scipy) - ‘cg’, ‘bicgstab’, ‘gmres’: Iterative solvers - ‘lu’, ‘qr’, ‘cholesky’, ‘ldlt’: Direct solvers (CUDA)

  • atol (float, optional) – Absolute tolerance for iterative solvers, by default 1e-10

  • maxiter (int, optional) – Maximum iterations for iterative solvers, by default 10000

  • tol (float, optional) – Tolerance for direct solvers, by default 1e-12

  • matrix_type (str, optional) – Matrix type for cuDSS: ‘general’, ‘symmetric’, ‘spd’, by default “general”

  • is_symmetric (bool, optional) – Hint that matrix is symmetric (for auto method selection)

  • is_spd (bool, optional) – Hint that matrix is symmetric positive definite

Returns:

[n] Solution vector x

Return type:

torch.Tensor

Examples

>>> import torch
>>> from torch_sla import spsolve
>>>
>>> # Create a simple SPD matrix
>>> val = torch.tensor([4.0, -1.0, -1.0, 4.0, -1.0, -1.0, 4.0], dtype=torch.float64)
>>> row = torch.tensor([0, 0, 1, 1, 1, 2, 2], dtype=torch.int64)
>>> col = torch.tensor([0, 1, 0, 1, 2, 1, 2], dtype=torch.int64)
>>> shape = (3, 3)
>>> b = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float64)
>>>
>>> # Auto-select backend and method
>>> x = spsolve(val, row, col, shape, b)
>>>
>>> # Specify backend and method
>>> x = spsolve(val, row, col, shape, b, backend='scipy', method='superlu')
>>>
>>> # On CUDA
>>> val_cuda = val.cuda()
>>> row_cuda = row.cuda()
>>> col_cuda = col.cuda()
>>> b_cuda = b.cuda()
>>> x_cuda = spsolve(val_cuda, row_cuda, col_cuda, shape, b_cuda, backend='cudss', method='lu')

spsolve_coo

torch_sla.spsolve_coo(A: Tensor, b: Tensor, **kwargs) Tensor[source]

Solve Ax = b where A is a sparse COO tensor

Parameters:
  • A (torch.Tensor) – Sparse COO tensor representing the matrix

  • b (torch.Tensor) – Right-hand side vector

  • **kwargs – Additional arguments passed to spsolve()

Returns:

Solution vector x

Return type:

torch.Tensor

spsolve_csr

torch_sla.spsolve_csr(A: Tensor, b: Tensor, **kwargs) Tensor[source]

Solve Ax = b where A is a sparse CSR tensor

Parameters:
  • A (torch.Tensor) – Sparse CSR tensor representing the matrix

  • b (torch.Tensor) – Right-hand side vector

  • **kwargs – Additional arguments passed to spsolve()

Returns:

Solution vector x

Return type:

torch.Tensor


Batch Solve Functions

spsolve_batch_same_layout

torch_sla.spsolve_batch_same_layout(val_batch: Tensor, row: Tensor, col: Tensor, shape: Tuple[int, int], b_batch: Tensor, method: Literal['cg', 'bicgstab', 'cusolver_qr', 'cusolver_cholesky', 'cusolver_lu', 'cudss', 'cudss_lu', 'cudss_cholesky', 'cudss_ldlt'] = 'bicgstab', atol: float = 1e-10, maxiter: int = 10000) Tensor[source]

Batch solve sparse linear systems with the SAME sparsity pattern.

Deprecated since version Use: SparseTensor.decompose().solve() instead for a more Pythonic interface:

>>> A = SparseTensor(val, row, col, shape)
>>> decomp = A.decompose(method='superlu')
>>> x_batch = decomp.solve(val_batch, b_batch)

All matrices A_i share the same (row, col) structure but have different values. This is efficient when the sparsity pattern is fixed (e.g., FEM with fixed mesh).

Solves: A_i @ x_i = b_i for i = 0, 1, …, batch_size-1

Parameters:
  • val_batch (torch.Tensor) – [batch_size, nnz] Non-zero values for each matrix

  • row (torch.Tensor) – [nnz] Row indices (shared across batch)

  • col (torch.Tensor) – [nnz] Column indices (shared across batch)

  • shape (Tuple[int, int]) – (m, n) Shape of each sparse matrix

  • b_batch (torch.Tensor) – [batch_size, m] Right-hand side vectors

  • method (str) – Solver method (same options as spsolve)

  • atol (float) – Absolute tolerance for iterative solvers

  • maxiter (int) – Maximum iterations for iterative solvers

Returns:

[batch_size, n] Solution vectors

Return type:

torch.Tensor

Example

>>> import torch
>>> from torch_sla import spsolve_batch_same_layout
>>>
>>> batch_size = 10
>>> n = 100
>>> nnz = 500
>>>
>>> # Same sparsity pattern, different values
>>> row = torch.randint(0, n, (nnz,))
>>> col = torch.randint(0, n, (nnz,))
>>> val_batch = torch.randn(batch_size, nnz, dtype=torch.float64)
>>> b_batch = torch.randn(batch_size, n, dtype=torch.float64)
>>>
>>> x_batch = spsolve_batch_same_layout(val_batch, row, col, (n, n), b_batch)

spsolve_batch_different_layout

torch_sla.spsolve_batch_different_layout(matrices: List[Tuple[Tensor, Tensor, Tensor, Tuple[int, int]]], b_list: List[Tensor], method: Literal['cg', 'bicgstab', 'cusolver_qr', 'cusolver_cholesky', 'cusolver_lu', 'cudss', 'cudss_lu', 'cudss_cholesky', 'cudss_ldlt'] = 'bicgstab', atol: float = 1e-10, maxiter: int = 10000) List[Tensor][source]

Batch solve sparse linear systems with DIFFERENT sparsity patterns.

Deprecated since version Use: SparseTensorList.solve() instead for a more Pythonic interface:

>>> matrices = SparseTensorList([A1, A2, A3])
>>> x_list = matrices.solve([b1, b2, b3])

Each matrix can have a different structure. This is useful when dealing with heterogeneous problems or adaptive mesh refinement.

Parameters:
  • matrices (List[Tuple[val, row, col, shape]]) – List of sparse matrices, each as (values, row_indices, col_indices, shape)

  • b_list (List[torch.Tensor]) – List of right-hand side vectors

  • method (str) – Solver method (same options as spsolve)

  • atol (float) – Absolute tolerance for iterative solvers

  • maxiter (int) – Maximum iterations for iterative solvers

Returns:

List of solution vectors

Return type:

List[torch.Tensor]

Example

>>> import torch
>>> from torch_sla import spsolve_batch_different_layout
>>>
>>> # Different matrices with different sizes/patterns
>>> matrices = []
>>> b_list = []
>>> for n in [50, 100, 150]:
...     nnz = n * 5
...     val = torch.randn(nnz, dtype=torch.float64)
...     row = torch.randint(0, n, (nnz,))
...     col = torch.randint(0, n, (nnz,))
...     matrices.append((val, row, col, (n, n)))
...     b_list.append(torch.randn(n, dtype=torch.float64))
>>>
>>> x_list = spsolve_batch_different_layout(matrices, b_list)

ParallelBatchSolver

class torch_sla.ParallelBatchSolver(row: Tensor, col: Tensor, shape: Tuple[int, int], method: Literal['cg', 'bicgstab', 'cusolver_qr', 'cusolver_cholesky', 'cusolver_lu', 'cudss', 'cudss_lu', 'cudss_cholesky', 'cudss_ldlt'] = 'bicgstab', device: str | None = None)[source]

Bases: object

High-performance parallel batch solver.

This class pre-analyzes the sparsity pattern and caches factorization information for repeated solves with the same structure.

Example

>>> solver = ParallelBatchSolver(row, col, shape, method='cudss_lu')
>>>
>>> # Solve multiple batches efficiently
>>> for val_batch, b_batch in data_loader:
...     x_batch = solver.solve(val_batch, b_batch)
solve(val_batch: Tensor, b_batch: Tensor, atol: float = 1e-10, maxiter: int = 10000) Tensor[source]

Solve batch of linear systems.

Parameters:
  • val_batch (torch.Tensor) – [batch_size, nnz] Matrix values

  • b_batch (torch.Tensor) – [batch_size, m] Right-hand side vectors

  • atol (float) – Tolerance for iterative solvers

  • maxiter (int) – Maximum iterations

Returns:

[batch_size, n] Solution vectors

Return type:

torch.Tensor


Nonlinear Solve

nonlinear_solve

torch_sla.nonlinear_solve(residual_fn: Callable, u0: Tensor, *params, jacobian_fn: Callable | None = None, method: str = 'newton', tol: float = 1e-06, atol: float = 1e-10, max_iter: int = 50, line_search: bool = True, verbose: bool = False, linear_solver: str = 'pytorch', linear_method: str = 'cg') Tensor[source]

Solve nonlinear equation F(u, θ) = 0 with adjoint-based gradients.

Parameters:
  • residual_fn – Function F(u, *params) -> residual tensor

  • u0 – Initial guess for solution

  • *params – Parameters θ (tensors with requires_grad=True for gradient computation)

  • jacobian_fn – Optional function J(u, *params) -> (val, row, col, shape) Returns sparse Jacobian in COO format. If None, uses autograd.

  • method – Nonlinear solver method - ‘newton’: Newton-Raphson with optional line search (default) - ‘picard’: Fixed-point iteration - ‘anderson’: Anderson acceleration

  • tol – Relative convergence tolerance

  • atol – Absolute convergence tolerance

  • max_iter – Maximum number of nonlinear iterations

  • line_search – Use Armijo line search for Newton (default: True)

  • verbose – Print convergence information

  • linear_solver – Backend for linear solves (‘pytorch’, ‘scipy’, ‘cudss’)

  • linear_method – Method for linear solves (‘cg’, ‘bicgstab’, ‘lu’)

Returns:

Solution tensor satisfying F(u, θ) ≈ 0

Return type:

u

Example

>>> def residual(u, A_val, b):
...     # Nonlinear: A(u) @ u - b where A depends on u
...     return torch.sparse.mm(A, u.unsqueeze(1)).squeeze() - b
...
>>> u0 = torch.zeros(n, requires_grad=False)
>>> A_val = torch.randn(nnz, requires_grad=True)
>>> b = torch.randn(n, requires_grad=True)
>>>
>>> u = nonlinear_solve(residual, u0, A_val, b, method='newton')
>>> loss = some_loss(u)
>>> loss.backward()  # Computes ∂L/∂A_val and ∂L/∂b via adjoint

adjoint_solve

torch_sla.adjoint_solve(residual_fn: Callable, u0: Tensor, *params, jacobian_fn: Callable | None = None, method: str = 'newton', tol: float = 1e-06, atol: float = 1e-10, max_iter: int = 50, line_search: bool = True, verbose: bool = False, linear_solver: str = 'pytorch', linear_method: str = 'cg') Tensor

Solve nonlinear equation F(u, θ) = 0 with adjoint-based gradients.

Parameters:
  • residual_fn – Function F(u, *params) -> residual tensor

  • u0 – Initial guess for solution

  • *params – Parameters θ (tensors with requires_grad=True for gradient computation)

  • jacobian_fn – Optional function J(u, *params) -> (val, row, col, shape) Returns sparse Jacobian in COO format. If None, uses autograd.

  • method – Nonlinear solver method - ‘newton’: Newton-Raphson with optional line search (default) - ‘picard’: Fixed-point iteration - ‘anderson’: Anderson acceleration

  • tol – Relative convergence tolerance

  • atol – Absolute convergence tolerance

  • max_iter – Maximum number of nonlinear iterations

  • line_search – Use Armijo line search for Newton (default: True)

  • verbose – Print convergence information

  • linear_solver – Backend for linear solves (‘pytorch’, ‘scipy’, ‘cudss’)

  • linear_method – Method for linear solves (‘cg’, ‘bicgstab’, ‘lu’)

Returns:

Solution tensor satisfying F(u, θ) ≈ 0

Return type:

u

Example

>>> def residual(u, A_val, b):
...     # Nonlinear: A(u) @ u - b where A depends on u
...     return torch.sparse.mm(A, u.unsqueeze(1)).squeeze() - b
...
>>> u0 = torch.zeros(n, requires_grad=False)
>>> A_val = torch.randn(nnz, requires_grad=True)
>>> b = torch.randn(n, requires_grad=True)
>>>
>>> u = nonlinear_solve(residual, u0, A_val, b, method='newton')
>>> loss = some_loss(u)
>>> loss.backward()  # Computes ∂L/∂A_val and ∂L/∂b via adjoint

NonlinearSolveAdjoint

class torch_sla.NonlinearSolveAdjoint(*args, **kwargs)[source]

Bases: Function

Adjoint-based nonlinear solver with automatic differentiation.

Uses implicit differentiation to compute gradients without storing intermediate Jacobians. Memory-efficient for large-scale problems.

static forward(ctx, u0: Tensor, num_params: int, *args) Tensor[source]

Forward pass: solve F(u, θ) = 0 for u.

Parameters:
  • u0 – Initial guess for solution

  • num_params – Number of parameter tensors

  • *args – First num_params elements are param tensors, last is config dict

Returns:

Solution satisfying F(u, θ) ≈ 0

Return type:

u

static backward(ctx, grad_u: Tensor)[source]

Backward pass using adjoint method.

Computes ∂L/∂θ = -λᵀ · ∂F/∂θ where (∂F/∂u)ᵀ · λ = grad_u

Returns:

(grad_u0, grad_num_params, *grad_params, grad_config)

Return type:

Tuple of gradients


Persistence (I/O)

safetensors Format

save_sparse

torch_sla.save_sparse(tensor: SparseTensor, path: str | Path, metadata: Dict[str, str] | None = None) None[source]

Save a SparseTensor to safetensors format.

Parameters:
  • tensor (SparseTensor) – The sparse tensor to save.

  • path (str or Path) – Output file path (should end with .safetensors).

  • metadata (dict, optional) – Additional metadata to store in the file.

Example

>>> A = SparseTensor(val, row, col, (100, 100))
>>> save_sparse(A, "matrix.safetensors")

load_sparse

torch_sla.load_sparse(path: str | Path, device: str | device = 'cpu') SparseTensor[source]

Load a SparseTensor from safetensors format.

Parameters:
  • path (str or Path) – Input file path.

  • device (str or torch.device) – Device to load tensors to.

Returns:

The loaded sparse tensor.

Return type:

SparseTensor

Example

>>> A = load_sparse("matrix.safetensors", device="cuda")

save_distributed

torch_sla.save_distributed(tensor: SparseTensor, directory: str | Path, num_partitions: int, partition_method: str = 'simple', coords: Tensor | None = None, verbose: bool = False) None[source]

Save a SparseTensor as partitioned files for distributed loading.

Creates a directory with: - metadata.json: Global metadata and partition info - partition_0.safetensors, partition_1.safetensors, …: Per-partition data

Parameters:
  • tensor (SparseTensor) – The global sparse tensor to partition and save.

  • directory (str or Path) – Output directory path.

  • num_partitions (int) – Number of partitions to create.

  • partition_method (str) – Partitioning method: ‘simple’, ‘metis’, or ‘geometric’.

  • coords (torch.Tensor, optional) – Node coordinates for geometric partitioning.

  • verbose (bool) – Print progress information.

Example

>>> A = SparseTensor(val, row, col, (1000, 1000))
>>> save_distributed(A, "matrix_dist", num_partitions=4)
# Creates:
#   matrix_dist/metadata.json
#   matrix_dist/partition_0.safetensors
#   matrix_dist/partition_1.safetensors
#   matrix_dist/partition_2.safetensors
#   matrix_dist/partition_3.safetensors

load_partition

torch_sla.load_partition(directory: str | Path, rank: int, world_size: int | None = None, device: str | device = 'cpu') DSparseMatrix[source]

Load a single partition for the given rank.

Each rank loads only its own partition, enabling efficient distributed loading.

Parameters:
  • directory (str or Path) – Directory containing partitioned data.

  • rank (int) – Rank of this process.

  • world_size (int, optional) – Total number of processes (must match num_partitions). If None, reads from metadata.

  • device (str or torch.device) – Device to load tensors to.

Returns:

The partition for this rank.

Return type:

DSparseMatrix

Example

>>> # In distributed context
>>> rank = dist.get_rank()
>>> world_size = dist.get_world_size()
>>> partition = load_partition("matrix_dist", rank, world_size, device="cuda")

load_metadata

torch_sla.load_metadata(directory: str | Path) Dict[source]

Load metadata from a distributed sparse tensor directory.

Parameters:

directory (str or Path) – Directory containing partitioned data.

Returns:

Metadata including shape, dtype, num_partitions, etc.

Return type:

dict

Example

>>> meta = load_metadata("matrix_dist")
>>> print(f"Shape: {meta['shape']}, Partitions: {meta['num_partitions']}")

load_sparse_as_partition

torch_sla.load_sparse_as_partition(path: str | Path, rank: int, world_size: int, partition_method: str = 'simple', coords: Tensor | None = None, device: str | device = 'cpu') DSparseMatrix[source]

Load a SparseTensor file and return only this rank’s partition.

This allows distributed reading of a single SparseTensor file, where each rank loads the full file but only keeps its partition.

For very large matrices, use save_distributed() instead to avoid loading the full matrix on each rank.

Parameters:
  • path (str or Path) – Path to SparseTensor file (.safetensors).

  • rank (int) – Rank of this process.

  • world_size (int) – Total number of processes.

  • partition_method (str) – ‘simple’, ‘metis’, or ‘geometric’.

  • coords (torch.Tensor, optional) – Node coordinates for geometric partitioning.

  • device (str or torch.device) – Device to load partition to.

Returns:

This rank’s partition of the matrix.

Return type:

DSparseMatrix

Example

>>> # Each rank calls this:
>>> rank = dist.get_rank()
>>> world_size = dist.get_world_size()
>>> partition = load_sparse_as_partition("matrix.safetensors", rank, world_size)

load_distributed_as_sparse

torch_sla.load_distributed_as_sparse(directory: str | Path, device: str | device = 'cpu') SparseTensor[source]

Load a distributed/partitioned save as a single SparseTensor.

This gathers all partitions into one SparseTensor. Useful when you have partitioned data but want to use it on a single node.

Parameters:
  • directory (str or Path) – Directory containing partitioned data (from save_distributed or DSparseTensor.save).

  • device (str or torch.device) – Device to load to.

Returns:

The complete sparse tensor.

Return type:

SparseTensor

Example

>>> # Load partitioned data as single SparseTensor
>>> A = load_distributed_as_sparse("matrix_dist", device="cuda")

save_dsparse

torch_sla.save_dsparse(tensor: DSparseTensor, directory: str | Path, verbose: bool = False) None[source]

Save a DSparseTensor to disk.

Parameters:
  • tensor (DSparseTensor) – The distributed sparse tensor to save.

  • directory (str or Path) – Output directory.

  • verbose (bool) – Print progress.

load_dsparse

torch_sla.load_dsparse(directory: str | Path, device: str | device = 'cpu') DSparseTensor[source]

Load a complete DSparseTensor from disk.

Parameters:
  • directory (str or Path) – Directory containing saved data.

  • device (str or torch.device) – Device to load to.

Returns:

The loaded distributed sparse tensor.

Return type:

DSparseTensor

Matrix Market Format

save_mtx

torch_sla.save_mtx(tensor: SparseTensor, path: str | Path, comment: str = '', field: str = 'real', symmetry: str = 'general') None[source]

Save a SparseTensor to Matrix Market (.mtx) format.

Parameters:
  • tensor (SparseTensor) – The sparse tensor to save.

  • path (str or Path) – Output file path (should end with .mtx).

  • comment (str, optional) – Comment to include in the header.

  • field (str, optional) – Field type: ‘real’, ‘complex’, ‘integer’, or ‘pattern’. Default: ‘real’.

  • symmetry (str, optional) – Symmetry type: ‘general’, ‘symmetric’, ‘skew-symmetric’, or ‘hermitian’. Default: ‘general’.

Example

>>> A = SparseTensor(val, row, col, (100, 100))
>>> save_mtx(A, "matrix.mtx")
>>> save_mtx(A, "matrix.mtx", symmetry="symmetric")

load_mtx

torch_sla.load_mtx(path: str | Path, dtype: dtype | None = None, device: str | device = 'cpu') SparseTensor[source]

Load a SparseTensor from Matrix Market (.mtx) format.

Parameters:
  • path (str or Path) – Input file path.

  • dtype (torch.dtype, optional) – Data type for values. If None, inferred from file.

  • device (str or torch.device) – Device to load tensors to.

Returns:

The loaded sparse tensor.

Return type:

SparseTensor

Example

>>> A = load_mtx("matrix.mtx")
>>> A = load_mtx("matrix.mtx", dtype=torch.float32, device="cuda")

load_mtx_info

torch_sla.load_mtx_info(path: str | Path) Dict[source]

Read Matrix Market file header without loading data.

Parameters:

path (str or Path) – Input file path.

Returns:

Dictionary with keys: ‘shape’, ‘nnz’, ‘field’, ‘symmetry’.

Return type:

dict

Example

>>> info = load_mtx_info("matrix.mtx")
>>> print(f"Shape: {info['shape']}, NNZ: {info['nnz']}")

Partitioning Functions

partition_graph_metis

torch_sla.partition_graph_metis(row: Tensor, col: Tensor, num_nodes: int, num_parts: int) Tensor[source]

Partition graph using METIS (if available) or fallback to simple method.

Returns:

partition_ids – Partition ID for each node [num_nodes]

Return type:

torch.Tensor

partition_coordinates

torch_sla.partition_coordinates(coords: Tensor, num_parts: int, method: str = 'rcb') Tensor[source]

Partition based on node coordinates using Recursive Coordinate Bisection (RCB).

This is common in CFD/FEM for mesh partitioning.

Parameters:
  • coords (torch.Tensor) – Node coordinates [num_nodes, dim]

  • num_parts (int) – Number of partitions (should be power of 2 for RCB)

  • method (str) – ‘rcb’: Recursive Coordinate Bisection ‘slicing’: Simple slicing along longest axis

Returns:

partition_ids – Partition ID for each node

Return type:

torch.Tensor

partition_simple

torch_sla.partition_simple(num_nodes: int, num_parts: int) Tensor[source]

Simple 1D partitioning (fallback when METIS not available) - vectorized.


Backend Utilities

get_available_backends

torch_sla.get_available_backends() List[str][source]

Get list of available backends

get_backend_methods

torch_sla.get_backend_methods(backend: str) List[str][source]

Get list of methods supported by a backend

get_default_method

torch_sla.get_default_method(backend: str) str[source]

Get default method for a backend

select_backend

torch_sla.select_backend(device: device, n: int | None = None, dtype: dtype | None = None, prefer_direct: bool = True) str[source]

Auto-select the best backend based on device, problem size, and dtype.

Recommendations based on benchmark results: - CPU: scipy+superlu (all sizes, fast + machine precision) - CUDA (DOF < 2M): cudss+cholesky (fast + high precision) - CUDA (DOF >= 2M): pytorch+cg (memory efficient, ~1e-6 precision)

Parameters:
  • device (torch.device) – Target device (cpu or cuda)

  • n (int, optional) – Problem size (DOF). If > CUDA_ITERATIVE_THRESHOLD, prefer iterative.

  • dtype (torch.dtype, optional) – Data type. Note: cuSOLVER does not support float32!

  • prefer_direct (bool) – If True, prefer direct solvers over iterative (when applicable)

Returns:

Backend name (‘scipy’, ‘eigen’, ‘pytorch’, ‘cusolver’, or ‘cudss’)

Return type:

str

select_method

torch_sla.select_method(backend: str, is_symmetric: bool = False, is_spd: bool = False, prefer_direct: bool = True) str[source]

Auto-select the best method for a given backend and matrix properties.

Recommendations based on benchmark results: - scipy: superlu (direct, best precision) or cg (iterative, for SPD) - cudss: cholesky (SPD, fastest) > ldlt (symmetric) > lu (general) - pytorch: cg (SPD) or bicgstab (general), both with Jacobi preconditioning

Parameters:
  • backend (str) – Backend name

  • is_symmetric (bool) – Whether the matrix is symmetric

  • is_spd (bool) – Whether the matrix is symmetric positive definite

  • prefer_direct (bool) – If True, prefer direct solvers

Returns:

Method name

Return type:

str

Backend Availability Checks

torch_sla.is_scipy_available() bool[source]

Check if SciPy backend is available

torch_sla.is_eigen_available() bool[source]

Check if Eigen backend (C++ extension) is available

torch_sla.is_cusolver_available() bool[source]

Check if cuSOLVER backend is available

torch_sla.is_cudss_available() bool[source]

Check if cuDSS backend is available


Utility Functions

auto_select_method

torch_sla.auto_select_method(nnz: int, n: int, dtype: dtype, is_cuda: bool, is_spd: bool = False, memory_threshold: float = 0.8) Tuple[str, str][source]

Automatically select the best backend and method.

Parameters:
  • nnz (int) – Number of non-zero elements.

  • n (int) – Matrix dimension.

  • dtype (torch.dtype) – Data type of the matrix.

  • is_cuda (bool) – Whether the matrix is on CUDA.

  • is_spd (bool, optional) – Whether the matrix is symmetric positive definite. Default: False.

  • memory_threshold (float, optional) – Fraction of GPU memory to use. Default: 0.8.

Returns:

(backend, method) tuple.

Return type:

Tuple[str, str]

estimate_direct_solver_memory

torch_sla.estimate_direct_solver_memory(nnz: int, n: int, dtype: dtype) int[source]

Estimate memory required for direct sparse solver.

Parameters:
  • nnz (int) – Number of non-zero elements.

  • n (int) – Matrix dimension.

  • dtype (torch.dtype) – Data type of the matrix.

Returns:

Estimated memory in bytes.

Return type:

int

get_available_gpu_memory

torch_sla.get_available_gpu_memory() int[source]

Get available GPU memory in bytes.

Returns:

Available GPU memory in bytes, or 0 if CUDA is not available.

Return type:

int


Constants

BACKEND_METHODS

Dictionary mapping backend names to available solver methods.

BACKEND_METHODS = {
    'scipy': ['superlu', 'umfpack', 'cg', 'bicgstab', 'gmres', 'minres'],
    'eigen': ['cg', 'bicgstab'],
    'pytorch': ['cg', 'bicgstab'],
    'cusolver': ['qr', 'cholesky', 'lu'],
    'cudss': ['lu', 'cholesky', 'ldlt'],
}

DEFAULT_METHODS

Dictionary mapping backend names to their default solver methods.

DEFAULT_METHODS = {
    'scipy': 'superlu',
    'eigen': 'cg',
    'pytorch': 'cg',
    'cusolver': 'cholesky',
    'cudss': 'cholesky',
}

Type Aliases

  • BackendType: Literal type for backend names: 'scipy', 'eigen', 'pytorch', 'cusolver', 'cudss'

  • MethodType: Literal type for solver methods: 'superlu', 'umfpack', 'cg', 'bicgstab', 'gmres', 'minres', 'qr', 'cholesky', 'lu', 'ldlt'