Source code for torch_sla.linear_solve

"""
Sparse Linear Solve module for PyTorch

This module provides differentiable sparse linear equation solvers with multiple backends.

Backends:
---------
- 'scipy': SciPy backend (CPU only) - Direct solvers via SuperLU/UMFPACK
- 'eigen': Eigen backend (CPU only) - Iterative solvers (CG, BiCGStab)
- 'pytorch': PyTorch-native (CPU & CUDA) - Iterative solvers for large-scale problems
- 'cusolver': NVIDIA cuSOLVER (CUDA only) - Direct solvers (QR, Cholesky, LU)
- 'cudss': NVIDIA cuDSS (CUDA only) - Direct solvers (LU, Cholesky, LDLT)

Methods:
--------
Direct solvers:
- 'superlu': SuperLU direct solver (scipy backend)
- 'umfpack': UMFPACK direct solver (scipy backend)
- 'lu': LU decomposition
- 'qr': QR decomposition (cusolver)
- 'cholesky': Cholesky decomposition (SPD matrices)
- 'ldlt': LDLT decomposition (symmetric matrices, cudss)

Iterative solvers:
- 'cg': Conjugate Gradient (SPD matrices)
- 'bicgstab': BiCGStab (general matrices)
- 'gmres': GMRES (general matrices, scipy)
- 'minres': MINRES (symmetric matrices, scipy)

Usage:
------
    # Auto-select backend and method based on device and problem size
    x = spsolve(val, row, col, shape, b)
    
    # Specify backend and method
    x = spsolve(val, row, col, shape, b, backend='scipy', method='superlu')
    x = spsolve(val, row, col, shape, b, backend='cudss', method='lu')
    x = spsolve(val, row, col, shape, b, backend='pytorch', method='cg')  # GPU iterative
"""

import warnings
import torch
from torch.autograd.function import Function
from typing import Tuple, Optional, Union, Literal

from .backends import (
    get_eigen_module,
    get_cusolver_module,
    get_cudss_module,
    is_scipy_available,
    is_eigen_available,
    is_pytorch_available,
    is_cusolver_available,
    is_cudss_available,
    select_backend,
    select_method,
    BACKEND_METHODS,
    CUDA_ITERATIVE_THRESHOLD,
    BackendType,
    MethodType,
)
from .backends.scipy_backend import scipy_solve
from .backends.pytorch_backend import pytorch_solve


# ============================================================================
# Autograd Functions for gradient support
# ============================================================================

class SparseLinearSolveScipySuperLU(Function):
    """SciPy SuperLU solver with gradient support"""

    @staticmethod
    def forward(ctx, val, row, col, shape, b, method, atol, maxiter):
        u = scipy_solve(val, row, col, shape, b, method=method, atol=atol, maxiter=maxiter)
        ctx.save_for_backward(val, row, col, u)
        ctx.shape = shape
        ctx.method = method
        ctx.atol = atol
        ctx.maxiter = maxiter
        return u

    @staticmethod
    def backward(ctx, gradu):
        val, row, col, u = ctx.saved_tensors
        shape = ctx.shape
        method = ctx.method
        atol = ctx.atol
        maxiter = ctx.maxiter
        
        # Solve A^T * gradb = gradu
        gradb = scipy_solve(val, col, row, (shape[1], shape[0]), gradu,
                           method=method, atol=atol, maxiter=maxiter)
        gradval = -gradb[row] * u[col]
        return gradval, None, None, None, gradb, None, None, None


class SparseLinearSolveEigenCG(Function):
    """Eigen CG solver with gradient support"""

    @staticmethod
    def forward(ctx, val, row, col, shape, b, atol, maxiter):
        _eigen = get_eigen_module()
        u = _eigen.cg(torch.stack([row, col], 0), val, shape[0], shape[1], b, atol, maxiter)
        ctx.save_for_backward(val, row, col, u)
        ctx.A_shape = shape
        ctx.atol = atol
        ctx.maxiter = maxiter
        return u

    @staticmethod
    def backward(ctx, gradu):
        _eigen = get_eigen_module()
        val, row, col, u = ctx.saved_tensors
        m, n = ctx.A_shape
        atol = ctx.atol
        maxiter = ctx.maxiter
        gradb = _eigen.cg(torch.stack([col, row], 0), val, n, m, gradu, atol, maxiter)
        gradval = -gradb[row] * u[col]
        return gradval, None, None, None, gradb, None, None


