torch-sla - PyTorch Sparse Linear Algebra with GPU Acceleration

torch-sla: PyTorch Sparse Linear Algebraยถ

torch-sla (Torch Sparse Linear Algebra) is a memory-efficient, differentiable sparse linear equation solver library for PyTorch with multiple backends. Perfect for scientific computing, FEM, CFD, and machine learning applications requiring sparse matrix operations with automatic differentiation.

arXiv GitHub PyPI License: MIT

Why torch-sla?ยถ

  • ๐Ÿš€ High Performance: CUDA-accelerated solvers via cuSOLVER and cuDSS
  • ๐Ÿ’พ Memory Efficient: Store only non-zero elements, enabling solving of systems with millions of unknowns
  • ๐Ÿ”„ Differentiable: Full gradient support through torch.autograd
  • ๐Ÿ“ฆ Batch Processing: Solve thousands of systems in parallel
  • ๐ŸŒ Distributed: Domain decomposition with halo exchange for large-scale problems
  • ๐Ÿ”ง Flexible: Multiple backends and solver methods

Key Featuresยถ

  • Memory efficient: Only stores non-zero elements โ€” a 1Mร—1M matrix with 1% density uses ~80MB instead of ~8TB
  • Full gradient support via torch.autograd for end-to-end differentiable pipelines
  • Multiple backends: SciPy, Eigen, cuSOLVER, cuDSS
  • Batch solving: Same-layout and different-layout sparse matrices
  • Distributed solving: Domain decomposition with halo exchange
  • 169M+ DOF tested: Scales to very large problems with near-linear complexity

Quick Startยถ

Installationยถ

pip install torch-sla

Basic Usageยถ

import torch
from torch_sla import SparseTensor

# Create a sparse matrix from dense (easier to read for small matrices)
dense = torch.tensor([[4.0, -1.0,  0.0],
                      [-1.0, 4.0, -1.0],
                      [ 0.0, -1.0, 4.0]], dtype=torch.float64)

A = SparseTensor.from_dense(dense)

# Solve Ax = b
b = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float64)
x = A.solve(b)

CUDA Accelerationยถ

# Move to GPU for CUDA-accelerated solving
A_cuda = A.cuda()
b_cuda = b.cuda()
x = A_cuda.solve(b_cuda)  # Uses cuDSS or cuSOLVER automatically

Use Casesยถ

torch-sla is ideal for:

  • Finite Element Method (FEM): Solve large sparse systems from FEM discretization

  • Computational Fluid Dynamics (CFD): Efficient sparse solvers for Navier-Stokes

  • Physics-Informed Neural Networks (PINNs): Differentiable sparse operations for physics constraints

  • Graph Neural Networks: Sparse message passing and Laplacian operations

  • Optimization: Gradient-based optimization involving sparse linear systems


Frequently Asked Questions (FAQ)ยถ

What is torch-sla?ยถ

torch-sla (Torch Sparse Linear Algebra) is a Python library that provides differentiable sparse linear equation solvers for PyTorch. It solves systems of the form Ax = b where A is a sparse matrix, with full support for automatic differentiation (autograd) and GPU acceleration via CUDA.

How do I solve a sparse linear system in PyTorch?ยถ

Use torch-slaโ€™s SparseTensor class:

from torch_sla import SparseTensor

# Create sparse matrix from COO format (values, row indices, column indices)
A = SparseTensor(values, row, col, shape)

# Solve Ax = b
x = A.solve(b)

This works on both CPU and GPU, and supports gradient computation.

What sparse solvers does torch-sla support?ยถ

torch-sla supports multiple backends:

  • CPU: SciPy (SuperLU, UMFPACK, CG, BiCGStab, GMRES), Eigen (CG, BiCGStab)

  • GPU: cuSOLVER (QR, Cholesky, LU), cuDSS (LU, Cholesky, LDLT)

The library automatically selects the best solver based on your hardware and matrix properties.

Can I compute gradients through sparse solve?ยถ

Yes. torch-sla fully supports PyTorch autograd:

val = torch.tensor([...], requires_grad=True)
x = spsolve(val, row, col, shape, b)
loss = x.sum()
loss.backward()  # Computes gradients w.r.t. val and b

How do I solve batched sparse systems?ยถ

