"""
Batch Sparse Linear Solve for PyTorch
This module provides batch solving capabilities for sparse linear equations:
1. Same-layout batch solve: All matrices share the same sparsity pattern
2. Different-layout batch solve: Each matrix can have different sparsity pattern
For same-layout batches, we can leverage optimized batch operations.
For different-layout batches, we solve each system independently.
"""
import torch
from torch.autograd.function import Function
from typing import Tuple, List, Optional, Union, Literal
import warnings
from .backends import (
get_cpu_module,
get_cusolver_module,
get_cudss_module,
is_cusolver_available,
is_cudss_available,
)
MethodType = Literal[
'cg', 'bicgstab',
'cusolver_qr', 'cusolver_cholesky', 'cusolver_lu',
'cudss', 'cudss_lu', 'cudss_cholesky', 'cudss_ldlt'
]
class BatchSparseLinearSolveSameLayout(Function):
"""
Batch solve for matrices with the same sparsity pattern.
All matrices share the same (row, col) indices, but have different values.
This is common in optimization and neural network applications where
the matrix structure is fixed but values change.
"""
@staticmethod
def forward(ctx,
val_batch: torch.Tensor, # [batch, nnz]
row: torch.Tensor, # [nnz]
col: torch.Tensor, # [nnz]
shape: Tuple[int, int],
b_batch: torch.Tensor, # [batch, m]
method: str,
atol: float,
maxiter: int):
batch_size = val_batch.size(0)
m, n = shape
# Solve each system
results = []
for i in range(batch_size):
val = val_batch[i]
b = b_batch[i]
if method == 'cg':
_cpu = get_cpu_module()
x = _cpu.cg(torch.stack([row, col], 0), val, m, n, b, atol, maxiter)
elif method == 'bicgstab':
_cpu = get_cpu_module()
x = _cpu.bicgstab(torch.stack([row, col], 0), val, m, n, b, atol, maxiter)
elif method == 'cusolver_qr':
_cusolver = get_cusolver_module()
x = _cusolver.qr(torch.stack([row, col], 0), val, m, n, b, 1e-12)
elif method == 'cusolver_cholesky':
_cusolver = get_cusolver_module()
x = _cusolver.cholesky(torch.stack([row, col], 0), val, m, n, b, 1e-12)
elif method == 'cusolver_lu':
_cusolver = get_cusolver_module()
x = _cusolver.lu(torch.stack([row, col], 0), val, m, n, b, 1e-12)
elif method == 'cudss_lu':
_cudss = get_cudss_module()
x = _cudss.lu(torch.stack([row, col], 0), val, m, n, b)
elif method == 'cudss_cholesky':
_cudss = get_cudss_module()
x = _cudss.cholesky(torch.stack([row, col], 0), val, m, n, b)
elif method == 'cudss_ldlt':
_cudss = get_cudss_module()
x = _cudss.ldlt(torch.stack([row, col], 0), val, m, n, b)
else:
raise ValueError(f"Unknown method: {method}")
results.append(x)
u_batch = torch.stack(results, dim=0)
ctx.save_for_backward(val_batch, row, col, u_batch)
ctx.A_shape = shape
ctx.method = method
ctx.atol = atol
ctx.maxiter = maxiter
return u_batch
@staticmethod
def backward(ctx, gradu_batch):
val_batch, row, col, u_batch = ctx.saved_tensors
m, n = ctx.A_shape
method = ctx.method
atol = ctx.atol
maxiter = ctx.maxiter
batch_size = val_batch.size(0)
gradval_list = []
gradb_list = []
for i in range(batch_size):
val = val_batch[i]
u = u_batch[i]
gradu = gradu_batch[i]
# Solve A^T * gradb = gradu
if method == 'cg':
_cpu = get_cpu_module()
gradb = _cpu.cg(torch.stack([col, row], 0), val, n, m, gradu, atol, maxiter)
elif method == 'bicgstab':
_cpu = get_cpu_module()
gradb = _cpu.bicgstab(torch.stack([col, row], 0), val, n, m, gradu, atol, maxiter)
elif method == 'cusolver_qr':
_cusolver = get_cusolver_module()
gradb = _cusolver.qr(torch.stack([col, row], 0), val, n, m, gradu, 1e-12)
elif method == 'cusolver_cholesky':
_cusolver = get_cusolver_module()
gradb = _cusolver.cholesky(torch.stack([row, col], 0), val, m, n, gradu, 1e-12)
elif method == 'cusolver_lu':
_cusolver = get_cusolver_module()
gradb = _cusolver.lu(torch.stack([col, row], 0), val, n, m, gradu, 1e-12)
elif method in ['cudss_lu']:
_cudss = get_cudss_module()
gradb = _cudss.lu(torch.stack([col, row], 0), val, n, m, gradu)
elif method == 'cudss_cholesky':
_cudss = get_cudss_module()
gradb = _cudss.cholesky(torch.stack([row, col], 0), val, m, n, gradu)
elif method == 'cudss_ldlt':
_cudss = get_cudss_module()
gradb = _cudss.ldlt(torch.stack([row, col], 0), val, m, n, gradu)
else:
raise ValueError(f"Unknown method: {method}")
gradval = -gradb[row] * u[col]
gradval_list.append(gradval)
gradb_list.append(gradb)
gradval_batch = torch.stack(gradval_list, dim=0)
gradb_batch = torch.stack(gradb_list, dim=0)
return gradval_batch, None, None, None, gradb_batch, None, None, None
[文档]
def spsolve_batch_same_layout(
val_batch: torch.Tensor,
row: torch.Tensor,
col: torch.Tensor,
shape: Tuple[int, int],
b_batch: torch.Tensor,
method: MethodType = "bicgstab",
atol: float = 1e-10,
maxiter: int = 10000
) -> torch.Tensor:
"""
Batch solve sparse linear systems with the SAME sparsity pattern.
.. deprecated::
Use SparseTensor.decompose().solve() instead for a more Pythonic interface:
>>> A = SparseTensor(val, row, col, shape)
>>> decomp = A.decompose(method='superlu')
>>> x_batch = decomp.solve(val_batch, b_batch)
All matrices A_i share the same (row, col) structure but have different values.
This is efficient when the sparsity pattern is fixed (e.g., FEM with fixed mesh).
Solves: A_i @ x_i = b_i for i = 0, 1, ..., batch_size-1
Parameters
----------
val_batch : torch.Tensor
[batch_size, nnz] Non-zero values for each matrix
row : torch.Tensor
[nnz] Row indices (shared across batch)
col : torch.Tensor
[nnz] Column indices (shared across batch)
shape : Tuple[int, int]
(m, n) Shape of each sparse matrix
b_batch : torch.Tensor
[batch_size, m] Right-hand side vectors
method : str
Solver method (same options as spsolve)
atol : float
Absolute tolerance for iterative solvers
maxiter : int
Maximum iterations for iterative solvers
Returns
-------
torch.Tensor
[batch_size, n] Solution vectors
Example
-------
>>> import torch
>>> from torch_sla import spsolve_batch_same_layout
>>>
>>> batch_size = 10
>>> n = 100
>>> nnz = 500
>>>
>>> # Same sparsity pattern, different values
>>> row = torch.randint(0, n, (nnz,))
>>> col = torch.randint(0, n, (nnz,))
>>> val_batch = torch.randn(batch_size, nnz, dtype=torch.float64)
>>> b_batch = torch.randn(batch_size, n, dtype=torch.float64)
>>>
>>> x_batch = spsolve_batch_same_layout(val_batch, row, col, (n, n), b_batch)
"""
# Validation
assert val_batch.dim() == 2, f"val_batch must be 2D [batch, nnz], got {val_batch.dim()}D"
assert b_batch.dim() == 2, f"b_batch must be 2D [batch, m], got {b_batch.dim()}D"
assert val_batch.size(0) == b_batch.size(0), "Batch sizes must match"
assert val_batch.size(1) == row.size(0), "val_batch[1] must equal nnz"
assert val_batch.size(1) == col.size(0), "val_batch[1] must equal nnz"
assert b_batch.size(1) == shape[0], "b_batch[1] must equal m"
return BatchSparseLinearSolveSameLayout.apply(
val_batch, row, col, shape, b_batch, method, atol, maxiter
)
[文档]
def spsolve_batch_different_layout(
matrices: List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Tuple[int, int]]],
b_list: List[torch.Tensor],
method: MethodType = "bicgstab",
atol: float = 1e-10,
maxiter: int = 10000
) -> List[torch.Tensor]:
"""
Batch solve sparse linear systems with DIFFERENT sparsity patterns.
.. deprecated::
Use SparseTensorList.solve() instead for a more Pythonic interface:
>>> matrices = SparseTensorList([A1, A2, A3])
>>> x_list = matrices.solve([b1, b2, b3])
Each matrix can have a different structure. This is useful when dealing
with heterogeneous problems or adaptive mesh refinement.
Parameters
----------
matrices : List[Tuple[val, row, col, shape]]
List of sparse matrices, each as (values, row_indices, col_indices, shape)
b_list : List[torch.Tensor]
List of right-hand side vectors
method : str
Solver method (same options as spsolve)
atol : float
Absolute tolerance for iterative solvers
maxiter : int
Maximum iterations for iterative solvers
Returns
-------
List[torch.Tensor]
List of solution vectors
Example
-------
>>> import torch
>>> from torch_sla import spsolve_batch_different_layout
>>>
>>> # Different matrices with different sizes/patterns
>>> matrices = []
>>> b_list = []
>>> for n in [50, 100, 150]:
... nnz = n * 5
... val = torch.randn(nnz, dtype=torch.float64)
... row = torch.randint(0, n, (nnz,))
... col = torch.randint(0, n, (nnz,))
... matrices.append((val, row, col, (n, n)))
... b_list.append(torch.randn(n, dtype=torch.float64))
>>>
>>> x_list = spsolve_batch_different_layout(matrices, b_list)
"""
from .linear_solve import spsolve
assert len(matrices) == len(b_list), "Number of matrices must equal number of RHS vectors"
results = []
for (val, row, col, shape), b in zip(matrices, b_list):
x = spsolve(val, row, col, shape, b, method=method, atol=atol, maxiter=maxiter)
results.append(x)
return results
def spsolve_batch_coo_same_layout(
A_template: torch.Tensor,
val_batch: torch.Tensor,
b_batch: torch.Tensor,
method: MethodType = "bicgstab",
**kwargs
) -> torch.Tensor:
"""
Batch solve using a template sparse COO tensor for the structure.
Parameters
----------
A_template : torch.Tensor
Sparse COO tensor defining the sparsity pattern
val_batch : torch.Tensor
[batch_size, nnz] Values for each matrix
b_batch : torch.Tensor
[batch_size, m] Right-hand side vectors
method : str
Solver method
**kwargs
Additional arguments passed to spsolve_batch_same_layout
Returns
-------
torch.Tensor
[batch_size, n] Solution vectors
"""
assert A_template.is_sparse, "A_template must be sparse"
indices = A_template._indices()
row = indices[0]
col = indices[1]
shape = tuple(A_template.shape)
return spsolve_batch_same_layout(val_batch, row, col, shape, b_batch, method, **kwargs)
def spsolve_batch_coo_different_layout(
A_list: List[torch.Tensor],
b_list: List[torch.Tensor],
method: MethodType = "bicgstab",
**kwargs
) -> List[torch.Tensor]:
"""
Batch solve using sparse COO tensors with different structures.
Parameters
----------
A_list : List[torch.Tensor]
List of sparse COO tensors
b_list : List[torch.Tensor]
List of right-hand side vectors
method : str
Solver method
**kwargs
Additional arguments passed to spsolve_batch_different_layout
Returns
-------
List[torch.Tensor]
List of solution vectors
"""
matrices = []
for A in A_list:
assert A.is_sparse, "All matrices must be sparse"
indices = A._indices()
val = A._values()
row = indices[0]
col = indices[1]
shape = tuple(A.shape)
matrices.append((val, row, col, shape))
return spsolve_batch_different_layout(matrices, b_list, method, **kwargs)
# Parallel batch solver for better GPU utilization
class ParallelBatchSolver:
"""
High-performance parallel batch solver.
This class pre-analyzes the sparsity pattern and caches factorization
information for repeated solves with the same structure.
Example
-------
>>> solver = ParallelBatchSolver(row, col, shape, method='cudss_lu')
>>>
>>> # Solve multiple batches efficiently
>>> for val_batch, b_batch in data_loader:
... x_batch = solver.solve(val_batch, b_batch)
"""
def __init__(
self,
row: torch.Tensor,
col: torch.Tensor,
shape: Tuple[int, int],
method: MethodType = "bicgstab",
device: Optional[str] = None
):
"""
Initialize the parallel batch solver.
Parameters
----------
row : torch.Tensor
[nnz] Row indices
col : torch.Tensor
[nnz] Column indices
shape : Tuple[int, int]
(m, n) Matrix shape
method : str
Solver method
device : str, optional
Device for computation
"""
self.row = row
self.col = col
self.shape = shape
self.method = method
self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
# Move indices to device
self.row = self.row.to(self.device)
self.col = self.col.to(self.device)
def solve(
self,
val_batch: torch.Tensor,
b_batch: torch.Tensor,
atol: float = 1e-10,
maxiter: int = 10000
) -> torch.Tensor:
"""
Solve batch of linear systems.
Parameters
----------
val_batch : torch.Tensor
[batch_size, nnz] Matrix values
b_batch : torch.Tensor
[batch_size, m] Right-hand side vectors
atol : float
Tolerance for iterative solvers
maxiter : int
Maximum iterations
Returns
-------
torch.Tensor
[batch_size, n] Solution vectors
"""
val_batch = val_batch.to(self.device)
b_batch = b_batch.to(self.device)
return spsolve_batch_same_layout(
val_batch, self.row, self.col, self.shape, b_batch,
method=self.method, atol=atol, maxiter=maxiter
)
def __call__(self, val_batch: torch.Tensor, b_batch: torch.Tensor, **kwargs) -> torch.Tensor:
"""Callable interface for the solver."""
return self.solve(val_batch, b_batch, **kwargs)