class SparseLinearSolveEigenBiCGStab(Function):
    """Eigen BiCGStab solver with gradient support"""

    @staticmethod
    def forward(ctx, val, row, col, shape, b, atol, maxiter):
        _eigen = get_eigen_module()
        u = _eigen.bicgstab(torch.stack([row, col], 0), val, shape[0], shape[1], b, atol, maxiter)
        ctx.save_for_backward(val, row, col, u)
        ctx.A_shape = shape
        ctx.atol = atol
        ctx.maxiter = maxiter
        return u

    @staticmethod
    def backward(ctx, gradu):
        _eigen = get_eigen_module()
        val, row, col, u = ctx.saved_tensors
        m, n = ctx.A_shape
        atol = ctx.atol
        maxiter = ctx.maxiter
        gradb = _eigen.bicgstab(torch.stack([col, row], 0), val, n, m, gradu, atol, maxiter)
        gradval = -gradb[row] * u[col]
        return gradval, None, None, None, gradb, None, None


class SparseLinearSolveCuSolverQR(Function):
    """cuSOLVER QR solver with gradient support"""

    @staticmethod
    def forward(ctx, val, row, col, shape, b, tol):
        cusolver = get_cusolver_module()
        indices = torch.stack([row, col], 0)
        u = cusolver.qr(indices, val, shape[0], shape[1], b, tol)
        ctx.save_for_backward(val, row, col, u)
        ctx.A_shape = shape
        ctx.tol = tol
        return u

    @staticmethod
    def backward(ctx, gradu):
        cusolver = get_cusolver_module()
        val, row, col, u = ctx.saved_tensors
        m, n = ctx.A_shape
        tol = ctx.tol
        indices_T = torch.stack([col, row], 0)
        gradb = cusolver.qr(indices_T, val, n, m, gradu, tol)
        gradval = -gradb[row] * u[col]
        return gradval, None, None, None, gradb, None


class SparseLinearSolveCuSolverCholesky(Function):
    """cuSOLVER Cholesky solver with gradient support"""

    @staticmethod
    def forward(ctx, val, row, col, shape, b, tol):
        cusolver = get_cusolver_module()
        indices = torch.stack([row, col], 0)
        u = cusolver.cholesky(indices, val, shape[0], shape[1], b, tol)
        ctx.save_for_backward(val, row, col, u)
        ctx.A_shape = shape
        ctx.tol = tol
        return u

    @staticmethod
    def backward(ctx, gradu):
        cusolver = get_cusolver_module()
        val, row, col, u = ctx.saved_tensors
        m, n = ctx.A_shape
        tol = ctx.tol
        indices = torch.stack([row, col], 0)
        gradb = cusolver.cholesky(indices, val, m, n, gradu, tol)
        gradval = -gradb[row] * u[col]
        return gradval, None, None, None, gradb, None


class SparseLinearSolveCuSolverLU(Function):
    """cuSOLVER LU solver with gradient support"""

    @staticmethod
    def forward(ctx, val, row, col, shape, b, tol):
        cusolver = get_cusolver_module()
        indices = torch.stack([row, col], 0)
        u = cusolver.lu(indices, val, shape[0], shape[1], b, tol)
        ctx.save_for_backward(val, row, col, u)
        ctx.A_shape = shape
        ctx.tol = tol
        return u

    @staticmethod
    def backward(ctx, gradu):
        cusolver = get_cusolver_module()
        val, row, col, u = ctx.saved_tensors
        m, n = ctx.A_shape
        tol = ctx.tol
        indices_T = torch.stack([col, row], 0)
        gradb = cusolver.lu(indices_T, val, n, m, gradu, tol)
        gradval = -gradb[row] * u[col]
        return gradval, None, None, None, gradb, None


class SparseLinearSolveCuDSS(Function):
    """cuDSS general solver with gradient support"""

    @staticmethod
    def forward(ctx, val, row, col, shape, b, matrix_type):
        cudss = get_cudss_module()
        indices = torch.stack([row, col], 0)
        u = cudss.solve(indices, val, shape[0], shape[1], b, matrix_type, "default")
        ctx.save_for_backward(val, row, col, u)
        ctx.A_shape = shape
        ctx.matrix_type = matrix_type
        return u

    @staticmethod
    def backward(ctx, gradu):
        cudss = get_cudss_module()
        val, row, col, u = ctx.saved_tensors
        m, n = ctx.A_shape
        matrix_type = ctx.matrix_type
        
        if matrix_type in ['symmetric', 'spd', 'hpd']:
            indices = torch.stack([row, col], 0)
            gradb = cudss.solve(indices, val, m, n, gradu, matrix_type, "default")
        else:
            indices_T = torch.stack([col, row], 0)
            gradb = cudss.solve(indices_T, val, n, m, gradu, "general", "default")
        
        gradval = -gradb[row] * u[col]
        return gradval, None, None, None, gradb, None