torch-sla supports batched solving for matrices with the same sparsity pattern:

# Batched values: [batch_size, nnz]
A = SparseTensor(val_batch, row, col, (batch_size, M, N))
x = A.solve(b_batch)  # Solves all systems in parallel

For matrices with different patterns, use SparseTensorList. See batched solve examples.

How do I use torch-sla on GPU?ยถ

Simply move your tensors to CUDA:

A_cuda = A.cuda()
x = A_cuda.solve(b.cuda())  # Uses cuDSS or cuSOLVER

What is the difference between SparseTensor and DSparseTensor?ยถ

  • SparseTensor: Single sparse matrix (optionally batched), for standard solving

  • DSparseTensor: Distributed sparse tensor with domain decomposition, for large-scale parallel computing with halo exchange

Comparison with Alternativesยถ

torch-sla vs scipy.sparse.linalgยถ

Feature

torch-sla โœ…

scipy.sparse.linalg

PyTorch Integration

โœ… Native tensors

โŒ Requires numpy copy

GPU Acceleration

โœ… CUDA (cuDSS, cuSOLVER)

โŒ CPU only

Autograd Gradients

โœ… Full support (adjoint)

โŒ No gradients

Batched Solve

โœ… Parallel batch solve

โŒ Loop required

Large Scale (>2M DOF)

โœ… 169M DOF tested

โš ๏ธ Memory limited

Distributed Computing

โœ… DSparseTensor

โŒ Not supported

Eigenvalue/SVD

โœ… Differentiable

โš ๏ธ No gradients

Nonlinear Solve

โœ… Newton/Anderson

โŒ Not included

torch-sla vs torch.linalg.solveยถ

Feature

torch-sla โœ…

torch.linalg.solve

Matrix Type

โœ… Sparse (COO/CSR)

โŒ Dense only

Memory (1Mร—1M, 1% density)

โœ… ~80 MB

โŒ ~8 TB (impossible)

Max Problem Size

โœ… 500M+ DOF (multi-GPU, scalable)

โŒ ~50K (GPU memory)

Specialized Solvers

โœ… LU, Cholesky, CG, BiCGStab

โš ๏ธ Dense LU only

Batched Operations

โœ… Same/different patterns

โš ๏ธ Same shape only

GPU Support

โœ… cuDSS, cuSOLVER, PyTorch

โœ… Yes

Autograd

โœ… O(1) graph nodes

โœ… Yes

torch-sla vs NVIDIA AmgXยถ

Feature

torch-sla โœ…

NVIDIA AmgX

Installation

โœ… pip install torch-sla

โŒ Complex build process

PyTorch Integration

โœ… Native

โŒ Requires wrapper

Autograd Support

โœ… Full gradient flow

โŒ No gradients

Python API

โœ… Pythonic

โš ๏ธ C++ focused

Multigrid (AMG)

โŒ Not yet

โœ… Core feature

Preconditioners

โš ๏ธ Jacobi

โœ… ILU, AMG, etc.

Documentation

โœ… Comprehensive

โš ๏ธ Limited examples

torch-sla vs PETScยถ

Feature

torch-sla โœ…

PETSc

Installation

โœ… pip install

โŒ Complex (MPI, compilers)

Learning Curve

โœ… Simple Python API

โŒ Steep (C/Fortran heritage)

PyTorch Integration

โœ… Native tensors

โŒ Requires petsc4py + copies

Autograd

โœ… Full support

โŒ No gradients

Solver Variety

โš ๏ธ Core methods

โœ… Extensive (KSP, SNES)

Distributed

โœ… DSparseTensor multi-GPU

โœ… Full MPI support

Production Scale

โœ… 500M+ DOF (multi-GPU)

โœ… Exascale proven

Summary: When to Use torch-slaยถ

Use torch-sla When

Consider Alternatives When

โœ… You need PyTorch integration

Youโ€™re not using PyTorch

โœ… You need gradient flow through solve

Gradients not needed

โœ… Problem size up to 500M+ DOF (multi-GPU)

Exascale problems (use PETSc)

โœ… You want simple pip install

You need AMG preconditioners (AmgX)

โœ… Batched sparse systems

Complex preconditioning (PETSc)

โœ… GPU acceleration with minimal setup

Full MPI distributed (PETSc)