class SparseLinearSolveCuDSSLU(Function):
    """cuDSS LU solver with gradient support"""

    @staticmethod
    def forward(ctx, val, row, col, shape, b):
        cudss = get_cudss_module()
        indices = torch.stack([row, col], 0)
        u = cudss.lu(indices, val, shape[0], shape[1], b)
        ctx.save_for_backward(val, row, col, u)
        ctx.A_shape = shape
        return u

    @staticmethod
    def backward(ctx, gradu):
        cudss = get_cudss_module()
        val, row, col, u = ctx.saved_tensors
        m, n = ctx.A_shape
        indices_T = torch.stack([col, row], 0)
        gradb = cudss.lu(indices_T, val, n, m, gradu)
        gradval = -gradb[row] * u[col]
        return gradval, None, None, None, gradb


class SparseLinearSolveCuDSSCholesky(Function):
    """cuDSS Cholesky solver with gradient support"""

    @staticmethod
    def forward(ctx, val, row, col, shape, b):
        cudss = get_cudss_module()
        indices = torch.stack([row, col], 0)
        u = cudss.cholesky(indices, val, shape[0], shape[1], b)
        ctx.save_for_backward(val, row, col, u)
        ctx.A_shape = shape
        return u

    @staticmethod
    def backward(ctx, gradu):
        cudss = get_cudss_module()
        val, row, col, u = ctx.saved_tensors
        m, n = ctx.A_shape
        indices = torch.stack([row, col], 0)
        gradb = cudss.cholesky(indices, val, m, n, gradu)
        gradval = -gradb[row] * u[col]
        return gradval, None, None, None, gradb


class SparseLinearSolveCuDSSLDLT(Function):
    """cuDSS LDLT solver with gradient support"""

    @staticmethod
    def forward(ctx, val, row, col, shape, b):
        cudss = get_cudss_module()
        indices = torch.stack([row, col], 0)
        u = cudss.ldlt(indices, val, shape[0], shape[1], b)
        ctx.save_for_backward(val, row, col, u)
        ctx.A_shape = shape
        return u

    @staticmethod
    def backward(ctx, gradu):
        cudss = get_cudss_module()
        val, row, col, u = ctx.saved_tensors
        m, n = ctx.A_shape
        indices = torch.stack([row, col], 0)
        gradb = cudss.ldlt(indices, val, m, n, gradu)
        gradval = -gradb[row] * u[col]
        return gradval, None, None, None, gradb


class SparseLinearSolvePyTorch(Function):
    """PyTorch-native iterative solver with gradient support (works on CPU & CUDA)"""

    @staticmethod
    def forward(ctx, val, row, col, shape, b, method, atol, rtol, maxiter, preconditioner, mixed_precision):
        u = pytorch_solve(val, row, col, shape, b, method=method, atol=atol, rtol=rtol, 
                          maxiter=maxiter, preconditioner=preconditioner, mixed_precision=mixed_precision)
        # Convert back to input dtype if mixed precision was used
        if mixed_precision and u.dtype != val.dtype:
            u = u.to(val.dtype)
        ctx.save_for_backward(val, row, col, u)
        ctx.shape = shape
        ctx.method = method
        ctx.atol = atol
        ctx.rtol = rtol
        ctx.maxiter = maxiter
        ctx.preconditioner = preconditioner
        ctx.mixed_precision = mixed_precision
        return u

    @staticmethod
    def backward(ctx, gradu):
        val, row, col, u = ctx.saved_tensors
        shape = ctx.shape
        method = ctx.method
        atol = ctx.atol
        rtol = ctx.rtol
        maxiter = ctx.maxiter
        preconditioner = ctx.preconditioner
        mixed_precision = ctx.mixed_precision
        
        # Solve A^T * gradb = gradu
        gradb = pytorch_solve(val, col, row, (shape[1], shape[0]), gradu,
                              method=method, atol=atol, rtol=rtol, maxiter=maxiter,
                              preconditioner=preconditioner, mixed_precision=mixed_precision)
        if gradb.dtype != val.dtype:
            gradb = gradb.to(val.dtype)
        gradval = -gradb[row] * u[col]
        return gradval, None, None, None, gradb, None, None, None, None, None, None


# ============================================================================
# Main solve function
# ============================================================================

[docs] def spsolve( val: torch.Tensor, row: torch.Tensor, col: torch.Tensor, shape: Tuple[int, int], b: torch.Tensor, backend: BackendType = "auto", method: MethodType = "auto", atol: float = 1e-10, maxiter: int = 10000, tol: float = 1e-12, matrix_type: str = "general", is_symmetric: bool = False, is_spd: bool = False, preconditioner: str = "jacobi", mixed_precision: bool = False, ) -> torch.Tensor: """ Solve the Sparse Linear Equation Ax = b with gradient support. Supports multiple backends for CPU and CUDA tensors. Parameters ---------- val : torch.Tensor [nnz] Non-zero values of sparse matrix A in COO format row : torch.Tensor [nnz] Row indices col : torch.Tensor [nnz] Column indices shape : Tuple[int, int] (m, n) Shape of sparse matrix A b : torch.Tensor [m] Right-hand side vector backend : str, optional Backend to use: - 'auto': Auto-select based on device and problem size (default) - 'scipy': SciPy (CPU only, uses SuperLU/UMFPACK) - 'eigen': Eigen C++ (CPU only, iterative) - 'pytorch': PyTorch-native (CPU & CUDA, iterative) - best for large problems - 'cusolver': NVIDIA cuSOLVER (CUDA only, direct) - 'cudss': NVIDIA cuDSS (CUDA only, direct) method : str, optional Solver method. Available methods depend on backend: - 'auto': Auto-select based on matrix properties - 'superlu', 'umfpack': Direct solvers (scipy) - 'cg', 'bicgstab', 'gmres': Iterative solvers - 'lu', 'qr', 'cholesky', 'ldlt': Direct solvers (CUDA) atol : float, optional Absolute tolerance for iterative solvers, by default 1e-10 maxiter : int, optional Maximum iterations for iterative solvers, by default 10000 tol : float, optional Tolerance for direct solvers, by default 1e-12 matrix_type : str, optional Matrix type for cuDSS: 'general', 'symmetric', 'spd', by default "general" is_symmetric : bool, optional Hint that matrix is symmetric (for auto method selection) is_spd : bool, optional Hint that matrix is symmetric positive definite Returns ------- torch.Tensor [n] Solution vector x Examples -------- >>> import torch >>> from torch_sla import spsolve >>> >>> # Create a simple SPD matrix >>> val = torch.tensor([4.0, -1.0, -1.0, 4.0, -1.0, -1.0, 4.0], dtype=torch.float64) >>> row = torch.tensor([0, 0, 1, 1, 1, 2, 2], dtype=torch.int64) >>> col = torch.tensor([0, 1, 0, 1, 2, 1, 2], dtype=torch.int64) >>> shape = (3, 3) >>> b = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float64) >>> >>> # Auto-select backend and method >>> x = spsolve(val, row, col, shape, b) >>> >>> # Specify backend and method >>> x = spsolve(val, row, col, shape, b, backend='scipy', method='superlu') >>> >>> # On CUDA >>> val_cuda = val.cuda() >>> row_cuda = row.cuda() >>> col_cuda = col.cuda() >>> b_cuda = b.cuda() >>> x_cuda = spsolve(val_cuda, row_cuda, col_cuda, shape, b_cuda, backend='cudss', method='lu') """ # Input validation assert val.dim() == 1, f"val must be 1D tensor, got {val.dim()}" assert row.dim() == 1, f"row must be 1D tensor, got {row.dim()}" assert col.dim() == 1, f"col must be 1D tensor, got {col.dim()}" assert b.dim() == 1, f"b must be 1D tensor, got {b.dim()}" assert shape[0] > 0, f"shape[0] must be positive, got {shape[0]}" assert shape[1] > 0, f"shape[1] must be positive, got {shape[1]}" assert val.size(0) == row.size(0), "val and row must have same size" assert val.size(0) == col.size(0), "val and col must have same size" assert b.size(0) == shape[0], "b and shape[0] must have same size" assert val.dtype == b.dtype, "val and b must have same dtype" device = val.device n = shape[0] # Problem size (DOF) # Auto-select backend based on device and problem size if backend == "auto": backend = select_backend(device, n=n) # Auto-select method if method == "auto": method = select_method(backend, is_symmetric=is_symmetric, is_spd=is_spd) # Validate backend-method combination valid_methods = BACKEND_METHODS.get(backend, []) if method not in valid_methods and method != "auto": raise ValueError(f"Method '{method}' not supported by backend '{backend}'. " f"Available methods: {valid_methods}") # ======================================================================== # SciPy backend (CPU) # ======================================================================== if backend == "scipy": if val.is_cuda: warnings.warn("SciPy backend requires CPU, moving tensors to CPU") val = val.cpu() row = row.cpu() col = col.cpu() b = b.cpu() if not is_scipy_available(): raise RuntimeError("SciPy is not available. Install with: pip install scipy") return SparseLinearSolveScipySuperLU.apply( val, row, col, shape, b, method, atol, maxiter ) # ======================================================================== # Eigen backend (CPU) # ======================================================================== elif backend == "eigen": if val.is_cuda: warnings.warn("Eigen backend requires CPU, moving tensors to CPU") val = val.cpu() row = row.cpu() col = col.cpu() b = b.cpu() if not is_eigen_available(): raise RuntimeError("Eigen backend is not available. Ensure C++ extension is compiled.") if val.dtype != torch.float64: warnings.warn("Using float64 is recommended for good precision with iterative solvers") if method == "cg": return SparseLinearSolveEigenCG.apply(val, row, col, shape, b, atol, maxiter) else: # bicgstab return SparseLinearSolveEigenBiCGStab.apply(val, row, col, shape, b, atol, maxiter) # ======================================================================== # cuSOLVER backend (CUDA) # ======================================================================== elif backend == "cusolver": if not val.is_cuda: raise ValueError("cuSOLVER backend requires CUDA tensors") if not is_cusolver_available(): raise RuntimeError("cuSOLVER backend is not available") if method == "qr": return SparseLinearSolveCuSolverQR.apply(val, row, col, shape, b, tol) elif method == "cholesky": return SparseLinearSolveCuSolverCholesky.apply(val, row, col, shape, b, tol) else: # lu return SparseLinearSolveCuSolverLU.apply(val, row, col, shape, b, tol) # ======================================================================== # cuDSS backend (CUDA) # ======================================================================== elif backend == "cudss": if not val.is_cuda: raise ValueError("cuDSS backend requires CUDA tensors") if not is_cudss_available(): raise RuntimeError("cuDSS backend is not available. Install with: pip install nvidia-cudss-cu12") if method == "lu": return SparseLinearSolveCuDSSLU.apply(val, row, col, shape, b) elif method == "cholesky": return SparseLinearSolveCuDSSCholesky.apply(val, row, col, shape, b) elif method == "ldlt": return SparseLinearSolveCuDSSLDLT.apply(val, row, col, shape, b) else: # Use general solver with matrix_type return SparseLinearSolveCuDSS.apply(val, row, col, shape, b, matrix_type) # ======================================================================== # PyTorch backend (CPU & CUDA - iterative) # ======================================================================== elif backend == "pytorch": # PyTorch-native iterative solvers work on both CPU and CUDA if val.dtype != torch.float64: warnings.warn("Using float64 is recommended for good precision with iterative solvers") rtol = 1e-10 # Relative tolerance (stricter for better accuracy) return SparseLinearSolvePyTorch.apply(val, row, col, shape, b, method, atol, rtol, maxiter, preconditioner, mixed_precision) else: raise ValueError(f"Unknown backend: {backend}. Available: scipy, eigen, pytorch, cusolver, cudss")
[docs] def spsolve_coo(A: torch.Tensor, b: torch.Tensor, **kwargs) -> torch.Tensor: """Solve Ax = b where A is a sparse COO tensor Parameters ---------- A : torch.Tensor Sparse COO tensor representing the matrix b : torch.Tensor Right-hand side vector **kwargs Additional arguments passed to spsolve() Returns ------- torch.Tensor Solution vector x """ assert A.is_sparse, "A must be a sparse tensor" assert A.layout == torch.sparse_coo, "A must be in COO format" indices = A._indices() values = A._values() shape = tuple(A.shape) row = indices[0] col = indices[1] return spsolve(values, row, col, shape, b, **kwargs)
[docs] def spsolve_csr(A: torch.Tensor, b: torch.Tensor, **kwargs) -> torch.Tensor: """Solve Ax = b where A is a sparse CSR tensor Parameters ---------- A : torch.Tensor Sparse CSR tensor representing the matrix b : torch.Tensor Right-hand side vector **kwargs Additional arguments passed to spsolve() Returns ------- torch.Tensor Solution vector x """ # Convert CSR to COO A_coo = A.to_sparse_coo() return spsolve_coo(A_coo, b, **kwargs)