torch_sla.sparse_tensor 源代码

"""
SparseTensor wrapper class for PyTorch sparse tensors.

Supports batched and block sparse tensors with shape [...batch, M, N, ...block]:
- Leading dimensions: batch dimensions [B1, B2, ...]
- Matrix dimensions: (M, N) at positions (sparse_dim[0], sparse_dim[1]), default (-2, -1)
- Trailing dimensions: block dimensions [K1, K2, ...]

Key Features:
- Automatic symmetry and positive definiteness detection
- Sparse linear equation solving with gradient support
- Sparse-sparse multiplication with sparse gradients
- Batched operations for all methods
- CUDA support with LOBPCG for eigenvalue computation

Examples
--------
>>> # Create a simple sparse matrix
>>> val = torch.tensor([4.0, -1.0, -1.0, 4.0])
>>> row = torch.tensor([0, 0, 1, 1])
>>> col = torch.tensor([0, 1, 0, 1])
>>> A = SparseTensor(val, row, col, (2, 2))
>>>
>>> # Check properties (returns boolean tensor for batched)
>>> is_sym = A.is_symmetric()  # tensor(True)
>>> is_pd = A.is_positive_definite()  # tensor(True)
>>>
>>> # Solve linear system
>>> b = torch.tensor([1.0, 2.0])
>>> x = A.solve(b)
>>>
>>> # Matrix operations
>>> y = A @ x  # Sparse @ Dense
>>> C = A @ A  # Sparse @ Sparse (sparse gradient)
"""

import os
import torch
from torch.autograd.function import Function
from typing import Tuple, Optional, Union, Literal, List, Dict
import warnings
import math

from .backends import (
    is_scipy_available,
    is_eigen_available,
    is_cusolver_available,
    is_cudss_available,
    select_backend,
    select_method,
    BackendType,
    MethodType,
)
from .backends.scipy_backend import (
    scipy_solve,
    scipy_eigs,
    scipy_eigsh,
    scipy_svds,
    scipy_norm,
    scipy_lu,
    scipy_det,
)


# =============================================================================
# Adjoint Determinant Solver
# =============================================================================

class DetAdjoint(Function):
    """
    Adjoint-based differentiable determinant computation.
    
    Uses implicit differentiation to compute gradients:
    For matrix A with determinant d = det(A):
        ∂d/∂A = d * (A^{-1})^T
    
    This means:
        ∂d/∂A_ij = d * (A^{-1})_ji
    
    The gradient computation requires solving a linear system,
    which is done efficiently using the existing solve infrastructure.
    """
    
    @staticmethod
    def forward(ctx, val, row, col, shape, device, is_cuda):
        """Forward pass: compute determinant."""
        from .backends.scipy_backend import scipy_det
        
        if is_cuda:
            # For CUDA, convert to dense and use torch.linalg.det
            # NOTE: This is inefficient for sparse matrices due to O(n²) memory
            # and O(n³) computation. cuSOLVER/cuDSS don't expose determinant
            # computation for sparse matrices directly.
            # 
            # Performance: ~100x slower than CPU for sparse matrices
            # Recommendation: Use .cpu().det() for better performance
            val_detached = val.detach()
            indices = torch.stack([row, col], dim=0).to(device)
            sparse_coo = torch.sparse_coo_tensor(indices, val_detached, shape, device=device)
            dense = sparse_coo.to_dense()
            det_val = torch.linalg.det(dense)
        else:
            # For CPU, use scipy backend
            det_val = scipy_det(val.detach(), row, col, shape)
        
        # Save for backward
        ctx.save_for_backward(val, row, col, det_val)
        ctx.shape = shape
        ctx.device = device
        ctx.is_cuda = is_cuda
        
        return det_val
    
    @staticmethod
    def backward(ctx, grad_output):
        """
        Backward pass: compute gradient using adjoint method.
        
        Gradient formula: ∂L/∂A_ij = ∂L/∂d * d * (A^{-1})_ji
        """
        val, row, col, det_val = ctx.saved_tensors
        shape = ctx.shape
        device = ctx.device
        is_cuda = ctx.is_cuda
        
        # If determinant is zero, gradient is undefined
        if abs(det_val.item()) < 1e-15:
            # Return zero gradient for numerical stability
            return torch.zeros_like(val), None, None, None, None, None
        
        # Compute A^{-1} using sparse solve
        # We need (A^{-1})_ji for each nonzero A_ij
        # 
        # Formula: ∂d/∂A_ij = d * (A^{-1})^T_ij = d * (A^{-1})_ji
        # 
        # Strategy: For each unique row index i in the sparsity pattern,
        # solve A @ x = e_i to get the i-th column of A^{-1}
        # Then (A^{-1})_ji is the j-th element of this column
        
        # Build sparse matrix
        indices = torch.stack([row, col], dim=0).to(device)
        sparse_coo = torch.sparse_coo_tensor(indices, val, shape, device=device)
        
        # Get unique row indices (for each row i, we need column i of A^{-1})
        unique_rows = torch.unique(row)
        
        # Solve for each column of A^{-1}
        A_inv_cols = {}
        for i in unique_rows:
            # Create unit vector e_i
            e_i = torch.zeros(shape[0], dtype=val.dtype, device=device)
            e_i[i] = 1.0
            
            # Solve A @ x = e_i to get i-th column of A^{-1}
            if is_cuda:
                # Use dense solve for CUDA
                dense = sparse_coo.to_dense()
                x = torch.linalg.solve(dense, e_i)
            else:
                # Use scipy backend for CPU
                from .backends.scipy_backend import scipy_solve
                x = scipy_solve(val, row, col, shape, e_i, method='superlu')
            
            A_inv_cols[i.item()] = x
        
        # Compute gradient for each nonzero element
        # ∂d/∂A_ij = d * (A^{-1})_ji
        grad_val = torch.zeros_like(val)
        for k in range(len(val)):
            i = row[k].item()
            j = col[k].item()
            # (A^{-1})_ji is the j-th element of the i-th column of A^{-1}
            grad_val[k] = det_val * A_inv_cols[i][j]
        
        # Multiply by upstream gradient
        grad_val = grad_val * grad_output
        
        return grad_val, None, None, None, None, None


# =============================================================================
# Adjoint Eigenvalue Solver
# =============================================================================

class EigshAdjoint(Function):
    """
    Adjoint-based differentiable eigenvalue solver.
    
    Uses implicit differentiation to compute gradients with O(1) graph nodes,
    regardless of the number of iterations in the forward solve.
    
    For symmetric matrix A with eigenvalue λ and eigenvector v:
        A @ v = λ * v
    
    The gradient is:
        ∂λ/∂A = v @ v.T  (outer product)
        ∂v/∂A requires solving a linear system (more complex)
    """
    
    @staticmethod
    def forward(ctx, val, row, col, shape, k, which, return_eigenvectors, device):
        """Forward pass: compute eigenvalues using LOBPCG or dense fallback."""
        n = shape[0]
        
        # Detach for forward computation
        val_detached = val.detach()
        
        # Build sparse matrix for matvec
        indices = torch.stack([row, col], dim=0).to(device)
        sparse_coo = torch.sparse_coo_tensor(indices, val_detached, shape, device=device)
        
        def matvec(x):
            if x.dim() == 1:
                return torch.sparse.mm(sparse_coo, x.unsqueeze(1)).squeeze(1)
            return torch.sparse.mm(sparse_coo, x)
        
        # Compute eigenvalues
        if device.type == 'cuda':
            # Use LOBPCG on CUDA
            largest = which in ('LM', 'LA')
            eigenvalues, eigenvectors = _lobpcg_eigsh(
                matvec, n, k, val.dtype, device, largest=largest
            )
        else:
            # Use dense fallback on CPU (SciPy breaks gradient)
            A_dense = torch.zeros(n, n, dtype=val.dtype, device=device)
            for i in range(len(row)):
                A_dense[row[i], col[i]] = val_detached[i]
            
            eigenvalues_all, eigenvectors_all = torch.linalg.eigh(A_dense)
            
            if which in ('LM', 'LA'):
                # Largest eigenvalues
                eigenvalues = eigenvalues_all[-k:]
                eigenvectors = eigenvectors_all[:, -k:]
            else:
                # Smallest eigenvalues
                eigenvalues = eigenvalues_all[:k]
                eigenvectors = eigenvectors_all[:, :k]
        
        # Save for backward
        ctx.save_for_backward(val, eigenvalues, eigenvectors)
        ctx.row = row
        ctx.col = col
        ctx.shape = shape
        ctx.k = k
        ctx.return_eigenvectors = return_eigenvectors
        
        if return_eigenvectors:
            return eigenvalues, eigenvectors
        return eigenvalues, None
    
    @staticmethod
    def backward(ctx, grad_eigenvalues, grad_eigenvectors):
        """
        Backward pass using adjoint method.
        
        For eigenvalue λ_i with eigenvector v_i:
            ∂L/∂A[j,k] = Σ_i (∂L/∂λ_i) * v_i[j] * v_i[k]
        
        This gives us O(1) graph nodes.
        """
        val, eigenvalues, eigenvectors = ctx.saved_tensors
        row = ctx.row
        col = ctx.col
        k = ctx.k
        
        if grad_eigenvalues is None:
            return None, None, None, None, None, None, None, None
        
        # Compute gradient w.r.t. values
        # ∂L/∂A[i,j] = Σ_m (∂L/∂λ_m) * v_m[i] * v_m[j]
        # For sparse format: ∂L/∂val[idx] = Σ_m (∂L/∂λ_m) * v_m[row[idx]] * v_m[col[idx]]
        
        grad_val = torch.zeros_like(val)
        
        for m in range(k):
            if grad_eigenvalues[m] != 0:
                # v_m[row] * v_m[col] for each sparse entry
                v_m = eigenvectors[:, m]
                grad_val += grad_eigenvalues[m] * v_m[row] * v_m[col]
        
        # Handle eigenvector gradients if needed (more complex, skip for now)
        # The eigenvector gradient requires solving (A - λI) @ dv = ...
        
        return grad_val, None, None, None, None, None, None, None


# =============================================================================
# Utility Functions
# =============================================================================

def estimate_direct_solver_memory(nnz: int, n: int, dtype: torch.dtype) -> int:
    """
    Estimate memory required for direct sparse solver.
    
    Parameters
    ----------
    nnz : int
        Number of non-zero elements.
    n : int
        Matrix dimension.
    dtype : torch.dtype
        Data type of the matrix.
    
    Returns
    -------
    int
        Estimated memory in bytes.
    """
    bytes_per_element = 8 if dtype == torch.float64 else 4
    fill_factor = min(10, max(2, n / 100))
    factor_memory = int(nnz * fill_factor * bytes_per_element)
    workspace_memory = n * bytes_per_element * 10
    return factor_memory + workspace_memory


def get_available_gpu_memory() -> int:
    """
    Get available GPU memory in bytes.
    
    Returns
    -------
    int
        Available GPU memory in bytes, or 0 if CUDA is not available.
    """
    if not torch.cuda.is_available():
        return 0
    try:
        free_memory, total_memory = torch.cuda.mem_get_info()
        return free_memory
    except Exception:
        return torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_allocated()


def auto_select_method(
    nnz: int, n: int, dtype: torch.dtype, is_cuda: bool, is_spd: bool = False,
    memory_threshold: float = 0.8
) -> Tuple[str, str]:
    """
    Automatically select the best backend and method.
    
    Parameters
    ----------
    nnz : int
        Number of non-zero elements.
    n : int
        Matrix dimension.
    dtype : torch.dtype
        Data type of the matrix.
    is_cuda : bool
        Whether the matrix is on CUDA.
    is_spd : bool, optional
        Whether the matrix is symmetric positive definite. Default: False.
    memory_threshold : float, optional
        Fraction of GPU memory to use. Default: 0.8.
        
    Returns
    -------
    Tuple[str, str]
        (backend, method) tuple.
    """
    if not is_cuda:
        if is_scipy_available():
            return ("scipy", "superlu")
        elif is_eigen_available():
            return ("eigen", "cg" if is_spd else "bicgstab")
        else:
            raise RuntimeError("No CPU backend available")
    
    estimated_memory = estimate_direct_solver_memory(nnz, n, dtype)
    available_memory = get_available_gpu_memory()
    
    if available_memory > 0 and estimated_memory < available_memory * memory_threshold:
        if is_cudss_available():
            return ("cudss", "cholesky" if is_spd else "lu")
        elif is_cusolver_available():
            return ("cusolver", "cholesky" if is_spd else "qr")
    
    if is_scipy_available():
        return ("scipy", "superlu")
    
    raise RuntimeError("No suitable backend available")


# =============================================================================
# Autograd Functions
# =============================================================================

class SparseSolveFunction(Function):
    """
    Differentiable sparse solve using scipy for CPU.
    
    Solves Ax = b and computes gradients for both A's values and b.
    """
    
    @staticmethod
    def forward(ctx, val, row, col, shape, b, method, atol, maxiter):
        u = scipy_solve(val, row, col, shape, b, method=method, atol=atol, maxiter=maxiter)
        ctx.save_for_backward(val, row, col, u, b)
        ctx.shape = shape
        ctx.method = method
        ctx.atol = atol
        ctx.maxiter = maxiter
        return u
    
    @staticmethod
    def backward(ctx, grad_u):
        val, row, col, u, b = ctx.saved_tensors
        shape = ctx.shape
        method = ctx.method
        atol = ctx.atol
        maxiter = ctx.maxiter
        grad_b = scipy_solve(val, col, row, (shape[1], shape[0]), grad_u,
                            method=method, atol=atol, maxiter=maxiter)
        grad_val = -grad_b[row] * u[col]
        return grad_val, None, None, None, grad_b, None, None, None


class SparseSparseMatmulFunction(Function):
    """
    Differentiable Sparse @ Sparse multiplication with SPARSE gradients.
    
    Forward: C = A @ B where A is [M, K], B is [K, N], C is [M, N]
    
    Backward:
    - grad_A_values = (grad_C @ B^T)[A_row, A_col]  (sparse gradient at A's positions)
    - grad_B_values = (A^T @ grad_C)[B_row, B_col]  (sparse gradient at B's positions)
    
    The gradients are computed only at the original non-zero positions,
    keeping memory usage proportional to nnz rather than M*N.
    """
    
    @staticmethod
    def forward(ctx, val_a, row_a, col_a, shape_a, val_b, row_b, col_b, shape_b):
        M, K = shape_a
        K2, N = shape_b
        assert K == K2, f"Inner dimensions must match: {K} vs {K2}"
        
        # Create torch sparse tensors for multiplication
        A_coo = torch.sparse_coo_tensor(
            torch.stack([row_a, col_a]), val_a, (M, K)
        ).coalesce()
        B_coo = torch.sparse_coo_tensor(
            torch.stack([row_b, col_b]), val_b, (K, N)
        ).coalesce()
        
        # Sparse @ Sparse -> Sparse
        with torch.no_grad():
            C_coo = torch.sparse.mm(A_coo, B_coo).coalesce()
        
        # Extract result
        C_indices = C_coo._indices()
        C_values = C_coo._values()
        
        # Save for backward
        ctx.save_for_backward(val_a, row_a, col_a, val_b, row_b, col_b, 
                              C_indices[0], C_indices[1], C_values)
        ctx.shape_a = shape_a
        ctx.shape_b = shape_b
        
        return C_values, C_indices[0], C_indices[1]
    
    @staticmethod
    def backward(ctx, grad_C_values, grad_row_c, grad_col_c):
        (val_a, row_a, col_a, val_b, row_b, col_b, 
         row_c, col_c, val_c) = ctx.saved_tensors
        M, K = ctx.shape_a
        K2, N = ctx.shape_b
        
        grad_val_a = None
        grad_val_b = None
        
        if ctx.needs_input_grad[0]:
            # grad_A = grad_C @ B^T
            grad_C_coo = torch.sparse_coo_tensor(
                torch.stack([row_c, col_c]), grad_C_values, (M, N)
            ).coalesce()
            B_T_coo = torch.sparse_coo_tensor(
                torch.stack([col_b, row_b]), val_b, (N, K)
            ).coalesce()
            grad_A_dense = torch.sparse.mm(grad_C_coo, B_T_coo).to_dense()
            grad_val_a = grad_A_dense[row_a, col_a]
        
        if ctx.needs_input_grad[4]:
            # grad_B = A^T @ grad_C
            A_T_coo = torch.sparse_coo_tensor(
                torch.stack([col_a, row_a]), val_a, (K, M)
            ).coalesce()
            grad_C_coo = torch.sparse_coo_tensor(
                torch.stack([row_c, col_c]), grad_C_values, (M, N)
            ).coalesce()
            grad_B_dense = torch.sparse.mm(A_T_coo, grad_C_coo).to_dense()
            grad_val_b = grad_B_dense[row_b, col_b]
        
        return grad_val_a, None, None, None, grad_val_b, None, None, None


def _sparse_sparse_matmul_with_sparse_grad(
    val_a: torch.Tensor, row_a: torch.Tensor, col_a: torch.Tensor, shape_a: Tuple[int, int],
    val_b: torch.Tensor, row_b: torch.Tensor, col_b: torch.Tensor, shape_b: Tuple[int, int]
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Tuple[int, int]]:
    """
    Sparse @ Sparse with sparse gradients.
    
    Parameters
    ----------
    val_a, row_a, col_a : torch.Tensor
        COO representation of matrix A.
    shape_a : Tuple[int, int]
        Shape of matrix A (M, K).
    val_b, row_b, col_b : torch.Tensor
        COO representation of matrix B.
    shape_b : Tuple[int, int]
        Shape of matrix B (K, N).
    
    Returns
    -------
    Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Tuple[int, int]]
        (values, row_indices, col_indices, shape) of result C = A @ B.
    """
    M, K = shape_a
    K2, N = shape_b
    
    C_values, C_row, C_col = SparseSparseMatmulFunction.apply(
        val_a, row_a, col_a, shape_a,
        val_b, row_b, col_b, shape_b
    )
    
    return C_values, C_row, C_col, (M, N)


# =============================================================================
# LOBPCG and Power Iteration for CUDA
# =============================================================================

def _lobpcg_eigsh(
    A_matvec,
    n: int,
    k: int,
    dtype: torch.dtype,
    device: torch.device,
    largest: bool = True,
    maxiter: int = 1000,
    tol: float = 1e-8
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    LOBPCG eigenvalue solver for sparse matrices on any device.
    
    Uses subspace iteration with Rayleigh-Ritz procedure to find
    the k largest or smallest eigenvalues.
    
    Parameters
    ----------
    A_matvec : callable
        Function that computes A @ x for input x of shape [n] or [n, m].
    n : int
        Matrix dimension.
    k : int
        Number of eigenvalues to compute.
    dtype : torch.dtype
        Data type.
    device : torch.device
        Device to compute on.
    largest : bool, optional
        If True, compute largest eigenvalues. Default: True.
    maxiter : int, optional
        Maximum iterations. Default: 1000.
    tol : float, optional
        Convergence tolerance. Default: 1e-8.
    
    Returns
    -------
    Tuple[torch.Tensor, torch.Tensor]
        (eigenvalues, eigenvectors) with shapes [k] and [n, k].
    """
    m = min(2 * k, n)
    X = torch.randn(n, m, dtype=dtype, device=device)
    X, _ = torch.linalg.qr(X)
    
    eigenvalues_prev = None
    
    for iteration in range(maxiter):
        AX = A_matvec(X)
        H = X.T @ AX
        eigenvalues, eigenvectors = torch.linalg.eigh(H)
        
        if largest:
            idx = eigenvalues.argsort(descending=True)
        else:
            idx = eigenvalues.argsort()
        
        eigenvalues = eigenvalues[idx]
        eigenvectors = eigenvectors[:, idx]
        X = X @ eigenvectors
        
        if eigenvalues_prev is not None:
            diff = (eigenvalues[:k] - eigenvalues_prev[:k]).abs()
            if (diff < tol * eigenvalues[:k].abs().clamp(min=1e-10)).all():
                break
        eigenvalues_prev = eigenvalues.clone()
        
        if iteration < maxiter - 1:
            AX = A_matvec(X)
            residual = AX - X * eigenvalues.unsqueeze(0)
            combined = torch.cat([X[:, :k], residual[:, :k]], dim=1)
            X, _ = torch.linalg.qr(combined)
            if X.size(1) < m:
                extra = torch.randn(n, m - X.size(1), dtype=dtype, device=device)
                X = torch.cat([X, extra], dim=1)
                X, _ = torch.linalg.qr(X)
    
    return eigenvalues[:k], X[:, :k]


def _power_iteration_svd(
    A_matvec,
    At_matvec,
    m: int,
    n: int,
    k: int,
    dtype: torch.dtype,
    device: torch.device,
    maxiter: int = 100,
    tol: float = 1e-6
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Power iteration based SVD for sparse matrices on any device.
    
    Parameters
    ----------
    A_matvec : callable
        Function that computes A @ x.
    At_matvec : callable
        Function that computes A^T @ x.
    m, n : int
        Matrix dimensions (m rows, n columns).
    k : int
        Number of singular values to compute.
    dtype : torch.dtype
        Data type.
    device : torch.device
        Device to compute on.
    maxiter : int, optional
        Maximum iterations. Default: 100.
    tol : float, optional
        Convergence tolerance. Default: 1e-6.
    
    Returns
    -------
    Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
        (U, S, Vt) with shapes [m, k], [k], [k, n].
    """
    V = torch.randn(n, k, dtype=dtype, device=device)
    V, _ = torch.linalg.qr(V)
    
    for _ in range(maxiter):
        U = A_matvec(V)
        U, R = torch.linalg.qr(U)
        V_new = At_matvec(U)
        S = V_new.norm(dim=0)
        V_new = V_new / S.unsqueeze(0).clamp(min=1e-10)
        diff = (V_new - V).norm()
        V = V_new
        if diff < tol:
            break
    
    return U, S, V.T


# =============================================================================
# SparseTensor Class
# =============================================================================

[文档] class SparseTensor: """ Wrapper class for PyTorch sparse tensors with batched and block support. Supports tensors with shape [...batch, M, N, ...block] where: - Leading dimensions [...batch] are batch dimensions - (M, N) are the sparse matrix dimensions (at sparse_dim positions) - Trailing dimensions [...block] are block dimensions Parameters ---------- values : torch.Tensor Non-zero values with shape: - Simple: [nnz] - Batched: [...batch, nnz] - Block: [nnz, *block_shape] - Batched+Block: [...batch, nnz, *block_shape] row_indices : torch.Tensor Row indices with shape [nnz]. Must be on the same device as values. col_indices : torch.Tensor Column indices with shape [nnz]. Must be on the same device as values. shape : Tuple[int, ...] Full tensor shape [...batch, M, N, *block_shape]. sparse_dim : Tuple[int, int], optional Which dimensions are sparse (M, N). Default: (-2, -1) meaning last two before any block dimensions. Attributes ---------- values : torch.Tensor The non-zero values. row_indices : torch.Tensor Row indices of non-zeros. col_indices : torch.Tensor Column indices of non-zeros. shape : Tuple[int, ...] Full tensor shape. sparse_shape : Tuple[int, int] The (M, N) dimensions. batch_shape : Tuple[int, ...] The batch dimensions. block_shape : Tuple[int, ...] The block dimensions. Examples -------- **1. Simple 2D Sparse Matrix [M, N]** >>> import torch >>> from torch_sla import SparseTensor >>> >>> # Create a 3x3 tridiagonal matrix in COO format >>> val = torch.tensor([4.0, -1.0, -1.0, 4.0, -1.0, -1.0, 4.0]) >>> row = torch.tensor([0, 0, 1, 1, 1, 2, 2]) >>> col = torch.tensor([0, 1, 0, 1, 2, 1, 2]) >>> A = SparseTensor(val, row, col, (3, 3)) >>> print(A) SparseTensor(shape=(3, 3), sparse=(3, 3), nnz=7, dtype=torch.float64, device=cpu) >>> >>> # Solve Ax = b >>> b = torch.tensor([1.0, 2.0, 3.0]) >>> x = A.solve(b) **2. Batched Sparse Matrices [B, M, N]** Same sparsity pattern, different values for each batch. >>> # 4 matrices, each 3x3, same structure >>> batch_size = 4 >>> val_batch = val.unsqueeze(0).expand(batch_size, -1).clone() # [4, 7] >>> for i in range(batch_size): ... val_batch[i] = val * (1.0 + 0.1 * i) # Scale each matrix >>> >>> A_batch = SparseTensor(val_batch, row, col, (4, 3, 3)) >>> print(A_batch.batch_shape) # (4,) >>> print(A_batch.sparse_shape) # (3, 3) >>> >>> # Batched solve >>> b_batch = torch.randn(4, 3) >>> x_batch = A_batch.solve(b_batch) # [4, 3] **3. Multi-Dimensional Batch [B1, B2, M, N]** >>> B1, B2 = 2, 3 # e.g., 2 materials x 3 temperatures >>> val_batch = val.unsqueeze(0).unsqueeze(0).expand(B1, B2, -1).clone() # [2, 3, 7] >>> A_multi = SparseTensor(val_batch, row, col, (B1, B2, 3, 3)) >>> print(A_multi.batch_shape) # (2, 3) >>> >>> b_multi = torch.randn(B1, B2, 3) >>> x_multi = A_multi.solve(b_multi) # [2, 3, 3] **4. Block Sparse Matrix [M, N, K, K] (Block Size K)** Each non-zero entry is a KxK dense block instead of a scalar. >>> # 2x2 block matrix with 2x2 blocks = 4x4 total >>> block_size = 2 >>> nnz = 3 # 3 non-zero blocks >>> >>> # Values: [nnz, K, K] = [3, 2, 2] >>> val_block = torch.randn(nnz, block_size, block_size) >>> row_block = torch.tensor([0, 0, 1]) # Block row indices >>> col_block = torch.tensor([0, 1, 1]) # Block col indices >>> >>> # Shape: (num_block_rows, num_block_cols, block_size, block_size) >>> A_block = SparseTensor(val_block, row_block, col_block, (2, 2, 2, 2)) >>> print(A_block.block_shape) # (2, 2) >>> print(A_block.sparse_shape) # (2, 2) - number of blocks >>> print(A_block.shape) # (2, 2, 2, 2) - full shape **5. Batched Block Sparse [B, M, N, K, K]** >>> batch_size = 4 >>> val_batch_block = torch.randn(batch_size, nnz, block_size, block_size) # [4, 3, 2, 2] >>> A_batch_block = SparseTensor(val_batch_block, row_block, col_block, (4, 2, 2, 2, 2)) >>> print(A_batch_block.batch_shape) # (4,) >>> print(A_batch_block.block_shape) # (2, 2) **6. Create from Dense Matrix** >>> A_dense = torch.randn(100, 100) >>> A_dense[A_dense.abs() < 0.5] = 0 # Sparsify >>> A = SparseTensor.from_dense(A_dense) **7. Create from PyTorch Sparse Tensor** >>> A_torch = torch.randn(100, 100).to_sparse_coo() >>> A = SparseTensor.from_torch_sparse(A_torch) **8. Property Detection** >>> A = SparseTensor(val, row, col, (3, 3)) >>> A.is_symmetric() # tensor(True) - returns tensor for batch support >>> A.is_positive_definite() # tensor(True) >>> A.is_positive_definite('cholesky') # Use Cholesky factorization check **9. Matrix Operations** >>> # Matrix-vector multiply >>> y = A @ x # SparseTensor @ dense vector >>> >>> # Sparse-sparse multiply (returns SparseTensor with sparse gradients) >>> C = A @ A >>> >>> # Norms >>> A.norm('fro') # Frobenius norm >>> >>> # Eigenvalues (symmetric matrices) >>> eigenvalues, eigenvectors = A.eigsh(k=2, which='LM') **10. CUDA Support** >>> A_cuda = A.cuda() >>> x = A_cuda.solve(b.cuda()) # Uses cuDSS or cuSOLVER """ def __init__( self, values: torch.Tensor, row_indices: torch.Tensor, col_indices: torch.Tensor, shape: Tuple[int, ...], sparse_dim: Tuple[int, int] = (-2, -1), ): self.values = values self.row_indices = row_indices self.col_indices = col_indices self._shape = tuple(shape) self._sparse_dim = self._normalize_sparse_dim(sparse_dim, len(shape)) # Cache for computed properties self._is_symmetric_cache = None self._is_positive_definite_cache = None self._validate() def _normalize_sparse_dim(self, sparse_dim: Tuple[int, int], ndim: int) -> Tuple[int, int]: """Normalize negative indices in sparse_dim.""" dim_m = sparse_dim[0] if sparse_dim[0] >= 0 else ndim + sparse_dim[0] dim_n = sparse_dim[1] if sparse_dim[1] >= 0 else ndim + sparse_dim[1] return (dim_m, dim_n) def _validate(self): """Validate tensor dimensions and indices.""" ndim = len(self._shape) dim_m, dim_n = self._sparse_dim if ndim < 2: raise ValueError(f"Shape must have at least 2 dimensions, got {ndim}") if not (0 <= dim_m < ndim and 0 <= dim_n < ndim): raise ValueError(f"sparse_dim {self._sparse_dim} out of range for shape {self._shape}") if dim_m == dim_n: raise ValueError(f"sparse_dim dimensions must be different") # ========================================================================= # Class Methods # =========================================================================
[文档] @classmethod def from_dense( cls, A: torch.Tensor, sparse_dim: Tuple[int, int] = (-2, -1) ) -> "SparseTensor": """ Create SparseTensor from dense tensor. Parameters ---------- A : torch.Tensor Dense tensor with shape [...batch, M, N, ...block]. sparse_dim : Tuple[int, int], optional Which dimensions are sparse. Default: (-2, -1). Returns ------- SparseTensor Sparse representation of A. Examples -------- >>> A_dense = torch.randn(3, 3) >>> A_dense[A_dense.abs() < 0.5] = 0 >>> A = SparseTensor.from_dense(A_dense) """ ndim = A.dim() dim_m = sparse_dim[0] if sparse_dim[0] >= 0 else ndim + sparse_dim[0] dim_n = sparse_dim[1] if sparse_dim[1] >= 0 else ndim + sparse_dim[1] if ndim == 2 and dim_m == 0 and dim_n == 1: A_sparse = A.to_sparse_coo() indices = A_sparse._indices() values = A_sparse._values() return cls(values, indices[0], indices[1], tuple(A.shape), sparse_dim=sparse_dim) perm = [i for i in range(ndim) if i not in (dim_m, dim_n)] + [dim_m, dim_n] A_perm = A.permute(*perm) batch_shape = A_perm.shape[:-2] M, N = A_perm.shape[-2], A_perm.shape[-1] A_flat = A_perm.reshape(-1, M, N) A_2d = A_flat[0].to_sparse_coo() indices = A_2d._indices() row = indices[0] col = indices[1] nnz = row.size(0) values = A_flat[:, row, col] if len(batch_shape) > 0: values = values.reshape(*batch_shape, nnz) else: values = values.squeeze(0) return cls(values, row, col, tuple(A.shape), sparse_dim=sparse_dim)
[文档] @classmethod def from_torch_sparse(cls, A: torch.Tensor) -> "SparseTensor": """ Create SparseTensor from PyTorch sparse tensor. Parameters ---------- A : torch.Tensor PyTorch sparse COO or CSR tensor (2D only). Returns ------- SparseTensor SparseTensor representation. Examples -------- >>> A_coo = torch.randn(3, 3).to_sparse_coo() >>> A = SparseTensor.from_torch_sparse(A_coo) """ if A.layout == torch.sparse_csr: A = A.to_sparse_coo() indices = A._indices() values = A._values() return cls(values, indices[0], indices[1], tuple(A.shape))
# ========================================================================= # Properties # ========================================================================= @property def shape(self) -> Tuple[int, ...]: """Full tensor shape [...batch, M, N, ...block].""" return self._shape @property def sparse_shape(self) -> Tuple[int, int]: """The (M, N) sparse matrix dimensions.""" dim_m, dim_n = self._sparse_dim return (self._shape[dim_m], self._shape[dim_n]) @property def batch_shape(self) -> Tuple[int, ...]: """The batch dimensions before the sparse dimensions.""" dim_m, dim_n = self._sparse_dim min_dim = min(dim_m, dim_n) return self._shape[:min_dim] @property def block_shape(self) -> Tuple[int, ...]: """The block dimensions after the sparse dimensions.""" dim_m, dim_n = self._sparse_dim max_dim = max(dim_m, dim_n) return self._shape[max_dim + 1:] @property def sparse_dim(self) -> Tuple[int, int]: """The dimensions that are sparse (M, N).""" return self._sparse_dim @property def ndim(self) -> int: """Number of dimensions.""" return len(self._shape) @property def nnz(self) -> int: """Number of non-zero elements (per batch/block).""" return self.row_indices.size(0) @property def dtype(self) -> torch.dtype: """Data type of the values.""" return self.values.dtype @property def device(self) -> torch.device: """Device of the tensor.""" return self.values.device @property def is_cuda(self) -> bool: """Whether the tensor is on CUDA.""" return self.values.is_cuda @property def is_batched(self) -> bool: """Whether the tensor has batch dimensions.""" return len(self.batch_shape) > 0 @property def is_block(self) -> bool: """Whether the tensor has block dimensions.""" return len(self.block_shape) > 0 @property def batch_size(self) -> int: """Total number of batch elements (product of batch_shape).""" return math.prod(self.batch_shape) if self.batch_shape else 1 @property def is_square(self) -> bool: """Whether the sparse dimensions are square (M == N).""" M, N = self.sparse_shape return M == N # ========================================================================= # Device and Type Management # =========================================================================
[文档] def to( self, device: Optional[Union[str, torch.device]] = None, dtype: Optional[torch.dtype] = None ) -> "SparseTensor": """ Move tensor to device and/or convert dtype. Parameters ---------- device : str or torch.device, optional Target device (e.g., 'cuda', 'cpu', 'cuda:0'). dtype : torch.dtype, optional Target data type (e.g., torch.float32, torch.float64). Returns ------- SparseTensor New SparseTensor on the target device/dtype. Examples -------- >>> A = SparseTensor(val, row, col, shape) >>> A_cuda = A.to('cuda') >>> A_float32 = A.to(dtype=torch.float32) >>> A_cuda_float32 = A.to('cuda', torch.float32) """ new_values = self.values new_row = self.row_indices new_col = self.col_indices if device is not None: new_values = new_values.to(device) new_row = new_row.to(device) new_col = new_col.to(device) if dtype is not None: new_values = new_values.to(dtype) result = SparseTensor( new_values, new_row, new_col, self._shape, sparse_dim=self._sparse_dim ) return result
[文档] def cuda(self, device: Optional[int] = None) -> "SparseTensor": """ Move tensor to CUDA device. Parameters ---------- device : int, optional CUDA device index. Default: current device. Returns ------- SparseTensor Tensor on CUDA. """ if device is None: return self.to('cuda') return self.to(f'cuda:{device}')
[文档] def cpu(self) -> "SparseTensor": """ Move tensor to CPU. Returns ------- SparseTensor Tensor on CPU. """ return self.to('cpu')
[文档] def float(self) -> "SparseTensor": """Convert to float32.""" return self.to(dtype=torch.float32)
[文档] def double(self) -> "SparseTensor": """Convert to float64.""" return self.to(dtype=torch.float64)
[文档] def half(self) -> "SparseTensor": """Convert to float16.""" return self.to(dtype=torch.float16)
# ========================================================================= # Conversion Methods # =========================================================================
[文档] def to_torch_sparse(self, batch_idx: Optional[Tuple[int, ...]] = None) -> torch.Tensor: """ Convert to PyTorch sparse COO tensor. Parameters ---------- batch_idx : Tuple[int, ...], optional For batched tensors, which batch element to convert. Default: (0, 0, ...) for first batch element. Returns ------- torch.Tensor PyTorch sparse COO tensor. """ if self.is_batched: if batch_idx is None: batch_idx = (0,) * len(self.batch_shape) vals = self.values[batch_idx] else: vals = self.values M, N = self.sparse_shape indices = torch.stack([self.row_indices, self.col_indices], dim=0) if self.is_block: return torch.sparse_coo_tensor(indices, vals, (M, N) + self.block_shape) else: return torch.sparse_coo_tensor(indices, vals, (M, N))
[文档] def to_dense(self, batch_idx: Optional[Tuple[int, ...]] = None) -> torch.Tensor: """ Convert to dense tensor. Parameters ---------- batch_idx : Tuple[int, ...], optional For batched tensors, which batch element to convert. Returns ------- torch.Tensor Dense tensor. """ return self.to_torch_sparse(batch_idx).to_dense()
[文档] def to_csr(self, batch_idx: Optional[Tuple[int, ...]] = None) -> torch.Tensor: """ Convert to CSR format. Parameters ---------- batch_idx : Tuple[int, ...], optional For batched tensors, which batch element to convert. Returns ------- torch.Tensor PyTorch sparse CSR tensor. """ return self.to_torch_sparse(batch_idx).to_sparse_csr()
[文档] def partition( self, num_partitions: int, coords: Optional[torch.Tensor] = None, partition_method: str = 'auto', verbose: bool = False ) -> "DSparseTensor": """ Partition into a distributed sparse tensor. Creates a DSparseTensor with automatic domain decomposition. This is useful for distributed computing and parallel solvers. Parameters ---------- num_partitions : int Number of partitions to create coords : torch.Tensor, optional Node coordinates for geometric partitioning [num_nodes, dim]. Required for 'rcb' and 'slicing' methods. partition_method : str Partitioning method: - 'auto': Auto-select (uses 'rcb' if coords provided, else 'metis') - 'metis': Graph-based partitioning (requires pymetis) - 'rcb': Recursive Coordinate Bisection (requires coords) - 'slicing': Simple coordinate slicing (requires coords) - 'simple': Simple 1D partitioning by node index verbose : bool Whether to print partition info Returns ------- DSparseTensor Distributed sparse tensor with the specified partitions Example ------- >>> A = SparseTensor(val, row, col, shape) >>> D = A.partition(num_partitions=4) >>> for i in range(4): ... partition = D[i] ... y = partition.matvec(x_local) Notes ----- - Use `D.to_sparse_tensor()` to gather back to a SparseTensor - For distributed training, use `partition_for_rank()` instead """ from .distributed import DSparseTensor if self.is_batched: raise ValueError("partition() does not support batched SparseTensor. " "Use a 2D SparseTensor.") return DSparseTensor( self.values, self.row_indices, self.col_indices, self.sparse_shape, num_partitions=num_partitions, coords=coords, partition_method=partition_method, device=self.device, verbose=verbose )
[文档] def partition_for_rank( self, rank: int, world_size: int, coords: Optional[torch.Tensor] = None, partition_method: str = 'simple', verbose: bool = False ) -> "DSparseMatrix": """ Get partition for a specific rank in distributed environment. This is the recommended API for multi-process distributed computing. Each rank calls this method with its own rank ID to get its local partition. The partitioning is deterministic and consistent across all ranks. Parameters ---------- rank : int This process's rank (0 to world_size-1) world_size : int Total number of processes coords : torch.Tensor, optional Node coordinates for geometric partitioning partition_method : str Partitioning method ('simple', 'metis', 'rcb', 'slicing') verbose : bool Print partition info Returns ------- DSparseMatrix Local partition for this rank Example ------- >>> # In multi-process code: >>> A = SparseTensor(val, row, col, shape) >>> partition = A.partition_for_rank(rank, world_size) >>> y_local = partition.matvec(x_local) Notes ----- - This uses `DSparseTensor.from_global_distributed()` internally, which broadcasts partition IDs from rank 0 for consistency. - Requires `torch.distributed` to be initialized. """ from .distributed import DSparseTensor if self.is_batched: raise ValueError("partition_for_rank() does not support batched SparseTensor.") return DSparseTensor.from_global_distributed( self.values, self.row_indices, self.col_indices, self.sparse_shape, rank=rank, world_size=world_size, coords=coords, partition_method=partition_method, verbose=verbose )
[文档] def T(self) -> "SparseTensor": """ Transpose the sparse dimensions. Returns ------- SparseTensor Transposed tensor with row/col indices swapped. """ new_shape = list(self._shape) dim_m, dim_n = self._sparse_dim new_shape[dim_m], new_shape[dim_n] = new_shape[dim_n], new_shape[dim_m] result = SparseTensor( self.values, self.col_indices, # Swap row and col self.row_indices, tuple(new_shape), sparse_dim=self._sparse_dim ) return result
[文档] def flatten_blocks(self) -> "SparseTensor": """ Flatten block dimensions into the sparse (M, N) dimensions. For a block-sparse tensor with shape [...batch, M, N, *block_shape], this creates a new tensor with shape [...batch, M*block_M, N*block_N] where each block entry becomes multiple scalar entries. Returns ------- SparseTensor Flattened tensor without block dimensions. Example ------- >>> # Block sparse: shape (10, 10, 2, 2), block_shape=(2, 2) >>> A = SparseTensor(val, row, col, (10, 10, 2, 2)) >>> A_flat = A.flatten_blocks() >>> print(A_flat.shape) # (20, 20) >>> print(A_flat.nnz) # nnz * 4 (each block has 4 elements) Notes ----- - Only works for 2D block shapes (block_M, block_N). - Use `unflatten_blocks(block_shape)` to reverse this operation. - The flattened tensor's sparsity pattern may have duplicates that need to be coalesced. """ if not self.is_block: return self # No blocks, return as is block_shape = self.block_shape if len(block_shape) != 2: raise ValueError(f"flatten_blocks only supports 2D blocks, got {block_shape}") block_M, block_N = block_shape M, N = self.sparse_shape batch_shape = self.batch_shape nnz = self.nnz # New sparse shape new_M = M * block_M new_N = N * block_N # Expand block entries into individual entries # Original: values shape [...batch, nnz, block_M, block_N] # New: values shape [...batch, nnz * block_M * block_N] # Create new row/col indices # For each (row, col) block at position (i, j), create indices: # (i*block_M + bi, j*block_N + bj) for bi in [0, block_M), bj in [0, block_N) row = self.row_indices # [nnz] col = self.col_indices # [nnz] # Create block offsets block_offsets = torch.arange(block_M * block_N, device=self.device) bi = block_offsets // block_N # [block_M * block_N] bj = block_offsets % block_N # [block_M * block_N] # Expand row/col to new indices # new_row[k * block_M * block_N + offset] = row[k] * block_M + bi[offset] new_row = (row.unsqueeze(-1) * block_M + bi.unsqueeze(0)).reshape(-1) # [nnz * block_size] new_col = (col.unsqueeze(-1) * block_N + bj.unsqueeze(0)).reshape(-1) # [nnz * block_size] # Flatten values if len(batch_shape) > 0: # [...batch, nnz, block_M, block_N] -> [...batch, nnz * block_M * block_N] vals = self.values.reshape(*batch_shape, nnz * block_M * block_N) else: # [nnz, block_M, block_N] -> [nnz * block_M * block_N] vals = self.values.reshape(nnz * block_M * block_N) new_shape = batch_shape + (new_M, new_N) return SparseTensor( vals, new_row, new_col, new_shape, sparse_dim=self._sparse_dim )
[文档] def unflatten_blocks(self, block_shape: Tuple[int, int]) -> "SparseTensor": """ Restore block structure from a flattened tensor. This is the inverse of `flatten_blocks()`. It groups scalar entries back into block entries. Parameters ---------- block_shape : Tuple[int, int] The (block_M, block_N) dimensions to create. M and N must be divisible by block_M and block_N respectively. Returns ------- SparseTensor Block-sparse tensor with the specified block shape. Example ------- >>> A_flat = SparseTensor(val, row, col, (20, 20)) >>> A_block = A_flat.unflatten_blocks((2, 2)) >>> print(A_block.shape) # (10, 10, 2, 2) >>> print(A_block.block_shape) # (2, 2) Notes ----- - Requires that the sparsity pattern is block-aligned. - All block entries must be present (dense within each block). - For sparse blocks, use `to_block_sparse()` instead. """ if self.is_block: raise ValueError("Tensor already has block structure. Use flatten_blocks first.") if len(block_shape) != 2: raise ValueError(f"block_shape must be 2D, got {block_shape}") block_M, block_N = block_shape M, N = self.sparse_shape batch_shape = self.batch_shape if M % block_M != 0 or N % block_N != 0: raise ValueError( f"Sparse shape ({M}, {N}) not divisible by block_shape ({block_M}, {block_N})" ) new_M = M // block_M new_N = N // block_N block_size = block_M * block_N row = self.row_indices col = self.col_indices nnz = self.nnz if nnz % block_size != 0: raise ValueError( f"Number of non-zeros ({nnz}) not divisible by block size ({block_size}). " "The sparsity pattern may not be block-aligned." ) # Compute block indices block_row = row // block_M # Which block row block_col = col // block_N # Which block col local_row = row % block_M # Position within block local_col = col % block_N # Position within block # Group entries by (block_row, block_col) # Create a unique block ID for sorting block_id = block_row * new_N + block_col # Sort by block_id, then by local position local_offset = local_row * block_N + local_col sort_key = block_id * block_size + local_offset sort_idx = torch.argsort(sort_key) sorted_block_id = block_id[sort_idx] sorted_local_offset = local_offset[sort_idx] # Extract unique blocks unique_blocks, counts = torch.unique_consecutive(sorted_block_id, return_counts=True) if not torch.all(counts == block_size): raise ValueError( "Not all blocks are complete. Each block must have exactly " f"{block_size} entries." ) num_blocks = unique_blocks.size(0) new_row_indices = unique_blocks // new_N new_col_indices = unique_blocks % new_N # Reshape values to include block dimensions if len(batch_shape) > 0: # Sort values: [...batch, nnz] -> [...batch, num_blocks * block_size] sorted_vals = self.values[..., sort_idx] new_vals = sorted_vals.reshape(*batch_shape, num_blocks, block_M, block_N) else: sorted_vals = self.values[sort_idx] new_vals = sorted_vals.reshape(num_blocks, block_M, block_N) new_shape = batch_shape + (new_M, new_N, block_M, block_N) return SparseTensor( new_vals, new_row_indices, new_col_indices, new_shape, sparse_dim=self._sparse_dim )
# ========================================================================= # Property Detection (returns tensor for batched) # =========================================================================
[文档] def is_symmetric( self, atol: float = 1e-8, rtol: float = 1e-5, force_recompute: bool = False ) -> torch.Tensor: """ Check if the matrix is symmetric (A == A^T). For batched tensors, checks each matrix independently and returns a boolean tensor with shape matching the batch dimensions. Parameters ---------- atol : float, optional Absolute tolerance for comparison. Default: 1e-8. rtol : float, optional Relative tolerance for comparison. Default: 1e-5. force_recompute : bool, optional If True, recompute even if cached. Default: False. Returns ------- torch.Tensor Boolean tensor with shape: - [] (scalar) for non-batched tensors - [*batch_shape] for batched tensors Examples -------- >>> A = SparseTensor(val, row, col, (3, 3)) >>> A.is_symmetric() # tensor(True) or tensor(False) >>> A_batch = SparseTensor(val_batch, row, col, (4, 3, 3)) >>> A_batch.is_symmetric() # tensor([True, True, True, True]) """ if self._is_symmetric_cache is not None and not force_recompute: return self._is_symmetric_cache if not self.is_square: result = torch.tensor(False, device=self.device) if self.is_batched: result = result.expand(self.batch_shape) self._is_symmetric_cache = result return result row = self.row_indices col = self.col_indices M, N = self.sparse_shape # Create hash for (row, col) pairs forward_hash = row * N + col transpose_hash = col * N + row # Sort both to align forward_order = forward_hash.argsort() transpose_order = transpose_hash.argsort() sorted_forward_hash = forward_hash[forward_order] sorted_transpose_hash = transpose_hash[transpose_order] # Check sparsity pattern if not torch.equal(sorted_forward_hash, sorted_transpose_hash): result = torch.tensor(False, device=self.device) if self.is_batched: result = result.expand(self.batch_shape).clone() self._is_symmetric_cache = result return result # Compare values if self.is_batched: B = self.batch_size vals_flat = self.values.reshape(B, self.nnz) vals_forward = vals_flat[:, forward_order] vals_transpose = vals_flat[:, transpose_order] diff = (vals_forward - vals_transpose).abs() threshold = atol + rtol * vals_forward.abs() is_sym = (diff <= threshold).all(dim=-1) result = is_sym.reshape(self.batch_shape) else: vals_forward = self.values[forward_order] vals_transpose = self.values[transpose_order] diff = (vals_forward - vals_transpose).abs() threshold = atol + rtol * vals_forward.abs() result = torch.tensor((diff <= threshold).all().item(), device=self.device) self._is_symmetric_cache = result return result
[文档] def is_positive_definite( self, method: Literal["gershgorin", "cholesky", "eigenvalue"] = "gershgorin", force_recompute: bool = False ) -> torch.Tensor: """ Check if the matrix is positive definite. For batched tensors, checks each matrix independently and returns a boolean tensor with shape matching the batch dimensions. Parameters ---------- method : {"gershgorin", "cholesky", "eigenvalue"}, optional Method for checking: - "gershgorin": Fast check using Gershgorin circles (sufficient but not necessary) - "cholesky": Try Cholesky decomposition (necessary and sufficient, slower) - "eigenvalue": Check smallest eigenvalues (necessary and sufficient, slowest) Default: "gershgorin". force_recompute : bool, optional If True, recompute even if cached. Default: False. Returns ------- torch.Tensor Boolean tensor with shape: - [] (scalar) for non-batched tensors - [*batch_shape] for batched tensors Examples -------- >>> A = SparseTensor(val, row, col, (3, 3)) >>> A.is_positive_definite() # tensor(True) or tensor(False) >>> A.is_positive_definite(method="cholesky") # More accurate check >>> A_batch = SparseTensor(val_batch, row, col, (4, 3, 3)) >>> A_batch.is_positive_definite() # tensor([True, True, True, True]) """ if self._is_positive_definite_cache is not None and not force_recompute: return self._is_positive_definite_cache if not self.is_square: result = torch.tensor(False, device=self.device) if self.is_batched: result = result.expand(self.batch_shape).clone() self._is_positive_definite_cache = result return result row = self.row_indices col = self.col_indices M, N = self.sparse_shape if method == "gershgorin": result = self._check_pd_gershgorin() elif method == "cholesky": result = self._check_pd_cholesky() else: # eigenvalue result = self._check_pd_eigenvalue() self._is_positive_definite_cache = result return result
def _check_pd_gershgorin(self) -> torch.Tensor: """Check positive definiteness using Gershgorin circles.""" row = self.row_indices col = self.col_indices M, N = self.sparse_shape is_diag = (row == col) if self.is_batched: B = self.batch_size vals_flat = self.values.reshape(B, self.nnz) # Diagonal elements diag_rows = row[is_diag] diag_vals = vals_flat[:, is_diag] # [B, num_diag] diag = torch.zeros(B, M, dtype=self.dtype, device=self.device) diag.scatter_(1, diag_rows.unsqueeze(0).expand(B, -1), diag_vals) # Off-diagonal sum is_offdiag = ~is_diag offdiag_rows = row[is_offdiag] offdiag_vals = vals_flat[:, is_offdiag].abs() # [B, num_offdiag] offdiag_sum = torch.zeros(B, M, dtype=self.dtype, device=self.device) offdiag_sum.scatter_add_(1, offdiag_rows.unsqueeze(0).expand(B, -1), offdiag_vals) # Check: diag > offdiag_sum AND diag > 0 is_pd = ((diag > offdiag_sum) & (diag > 0)).all(dim=-1) return is_pd.reshape(self.batch_shape) else: diag_rows = row[is_diag] diag_vals = self.values[is_diag] diag = torch.zeros(M, dtype=self.dtype, device=self.device) diag.scatter_(0, diag_rows, diag_vals) is_offdiag = ~is_diag offdiag_rows = row[is_offdiag] offdiag_vals = self.values[is_offdiag].abs() offdiag_sum = torch.zeros(M, dtype=self.dtype, device=self.device) offdiag_sum.scatter_add_(0, offdiag_rows, offdiag_vals) is_pd = ((diag > offdiag_sum) & (diag > 0)).all() return torch.tensor(is_pd.item(), device=self.device) def _check_pd_cholesky(self) -> torch.Tensor: """Check positive definiteness using Cholesky decomposition.""" if self.is_batched: results = [] for idx in self._batch_indices(): try: A_dense = self.to_dense(idx) torch.linalg.cholesky(A_dense) results.append(True) except RuntimeError: results.append(False) return torch.tensor(results, device=self.device).reshape(self.batch_shape) else: try: A_dense = self.to_dense() torch.linalg.cholesky(A_dense) return torch.tensor(True, device=self.device) except RuntimeError: return torch.tensor(False, device=self.device) def _check_pd_eigenvalue(self) -> torch.Tensor: """Check positive definiteness using eigenvalue computation.""" if self.is_batched: results = [] for idx in self._batch_indices(): try: A_dense = self.to_dense(idx) eigenvalues = torch.linalg.eigvalsh(A_dense) results.append((eigenvalues > 0).all().item()) except Exception: results.append(False) return torch.tensor(results, device=self.device).reshape(self.batch_shape) else: try: A_dense = self.to_dense() eigenvalues = torch.linalg.eigvalsh(A_dense) return torch.tensor((eigenvalues > 0).all().item(), device=self.device) except Exception: return torch.tensor(False, device=self.device) def _batch_indices(self): """Generate all batch index tuples.""" import itertools ranges = [range(s) for s in self.batch_shape] return itertools.product(*ranges) # ========================================================================= # Graph / Connected Components # =========================================================================
[文档] def connected_components(self) -> Tuple[torch.Tensor, int]: """ Find connected components of the graph represented by this sparse matrix. Uses union-find algorithm for efficiency. Treats the matrix as an undirected graph adjacency matrix. Returns ------- labels : torch.Tensor Component label for each node, shape [N]. Labels are in range [0, num_components). num_components : int Number of connected components. Notes ----- - Only works for non-batched 2D matrices - Matrix is treated as undirected (edges in either direction count) - Self-loops are ignored for connectivity Examples -------- >>> # Block diagonal matrix with 3 components >>> A = SparseTensor(val, row, col, (100, 100)) >>> labels, num_comp = A.connected_components() >>> print(f"Found {num_comp} components") """ if self.is_batched: raise NotImplementedError("connected_components not supported for batched tensors") M, N = self.sparse_shape if M != N: raise ValueError("connected_components requires square matrix") # Union-Find with path compression and union by rank parent = torch.arange(N, device=self.device, dtype=torch.long) rank = torch.zeros(N, device=self.device, dtype=torch.long) def find(x: int) -> int: """Find root with path compression.""" root = x while parent[root].item() != root: root = parent[root].item() # Path compression while parent[x].item() != root: next_x = parent[x].item() parent[x] = root x = next_x return root def union(x: int, y: int): """Union by rank.""" rx, ry = find(x), find(y) if rx == ry: return if rank[rx] < rank[ry]: rx, ry = ry, rx parent[ry] = rx if rank[rx] == rank[ry]: rank[rx] += 1 # Process all edges row = self.row_indices.cpu() col = self.col_indices.cpu() for i in range(len(row)): r, c = row[i].item(), col[i].item() if r != c: # Skip self-loops union(r, c) # Find all roots and relabel labels = torch.zeros(N, dtype=torch.long, device=self.device) for i in range(N): labels[i] = find(i) # Relabel to consecutive integers starting from 0 unique_labels = labels.unique() num_components = len(unique_labels) label_map = torch.zeros(N, dtype=torch.long, device=self.device) for new_label, old_label in enumerate(unique_labels): label_map[labels == old_label] = new_label return label_map, num_components
[文档] def has_isolated_components(self) -> bool: """ Check if the matrix has multiple connected components. Returns ------- bool True if matrix has more than one connected component. Examples -------- >>> A = SparseTensor(val, row, col, (100, 100)) >>> if A.has_isolated_components(): ... components = A.to_connected_components() """ _, num_components = self.connected_components() return num_components > 1
[文档] def to_connected_components(self) -> "SparseTensorList": """ Split the matrix into a list of connected component subgraphs. Each component becomes a separate SparseTensor with reindexed nodes. Returns ------- SparseTensorList List of SparseTensors, one per connected component. Notes ----- - Each component's nodes are reindexed from 0 - Original node indices can be recovered from the mapping Examples -------- >>> A = SparseTensor(val, row, col, (100, 100)) >>> components = A.to_connected_components() >>> print(f"Split into {len(components)} components") >>> for i, comp in enumerate(components): ... print(f" Component {i}: {comp.shape}") """ if self.is_batched: raise NotImplementedError("to_connected_components not supported for batched tensors") labels, num_components = self.connected_components() if num_components == 1: # Single component, return list with self return SparseTensorList([self]) # Split into components components = [] row = self.row_indices col = self.col_indices val = self.values for comp_id in range(num_components): # Find nodes in this component node_mask = (labels == comp_id) comp_nodes = torch.where(node_mask)[0] num_comp_nodes = len(comp_nodes) # Create mapping from old to new indices old_to_new = torch.full((self.sparse_shape[0],), -1, dtype=torch.long, device=self.device) old_to_new[comp_nodes] = torch.arange(num_comp_nodes, device=self.device) # Find edges within this component row_in_comp = node_mask[row] col_in_comp = node_mask[col] edge_mask = row_in_comp & col_in_comp # Extract and remap edges comp_row = old_to_new[row[edge_mask]] comp_col = old_to_new[col[edge_mask]] comp_val = val[edge_mask] comp_sparse = SparseTensor( comp_val, comp_row, comp_col, (num_comp_nodes, num_comp_nodes) ) components.append(comp_sparse) return SparseTensorList(components)
# ========================================================================= # Matrix Multiplication # ========================================================================= def _spmv_coo(self, x: torch.Tensor) -> torch.Tensor: """ Sparse matrix-vector/matrix multiply using COO format with scatter_add. Computes A @ x where A is this sparse tensor and x is dense. Works on any device without explicit CSR conversion. Parameters ---------- x : torch.Tensor Dense tensor to multiply. Shape depends on batching: - Non-batched: [N] or [N, K] - Batched: [B, N] or [B, N, K] Returns ------- torch.Tensor Result of A @ x. """ row = self.row_indices col = self.col_indices M, N = self.sparse_shape if self.is_batched: batch_shape = self.batch_shape B = self.batch_size vals_flat = self.values.reshape(B, self.nnz) # Determine output dtype via type promotion out_dtype = torch.result_type(self.values, x) if x.dim() == 1: # x: [N] - same for all batches -> result [B, M] x_gathered = x[col] products = vals_flat * x_gathered result = torch.zeros(B, M, dtype=out_dtype, device=self.device) row_expanded = row.unsqueeze(0).expand(B, -1) result.scatter_add_(1, row_expanded, products) return result.reshape(*batch_shape, M) elif x.dim() == len(batch_shape) + 1: # x: [...batch, N] -> result [...batch, M] x_flat = x.reshape(B, N) x_gathered = x_flat[:, col] products = vals_flat * x_gathered result = torch.zeros(B, M, dtype=out_dtype, device=self.device) row_expanded = row.unsqueeze(0).expand(B, -1) result.scatter_add_(1, row_expanded, products) return result.reshape(*batch_shape, M) else: # x: [...batch, N, K] -> result [...batch, M, K] K = x.size(-1) x_flat = x.reshape(B, N, K) x_gathered = x_flat[:, col, :] products = vals_flat.unsqueeze(-1) * x_gathered result = torch.zeros(B, M, K, dtype=out_dtype, device=self.device) row_expanded = row.unsqueeze(0).unsqueeze(-1).expand(B, -1, K) result.scatter_add_(1, row_expanded, products) return result.reshape(*batch_shape, M, K) else: # Determine output dtype via type promotion (handles float32 @ float64, etc.) out_dtype = torch.result_type(self.values, x) if x.dim() == 1: x_gathered = x[col] products = self.values * x_gathered result = torch.zeros(M, dtype=out_dtype, device=self.device) result.scatter_add_(0, row, products) return result else: K = x.size(1) x_gathered = x[col] products = self.values.unsqueeze(1) * x_gathered result = torch.zeros(M, K, dtype=out_dtype, device=self.device) row_expanded = row.unsqueeze(1).expand(-1, K) result.scatter_add_(0, row_expanded, products) return result def _dense_sparse_mm(self, X: torch.Tensor) -> torch.Tensor: """ Dense @ Sparse: X @ A where X is [..., M] or [..., K, M], A is [..., M, N]. Parameters ---------- X : torch.Tensor Dense tensor. Returns ------- torch.Tensor Result of X @ A. """ row = self.row_indices col = self.col_indices M, N = self.sparse_shape if self.is_batched: batch_shape = self.batch_shape B = self.batch_size vals_flat = self.values.reshape(B, self.nnz) # Determine output dtype via type promotion out_dtype = torch.result_type(self.values, X) if X.dim() == 1: X_gathered = X[row] products = vals_flat * X_gathered result = torch.zeros(B, N, dtype=out_dtype, device=self.device) col_expanded = col.unsqueeze(0).expand(B, -1) result.scatter_add_(1, col_expanded, products) return result.reshape(*batch_shape, N) elif X.dim() == len(batch_shape) + 1: X_flat = X.reshape(B, M) X_gathered = X_flat[:, row] products = vals_flat * X_gathered result = torch.zeros(B, N, dtype=out_dtype, device=self.device) col_expanded = col.unsqueeze(0).expand(B, -1) result.scatter_add_(1, col_expanded, products) return result.reshape(*batch_shape, N) else: K = X.size(-2) X_flat = X.reshape(B, K, M) X_gathered = X_flat[:, :, row] products = vals_flat.unsqueeze(1) * X_gathered result = torch.zeros(B, K, N, dtype=out_dtype, device=self.device) col_expanded = col.unsqueeze(0).unsqueeze(0).expand(B, K, -1) result.scatter_add_(2, col_expanded, products) return result.reshape(*batch_shape, K, N) else: # Determine output dtype via type promotion out_dtype = torch.result_type(self.values, X) if X.dim() == 1: X_gathered = X[row] products = self.values * X_gathered result = torch.zeros(N, dtype=out_dtype, device=self.device) result.scatter_add_(0, col, products) return result else: K = X.size(0) X_gathered = X[:, row] products = self.values.unsqueeze(0) * X_gathered result = torch.zeros(K, N, dtype=out_dtype, device=self.device) col_expanded = col.unsqueeze(0).expand(K, -1) result.scatter_add_(1, col_expanded, products) return result def _spsp_multiply(self, other: "SparseTensor") -> "SparseTensor": """ Sparse-Sparse multiplication: A @ B where both are sparse. Uses custom autograd function to provide SPARSE gradients. Memory usage is O(nnz) not O(M*N). Parameters ---------- other : SparseTensor Right-hand side sparse matrix. Returns ------- SparseTensor Result C = A @ B. """ M, K = self.sparse_shape K2, N = other.sparse_shape if K != K2: raise ValueError(f"Inner dimensions don't match: {K} vs {K2}") C_values, C_row, C_col, C_shape = _sparse_sparse_matmul_with_sparse_grad( self.values, self.row_indices, self.col_indices, (M, K), other.values, other.row_indices, other.col_indices, (K, N) ) return SparseTensor(C_values, C_row, C_col, C_shape) def __matmul__(self, other: Union[torch.Tensor, "SparseTensor"]) -> Union[torch.Tensor, "SparseTensor"]: """ Matrix multiplication: A @ other. Supports: - Sparse @ Dense vector: A @ x -> y - Sparse @ Dense matrix: A @ X -> Y - Sparse @ Sparse: A @ B -> C (with sparse gradients) Parameters ---------- other : torch.Tensor or SparseTensor Right-hand side operand. Returns ------- torch.Tensor or SparseTensor Result of multiplication. """ if isinstance(other, SparseTensor): return self._spsp_multiply(other) return self._spmv_coo(other) def __rmatmul__(self, other: torch.Tensor) -> torch.Tensor: """ Dense @ Sparse multiplication: other @ A. Parameters ---------- other : torch.Tensor Left-hand side dense tensor. Returns ------- torch.Tensor Result of multiplication. """ return self._dense_sparse_mm(other) # ========================================================================= # Linear Solve # =========================================================================
[文档] def solve( self, b: torch.Tensor, backend: BackendType = "auto", method: MethodType = "auto", atol: float = 1e-10, maxiter: int = 10000, tol: float = 1e-12, ) -> torch.Tensor: """ Solve the sparse linear system Ax = b. Automatically handles batched tensors: if A is [...batch, M, N] and b is [...batch, M], returns x with shape [...batch, N]. Parameters ---------- b : torch.Tensor Right-hand side vector(s). Shape: - Non-batched: [M] or [M, K] for multiple RHS - Batched: [...batch, M] or [...batch, M, K] backend : {"auto", "scipy", "eigen", "cusolver", "cudss"}, optional Solver backend. Default: "auto" (selects based on device). - "scipy": Uses SciPy's sparse solvers (CPU only) - "eigen": Uses Eigen C++ library (CPU only) - "cusolver": Uses NVIDIA cuSOLVER (CUDA only) - "cudss": Uses NVIDIA cuDSS (CUDA only) method : str, optional Solver method. Default: "auto" (selects based on matrix properties). - Direct methods: "superlu", "umfpack", "lu", "qr", "cholesky", "ldlt" - Iterative methods: "cg", "bicgstab", "gmres", "minres" atol : float, optional Absolute tolerance for iterative solvers. Default: 1e-10. maxiter : int, optional Maximum iterations for iterative solvers. Default: 10000. tol : float, optional Relative tolerance for direct solvers. Default: 1e-12. Returns ------- torch.Tensor Solution x with same batch shape as b. Raises ------ ValueError If matrix is not square. NotImplementedError If block sparse tensors are used (not yet supported). Examples -------- >>> # Simple solve >>> A = SparseTensor(val, row, col, (3, 3)) >>> b = torch.randn(3) >>> x = A.solve(b) >>> # Batched solve >>> A_batch = SparseTensor(val_batch, row, col, (4, 3, 3)) >>> b_batch = torch.randn(4, 3) >>> x_batch = A_batch.solve(b_batch) >>> # Specify backend >>> x = A.solve(b, backend='scipy', method='cg') """ if not self.is_square: raise ValueError("Matrix must be square for solve()") if self.is_block: raise NotImplementedError("solve() not yet supported for block sparse tensors") # Get matrix properties is_sym = self.is_symmetric().all().item() if self.is_batched else self.is_symmetric().item() is_pd = self.is_positive_definite().all().item() if self.is_batched else self.is_positive_definite().item() is_spd = is_sym and is_pd from .linear_solve import spsolve M, N = self.sparse_shape if self.is_batched: batch_shape = self.batch_shape vals_flat = self.values.reshape(-1, self.nnz) b_flat = b.reshape(-1, M) results = [] for i in range(vals_flat.size(0)): x = spsolve( vals_flat[i], self.row_indices, self.col_indices, (M, N), b_flat[i], backend=backend, method=method, atol=atol, maxiter=maxiter, tol=tol, is_symmetric=is_sym, is_spd=is_spd ) results.append(x) return torch.stack(results).reshape(*batch_shape, N) else: return spsolve( self.values, self.row_indices, self.col_indices, (M, N), b, backend=backend, method=method, atol=atol, maxiter=maxiter, tol=tol, is_symmetric=is_sym, is_spd=is_spd )
[文档] def solve_batch( self, values: torch.Tensor, b: torch.Tensor, backend: BackendType = "auto", method: MethodType = "auto", atol: float = 1e-10, maxiter: int = 10000, tol: float = 1e-12 ) -> torch.Tensor: """ Solve with different values but same sparsity structure. This is efficient when you have the same structure but different values (e.g., time-stepping, optimization, parameter sweeps). Parameters ---------- values : torch.Tensor Matrix values. Shape [...batch, nnz] where ... are batch dimensions. All matrices share the same row_indices and col_indices. b : torch.Tensor Right-hand side. Shape [...batch, M]. backend : {"auto", "scipy", "eigen", "cusolver", "cudss"}, optional Solver backend. See solve() for details. Default: "auto". method : str, optional Solver method. See solve() for details. Default: "auto". atol : float, optional Absolute tolerance for iterative solvers. Default: 1e-10. maxiter : int, optional Maximum iterations for iterative solvers. Default: 10000. tol : float, optional Relative tolerance. Default: 1e-12. Returns ------- torch.Tensor Solution x with shape [...batch, N]. Examples -------- >>> # Template matrix >>> A = SparseTensor(val, row, col, (10, 10)) >>> # Batch of different values >>> val_batch = torch.stack([val * (1 + 0.1*i) for i in range(4)]) # [4, nnz] >>> b_batch = torch.randn(4, 10) >>> # Solve all at once >>> x_batch = A.solve_batch(val_batch, b_batch) # [4, 10] """ from .linear_solve import spsolve M, N = self.sparse_shape # Check properties using first batch element temp = SparseTensor(values[0] if values.dim() > 1 else values, self.row_indices, self.col_indices, (M, N)) is_sym = temp.is_symmetric().item() is_pd = temp.is_positive_definite().item() is_spd = is_sym and is_pd if values.dim() > 1: batch_shape = values.shape[:-1] vals_flat = values.reshape(-1, self.nnz) b_flat = b.reshape(-1, M) results = [] for i in range(vals_flat.size(0)): x = spsolve( vals_flat[i], self.row_indices, self.col_indices, (M, N), b_flat[i], backend=backend, method=method, atol=atol, maxiter=maxiter, tol=tol, is_symmetric=is_sym, is_spd=is_spd ) results.append(x) return torch.stack(results).reshape(*batch_shape, N) else: return spsolve( values, self.row_indices, self.col_indices, (M, N), b, backend=backend, method=method, atol=atol, maxiter=maxiter, tol=tol, is_symmetric=is_sym, is_spd=is_spd )
[文档] def nonlinear_solve( self, residual_fn, u0: torch.Tensor, *params, method: Literal['newton', 'picard', 'anderson'] = 'newton', tol: float = 1e-6, atol: float = 1e-10, max_iter: int = 50, line_search: bool = True, verbose: bool = False, linear_solver: BackendType = 'pytorch', linear_method: MethodType = 'cg', ) -> torch.Tensor: """ Solve nonlinear equation F(u, A, θ) = 0 with adjoint-based gradients. The SparseTensor A is automatically passed as the first parameter to the residual function, enabling gradients to flow through A's values. Parameters ---------- residual_fn : Callable Function F(u, A, *params) -> residual tensor. - u: Current solution estimate - A: This SparseTensor (passed automatically) - *params: Additional parameters with requires_grad=True u0 : torch.Tensor Initial guess for solution. *params : torch.Tensor Additional parameters (e.g., boundary conditions, coefficients). Tensors with requires_grad=True will receive gradients. method : {'newton', 'picard', 'anderson'}, optional Nonlinear solver method: - 'newton': Newton-Raphson with line search (default, fast) - 'picard': Fixed-point iteration (simple, slow) - 'anderson': Anderson acceleration (memory efficient) tol : float, optional Relative convergence tolerance. Default: 1e-6. atol : float, optional Absolute convergence tolerance. Default: 1e-10. max_iter : int, optional Maximum nonlinear iterations. Default: 50. line_search : bool, optional Use Armijo line search for Newton. Default: True. verbose : bool, optional Print convergence information. Default: False. linear_solver : str, optional Backend for linear solves. Default: 'pytorch'. linear_method : str, optional Method for linear solves. Default: 'cg'. Returns ------- torch.Tensor Solution u* satisfying F(u*, A, θ) ≈ 0. Examples -------- >>> # Nonlinear PDE: A @ u + u² = f >>> def residual(u, A, f): ... return A @ u + u**2 - f ... >>> A = SparseTensor(val, row, col, (n, n)) >>> f = torch.randn(n, requires_grad=True) >>> u0 = torch.zeros(n) >>> >>> u = A.nonlinear_solve(residual, u0, f, method='newton') >>> >>> # Gradients flow via adjoint method >>> loss = u.sum() >>> loss.backward() >>> print(f.grad) # ∂u/∂f >>> print(A.values.grad) # ∂u/∂A (if A.values.requires_grad) >>> # Nonlinear elasticity: K(u) @ u = F >>> def residual_elasticity(u, K, F, material): ... # K depends on displacement through material nonlinearity ... return K @ u - F + material * u**3 ... >>> u = K.nonlinear_solve(residual_elasticity, u0, F, material) """ from .nonlinear_solve import nonlinear_solve as _nonlinear_solve # Wrap residual_fn to pass SparseTensor as matvec M, N = self.sparse_shape def wrapped_residual(u, *all_params): # First param is the values tensor, rest are user params # Reconstruct sparse matvec capability return residual_fn(u, self, *all_params) # Include self.values in params if it requires grad all_params = params return _nonlinear_solve( wrapped_residual, u0, *all_params, method=method, tol=tol, atol=atol, max_iter=max_iter, line_search=line_search, verbose=verbose, linear_solver=linear_solver, linear_method=linear_method, )
# ========================================================================= # Norms # =========================================================================
[文档] def norm(self, ord: Literal['fro', 1, 2] = 'fro') -> torch.Tensor: """ Compute matrix norm. For batched tensors, returns norm for each batch element. Parameters ---------- ord : {'fro', 1, 2}, optional Norm type: - 'fro': Frobenius norm (default) - 1: Maximum absolute column sum - 2: Spectral norm (largest singular value) Returns ------- torch.Tensor Norm value(s). Shape [] for non-batched, [*batch_shape] for batched. Examples -------- >>> A = SparseTensor(val, row, col, (3, 3)) >>> A.norm('fro') # tensor(5.0) >>> A_batch = SparseTensor(val_batch, row, col, (4, 3, 3)) >>> A_batch.norm('fro') # tensor([5.0, 5.0, 5.0, 5.0]) """ if self.is_batched: batch_shape = self.batch_shape vals_flat = self.values.reshape(-1, self.nnz) norms = [] for i in range(vals_flat.size(0)): if ord == 'fro': norms.append(vals_flat[i].norm()) else: idx = self._flat_to_batch_idx(i) A_dense = self.to_dense(idx) norms.append(torch.linalg.norm(A_dense, ord=ord)) return torch.stack(norms).reshape(*batch_shape) else: if ord == 'fro': return self.values.norm() if self.is_cuda or not is_scipy_available(): A = self.to_dense() return torch.linalg.norm(A, ord=ord) M, N = self.sparse_shape return scipy_norm(self.values, self.row_indices, self.col_indices, (M, N), ord=ord)
def _flat_to_batch_idx(self, flat_idx: int) -> Tuple[int, ...]: """Convert flat batch index to tuple.""" idx = [] for s in reversed(self.batch_shape): idx.append(flat_idx % s) flat_idx //= s return tuple(reversed(idx)) # ========================================================================= # Visualization # =========================================================================
[文档] def spy( self, batch_idx: Optional[Tuple[int, ...]] = None, ax=None, title: Optional[str] = None, cmap: str = 'viridis', show_grid: bool = True, grid_color: str = '#cccccc', grid_linewidth: float = 0.5, show_colorbar: bool = True, figsize: Tuple[float, float] = (8, 8), save_path: Optional[str] = None, dpi: int = 150, ): """ Visualize the sparsity pattern with values shown as color intensity. Creates a spy plot where each matrix element is rendered as a pixel. Non-zero elements are colored with intensity proportional to the absolute value, while zero elements are shown as white. This provides a pixel-perfect visualization without overlapping markers. Parameters ---------- batch_idx : Tuple[int, ...], optional For batched tensors, which batch element to visualize. Required if the tensor is batched. ax : matplotlib.axes.Axes, optional Axes to plot on. If None, creates a new figure. title : str, optional Plot title. Defaults to showing matrix info. cmap : str, optional Colormap for values. Default: 'viridis'. Other options: 'plasma', 'hot', 'coolwarm', 'Greys', etc. show_grid : bool, optional Whether to show grid lines (only for matrices <= 30x30). Default: True. grid_color : str, optional Color of grid lines. Default: '#cccccc' (light gray). grid_linewidth : float, optional Width of grid lines. Default: 0.5. show_colorbar : bool, optional Whether to show colorbar for values. Default: True. figsize : Tuple[float, float], optional Figure size in inches. Default: (8, 8). save_path : str, optional If provided, save figure to this path. dpi : int, optional DPI for saved figure. Default: 150. Returns ------- ax : matplotlib.axes.Axes The axes object with the plot. Examples -------- >>> A = SparseTensor(val, row, col, (100, 100)) >>> A.spy() # Basic spy plot >>> A.spy(cmap='hot', show_grid=False) # Custom colormap, no grid >>> A.spy(save_path='matrix.png') # Save to file >>> # For batched tensor >>> A_batch = SparseTensor(val_batch, row, col, (4, 100, 100)) >>> A_batch.spy(batch_idx=(0,)) # Visualize first batch element """ try: import matplotlib.pyplot as plt import matplotlib.colors as mcolors from matplotlib.collections import PathCollection except ImportError: raise ImportError("matplotlib is required for spy(). Install with: pip install matplotlib") # Get indices and values if self.is_batched: if batch_idx is None: raise ValueError("batch_idx is required for batched tensors") # Flatten batch_idx to linear index flat_idx = 0 for i, (idx, s) in enumerate(zip(batch_idx, self.batch_shape)): flat_idx = flat_idx * s + idx vals = self.values.reshape(-1, self.nnz)[flat_idx] else: vals = self.values row = self.row_indices.cpu().numpy() col = self.col_indices.cpu().numpy() vals_np = vals.abs().cpu().numpy() M, N = self.sparse_shape # Create figure if needed created_fig = False if ax is None: fig, ax = plt.subplots(figsize=figsize) created_fig = True else: fig = ax.get_figure() # Normalize values for colormap if vals_np.max() > 0: vals_norm = vals_np / vals_np.max() else: vals_norm = vals_np # Build a dense image for visualization # Use NaN for empty cells (will be shown as white) import numpy as np image = np.full((M, N), np.nan, dtype=np.float32) image[row, col] = vals_norm # Create a colormap with white for NaN values cmap_obj = plt.cm.get_cmap(cmap).copy() cmap_obj.set_bad(color='white') # Use imshow for pixel-perfect rendering im = ax.imshow( image, cmap=cmap_obj, aspect='equal', interpolation='nearest', vmin=0, vmax=1, origin='upper' ) # Add colorbar if show_colorbar: cbar = fig.colorbar(im, ax=ax, shrink=0.8, pad=0.02) cbar.set_label('|value| (normalized)', fontsize=10) # Clean up axes - hide ticks for cleaner look ax.set_xticks([]) ax.set_yticks([]) # Add border for spine in ax.spines.values(): spine.set_visible(True) spine.set_color('#333333') spine.set_linewidth(1) # Add grid only for small matrices if show_grid and max(M, N) <= 30: ax.set_xticks([i - 0.5 for i in range(N + 1)], minor=True) ax.set_yticks([i - 0.5 for i in range(M + 1)], minor=True) ax.grid(which='minor', color=grid_color, linewidth=grid_linewidth) ax.tick_params(which='minor', length=0) # Set title if title is None: nnz = len(vals_np) sparsity = 1 - nnz / (M * N) title = f'Sparse Matrix: {M}×{N}, nnz={nnz:,}, sparsity={sparsity:.1%}' ax.set_title(title, fontsize=12, fontweight='bold') # Tight layout if created_fig: plt.tight_layout() # Save if requested if save_path is not None: fig.savefig(save_path, dpi=dpi, bbox_inches='tight') return ax
# ========================================================================= # Eigenvalues and SVD # =========================================================================
[文档] def eigs( self, k: int = 6, which: str = "LM", sigma: Optional[float] = None, return_eigenvectors: bool = True ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """ Compute k eigenvalues and eigenvectors. For batched tensors, computes for each batch element. For CUDA tensors, uses LOBPCG algorithm. Parameters ---------- k : int, optional Number of eigenvalues to compute. Default: 6. which : {"LM", "SM", "LR", "SR", "LA", "SA"}, optional Which eigenvalues to find: - "LM": Largest magnitude (default) - "SM": Smallest magnitude - "LR"/"SR": Largest/smallest real part - "LA"/"SA": Largest/smallest algebraic (for symmetric) sigma : float, optional Find eigenvalues near sigma (shift-invert mode). return_eigenvectors : bool, optional Whether to return eigenvectors. Default: True. Returns ------- eigenvalues : torch.Tensor Shape [k] for non-batched, [*batch_shape, k] for batched. eigenvectors : torch.Tensor or None Shape [M, k] for non-batched, [*batch_shape, M, k] for batched. None if return_eigenvectors is False. Notes ----- **Gradient Support:** - Both CPU and CUDA: Fully differentiable via adjoint method - Uses O(1) graph nodes regardless of iteration count - For symmetric matrices, prefer eigsh() for efficiency **Warning**: For non-symmetric matrices with complex eigenvalues, gradient computation is only supported for the real part. Examples -------- >>> A = SparseTensor(val.requires_grad_(True), row, col, (n, n)) >>> eigenvalues, eigenvectors = A.eigs(k=3) >>> loss = eigenvalues.real.sum() # For complex eigenvalues >>> loss.backward() """ M, N = self.sparse_shape if self.is_batched: batch_shape = self.batch_shape eigenvalues_list = [] eigenvectors_list = [] for idx in self._batch_indices(): A_single = SparseTensor( self.values[idx], self.row_indices, self.col_indices, (M, N) ) evals, evecs = A_single.eigs(k, which, sigma, return_eigenvectors) eigenvalues_list.append(evals) if return_eigenvectors: eigenvectors_list.append(evecs) eigenvalues = torch.stack(eigenvalues_list).reshape(*batch_shape, k) if return_eigenvectors: eigenvectors = torch.stack(eigenvectors_list).reshape(*batch_shape, M, k) return eigenvalues, eigenvectors return eigenvalues, None # For symmetric matrices or when using LA/SA, use eigsh (more efficient) if which in ("LA", "SA") or self.is_symmetric().item(): return self.eigsh(k=k, which=which, sigma=sigma, return_eigenvectors=return_eigenvectors) # Use adjoint-based eigs for differentiability on all devices return EigshAdjoint.apply( self.values, self.row_indices, self.col_indices, (M, N), k, which, return_eigenvectors, self.device )
[文档] def eigsh( self, k: int = 6, which: str = "LM", sigma: Optional[float] = None, return_eigenvectors: bool = True ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """ Compute k eigenvalues for symmetric matrices. More efficient than eigs() for symmetric matrices. Parameters ---------- k : int, optional Number of eigenvalues to compute. Default: 6. which : {"LM", "SM", "LA", "SA"}, optional Which eigenvalues to find: - "LM": Largest magnitude (default) - "SM": Smallest magnitude - "LA"/"SA": Largest/smallest algebraic sigma : float, optional Find eigenvalues near sigma. return_eigenvectors : bool, optional Whether to return eigenvectors. Default: True. Returns ------- eigenvalues : torch.Tensor Shape [k] for non-batched, [*batch_shape, k] for batched. eigenvectors : torch.Tensor or None Shape [M, k] for non-batched, [*batch_shape, M, k] for batched. Notes ----- **Gradient Support:** - Both CPU and CUDA: Fully differentiable via adjoint method - Uses O(1) graph nodes regardless of iteration count - Gradient computed as: ∂L/∂A = Σ_i (∂L/∂λ_i) * v_i @ v_i.T Examples -------- >>> A = SparseTensor(val.requires_grad_(True), row, col, (n, n)) >>> eigenvalues, eigenvectors = A.eigsh(k=3) >>> loss = eigenvalues.sum() >>> loss.backward() # Computes ∂loss/∂val """ M, N = self.sparse_shape if self.is_batched: batch_shape = self.batch_shape eigenvalues_list = [] eigenvectors_list = [] for idx in self._batch_indices(): A_single = SparseTensor( self.values[idx], self.row_indices, self.col_indices, (M, N) ) evals, evecs = A_single.eigsh(k, which, sigma, return_eigenvectors) eigenvalues_list.append(evals) if return_eigenvectors: eigenvectors_list.append(evecs) eigenvalues = torch.stack(eigenvalues_list).reshape(*batch_shape, k) if return_eigenvectors: eigenvectors = torch.stack(eigenvectors_list).reshape(*batch_shape, M, k) return eigenvalues, eigenvectors return eigenvalues, None # Use adjoint-based eigsh for differentiability on all devices return EigshAdjoint.apply( self.values, self.row_indices, self.col_indices, (M, N), k, which, return_eigenvectors, self.device )
[文档] def svd(self, k: int = 6) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Compute truncated SVD. Parameters ---------- k : int, optional Number of singular values to compute. Default: 6. Returns ------- U : torch.Tensor Left singular vectors. Shape [M, k] or [*batch_shape, M, k]. S : torch.Tensor Singular values. Shape [k] or [*batch_shape, k]. Vt : torch.Tensor Right singular vectors. Shape [k, N] or [*batch_shape, k, N]. Notes ----- **Gradient Support:** - CUDA: Fully differentiable (uses power iteration with PyTorch operations) - CPU: NOT differentiable (uses SciPy which breaks gradient chain) For differentiable SVD on CPU, use `A.to_dense()` and `torch.linalg.svd()`. """ M, N = self.sparse_shape if self.is_batched: batch_shape = self.batch_shape U_list, S_list, Vt_list = [], [], [] for idx in self._batch_indices(): A_single = SparseTensor( self.values[idx], self.row_indices, self.col_indices, (M, N) ) U, S, Vt = A_single.svd(k) U_list.append(U) S_list.append(S) Vt_list.append(Vt) U = torch.stack(U_list).reshape(*batch_shape, M, k) S = torch.stack(S_list).reshape(*batch_shape, k) Vt = torch.stack(Vt_list).reshape(*batch_shape, k, N) return U, S, Vt if self.is_cuda: matvec = lambda x: self._spmv_coo(x) matvec_T = lambda x: self.T()._spmv_coo(x) U, S, Vt = _power_iteration_svd(matvec, matvec_T, M, N, k, self.dtype, self.device) return U, S, Vt if not is_scipy_available(): raise RuntimeError("SciPy is required for SVD on CPU") return scipy_svds(self.values, self.row_indices, self.col_indices, (M, N), k=k)
[文档] def condition_number(self, ord: int = 2) -> torch.Tensor: """ Estimate condition number. Parameters ---------- ord : int, optional Norm order for condition number. Default: 2 (spectral). Returns ------- torch.Tensor Condition number. Shape [] or [*batch_shape]. """ M, N = self.sparse_shape if self.is_batched: batch_shape = self.batch_shape cond_list = [] for idx in self._batch_indices(): A_single = SparseTensor( self.values[idx], self.row_indices, self.col_indices, (M, N) ) cond_list.append(A_single.condition_number(ord)) return torch.stack(cond_list).reshape(*batch_shape) if ord == 2: k = min(6, min(M, N) - 2) if k < 2: A_dense = self.to_dense() S = torch.linalg.svdvals(A_dense) return S.max() / S.min() _, S, _ = self.svd(k=k) return S.max() / S.min() norm_A = self.norm(ord=ord) e = torch.randn(M, dtype=self.dtype, device=self.device) e = e / e.norm() x = self.solve(e) return norm_A * x.norm() / e.norm()
[文档] def det(self) -> torch.Tensor: """ Compute determinant of the sparse matrix with gradient support. Uses LU decomposition (CPU) or dense conversion (CUDA) to compute the determinant efficiently. Supports automatic differentiation via the adjoint method. Returns ------- torch.Tensor Determinant value. Shape [] for single matrix or [*batch_shape] for batched. Raises ------ ValueError If matrix is not square Notes ----- - Only square matrices have determinants - For large matrices, determinant values can overflow/underflow - Consider using log-determinant for numerical stability in such cases - Supports both CPU (via SciPy) and CUDA (via torch.linalg.det) - For batched tensors, computes determinant independently for each batch - Fully differentiable: gradients computed via adjoint method - Gradient formula: ∂det(A)/∂A = det(A) * (A^{-1})^T Performance Warning ------------------- **CUDA performance is significantly slower than CPU for sparse matrices!** - CPU: Uses sparse LU decomposition (O(nnz^1.5)), ~0.3-0.8ms for n=10-1000 - CUDA: Converts to dense (O(n²) memory + O(n³) compute), ~0.2-2.5ms The CUDA version requires converting the sparse matrix to dense format because cuSOLVER/cuDSS don't expose determinant computation for sparse matrices. This makes it inefficient for large sparse matrices. **Recommendation**: For sparse matrices, use `.cpu().det().cuda()` instead: >>> # Slow: CUDA with dense conversion >>> det_slow = A_cuda.det() # ~2.5ms for n=1000 >>> >>> # Fast: CPU with sparse LU >>> det_fast = A_cuda.cpu().det() # ~0.8ms for n=1000 >>> det_fast = det_fast.cuda() # Move result back if needed Examples -------- >>> # Simple 2x2 matrix >>> val = torch.tensor([1.0, 2.0, 3.0, 4.0], requires_grad=True) >>> row = torch.tensor([0, 0, 1, 1]) >>> col = torch.tensor([0, 1, 0, 1]) >>> A = SparseTensor(val, row, col, (2, 2)) >>> det = A.det() >>> print(det) # Should be -2.0 >>> det.backward() >>> print(val.grad) # Gradient w.r.t. matrix values >>> >>> # CUDA support >>> A_cuda = A.cuda() >>> det_cuda = A_cuda.det() >>> >>> # Batched matrices >>> val_batch = val.unsqueeze(0).expand(3, -1).clone() >>> A_batch = SparseTensor(val_batch, row, col, (3, 2, 2)) >>> det_batch = A_batch.det() >>> print(det_batch.shape) # torch.Size([3]) """ M, N = self.sparse_shape if M != N: raise ValueError(f"Matrix must be square for determinant, got shape ({M}, {N})") if self.is_batched: batch_shape = self.batch_shape det_list = [] for idx in self._batch_indices(): A_single = SparseTensor( self.values[idx], self.row_indices, self.col_indices, (M, N) ) det_list.append(A_single.det()) return torch.stack(det_list).reshape(*batch_shape) # Use adjoint method for gradient support return DetAdjoint.apply( self.values, self.row_indices, self.col_indices, (M, N), self.device, self.is_cuda )
# ========================================================================= # LU Factorization # =========================================================================
[文档] def lu(self) -> "LUFactorization": """ Compute LU decomposition for repeated solves. Returns ------- LUFactorization Factorization object with solve() method. Examples -------- >>> A = SparseTensor(val, row, col, (10, 10)) >>> lu = A.lu() >>> x1 = lu.solve(b1) >>> x2 = lu.solve(b2) # Reuses factorization """ if self.is_batched: raise NotImplementedError("lu() not supported for batched tensors") if self.is_cuda: raise NotImplementedError("LU decomposition on CUDA not yet supported") if not is_scipy_available(): raise RuntimeError("SciPy is required for LU decomposition") M, N = self.sparse_shape lu = scipy_lu(self.values, self.row_indices, self.col_indices, (M, N)) return LUFactorization(lu, (M, N), self.dtype, self.device)
# ========================================================================= # String Representation # ========================================================================= def __repr__(self) -> str: parts = [f"SparseTensor(shape={self._shape}"] if self.is_batched: parts.append(f"batch={self.batch_shape}") parts.append(f"sparse={self.sparse_shape}") if self.is_block: parts.append(f"block={self.block_shape}") parts.append(f"nnz={self.nnz}") parts.append(f"dtype={self.dtype}") parts.append(f"device={self.device}") return ", ".join(parts) + ")" # ========================================================================= # Reduction Operations (sum, mean, prod) # ========================================================================= def _normalize_axis(self, axis: Optional[Union[int, Tuple[int, ...]]]) -> Optional[Tuple[int, ...]]: """Normalize axis to tuple of positive indices.""" if axis is None: return None if isinstance(axis, int): axis = (axis,) ndim = len(self._shape) return tuple(a if a >= 0 else ndim + a for a in axis) def _get_dim_type(self, dim: int) -> str: """Get the type of dimension: 'batch', 'sparse_m', 'sparse_n', or 'block'.""" dim_m, dim_n = self._sparse_dim min_sparse = min(dim_m, dim_n) max_sparse = max(dim_m, dim_n) if dim < min_sparse: return 'batch' elif dim == dim_m: return 'sparse_m' elif dim == dim_n: return 'sparse_n' else: return 'block' def _values_axis_for_dim(self, dim: int) -> int: """ Map tensor dimension to values tensor dimension. Values shape: [...batch, nnz, ...block] Tensor shape: [...batch, M, N, ...block] """ dim_m, dim_n = self._sparse_dim min_sparse = min(dim_m, dim_n) max_sparse = max(dim_m, dim_n) if dim < min_sparse: # Batch dimension - same position return dim elif dim == dim_m or dim == dim_n: # Sparse dimension - maps to nnz axis return min_sparse # nnz is at the position of first sparse dim else: # Block dimension - after nnz axis # Shift by -1 because we replaced (M, N) with (nnz,) return dim - 1
[文档] def sum( self, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdim: bool = False ) -> Union[torch.Tensor, "SparseTensor"]: """ Sum of sparse tensor elements over specified axis. Parameters ---------- axis : int, tuple of ints, or None Axis or axes along which to sum. Axes correspond to: - Batch dimensions: [...batch] at the beginning - Sparse dimensions: (M, N) at sparse_dim positions - Block dimensions: [...block] at the end If None, sum over all elements (returns scalar tensor). keepdim : bool Whether to keep the reduced dimensions. Returns ------- torch.Tensor or SparseTensor - If reducing over sparse dimensions: returns dense tensor - If reducing over batch/block dimensions only: returns SparseTensor - If axis=None: returns scalar tensor Examples -------- >>> # Shape: [batch=2, M=10, N=10, block=3] >>> A = SparseTensor(val, row, col, (2, 10, 10, 3)) >>> >>> A.sum() # Scalar: sum all elements >>> A.sum(axis=0) # Sum over batch -> [10, 10, 3] >>> A.sum(axis=1) # Sum over M (rows) -> [2, 10, 3] (dense) >>> A.sum(axis=2) # Sum over N (cols) -> [2, 10, 3] (dense) >>> A.sum(axis=3) # Sum over block -> SparseTensor [2, 10, 10] >>> A.sum(axis=(1,2)) # Sum over M and N -> [2, 3] (dense) """ if axis is None: # Sum over all elements return self.values.sum() axes = self._normalize_axis(axis) dim_types = [self._get_dim_type(d) for d in axes] # Check if we're reducing over sparse dimensions has_sparse_reduction = any(dt in ('sparse_m', 'sparse_n') for dt in dim_types) if has_sparse_reduction: # Need to convert to dense for sparse reduction return self._sum_over_sparse(axes, keepdim) else: # Only batch/block reduction - can stay sparse return self._sum_over_batch_block(axes, keepdim)
def _sum_over_sparse( self, axes: Tuple[int, ...], keepdim: bool ) -> torch.Tensor: """Sum that involves sparse dimensions - returns dense.""" M, N = self.sparse_shape dim_m, dim_n = self._sparse_dim row, col = self.row_indices, self.col_indices # Separate sparse and non-sparse axes sparse_axes = [a for a in axes if self._get_dim_type(a) in ('sparse_m', 'sparse_n')] other_axes = [a for a in axes if self._get_dim_type(a) not in ('sparse_m', 'sparse_n')] reduce_m = dim_m in axes reduce_n = dim_n in axes if self.is_batched: B = self.batch_size batch_shape = self.batch_shape vals_flat = self.values.reshape(B, self.nnz, *self.block_shape) if self.is_block else self.values.reshape(B, self.nnz) if reduce_m and reduce_n: # Sum all sparse entries per batch result = vals_flat.sum(dim=1) # [B, *block] result = result.reshape(*batch_shape, *self.block_shape) if self.is_block else result.reshape(*batch_shape) elif reduce_m: # Sum over rows -> result is [B, N, *block] result = torch.zeros(B, N, *self.block_shape, dtype=self.dtype, device=self.device) col_idx = col.unsqueeze(0).expand(B, -1) if self.is_block: for i in range(B): result[i].scatter_add_(0, col_idx[i].unsqueeze(-1).expand(-1, *self.block_shape), vals_flat[i]) else: result.scatter_add_(1, col_idx, vals_flat) result = result.reshape(*batch_shape, N, *self.block_shape) if self.is_block else result.reshape(*batch_shape, N) else: # reduce_n # Sum over cols -> result is [B, M, *block] result = torch.zeros(B, M, *self.block_shape, dtype=self.dtype, device=self.device) row_idx = row.unsqueeze(0).expand(B, -1) if self.is_block: for i in range(B): result[i].scatter_add_(0, row_idx[i].unsqueeze(-1).expand(-1, *self.block_shape), vals_flat[i]) else: result.scatter_add_(1, row_idx, vals_flat) result = result.reshape(*batch_shape, M, *self.block_shape) if self.is_block else result.reshape(*batch_shape, M) else: vals = self.values if reduce_m and reduce_n: result = vals.sum(dim=0) if self.is_block else vals.sum() elif reduce_m: result = torch.zeros(N, *self.block_shape, dtype=self.dtype, device=self.device) if self.is_block else torch.zeros(N, dtype=self.dtype, device=self.device) if self.is_block: result.scatter_add_(0, col.unsqueeze(-1).expand(-1, *self.block_shape), vals) else: result.scatter_add_(0, col, vals) else: # reduce_n result = torch.zeros(M, *self.block_shape, dtype=self.dtype, device=self.device) if self.is_block else torch.zeros(M, dtype=self.dtype, device=self.device) if self.is_block: result.scatter_add_(0, row.unsqueeze(-1).expand(-1, *self.block_shape), vals) else: result.scatter_add_(0, row, vals) # Handle other axes reduction if other_axes: result_axes = [self._values_axis_for_dim(a) for a in other_axes] result = result.sum(dim=tuple(result_axes), keepdim=keepdim) return result def _sum_over_batch_block( self, axes: Tuple[int, ...], keepdim: bool ) -> "SparseTensor": """Sum over batch/block dimensions only - stays sparse.""" # Map tensor axes to values axes val_axes = tuple(self._values_axis_for_dim(a) for a in axes) new_values = self.values.sum(dim=val_axes, keepdim=keepdim) # Compute new shape new_shape = list(self._shape) if keepdim: for a in axes: new_shape[a] = 1 else: for a in sorted(axes, reverse=True): del new_shape[a] # Adjust sparse_dim if needed new_sparse_dim = list(self._sparse_dim) if not keepdim: removed_before_m = sum(1 for a in axes if a < self._sparse_dim[0]) removed_before_n = sum(1 for a in axes if a < self._sparse_dim[1]) new_sparse_dim[0] -= removed_before_m new_sparse_dim[1] -= removed_before_n return SparseTensor( new_values, self.row_indices, self.col_indices, tuple(new_shape), sparse_dim=tuple(new_sparse_dim) )
[文档] def mean( self, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdim: bool = False ) -> Union[torch.Tensor, "SparseTensor"]: """ Mean of sparse tensor elements over specified axis. Note: For sparse dimensions, this computes mean of non-zero values only, NOT the mean over all M*N elements. For full mean, use to_dense().mean(). Parameters ---------- axis : int, tuple of ints, or None Axis or axes along which to compute mean. keepdim : bool Whether to keep the reduced dimensions. Returns ------- torch.Tensor or SparseTensor Mean values. Examples -------- >>> A = SparseTensor(val, row, col, (10, 10)) >>> A.mean() # Mean of all non-zero values >>> A.mean(axis=0) # Mean over batch dimension """ if axis is None: return self.values.mean() axes = self._normalize_axis(axis) # For sparse dims, we compute sum/count of nnz (not M*N) sum_result = self.sum(axis=axis, keepdim=keepdim) # Compute divisor based on axes divisor = 1 for a in axes: divisor *= self._shape[a] # But for sparse dimensions, divisor should be nnz not M*N dim_types = [self._get_dim_type(a) for a in axes] if 'sparse_m' in dim_types or 'sparse_n' in dim_types: # For sparse reduction, we're averaging over nnz values sparse_divisor = 1 if 'sparse_m' in dim_types: sparse_divisor *= self.sparse_shape[0] if 'sparse_n' in dim_types: sparse_divisor *= self.sparse_shape[1] # Replace M*N with nnz divisor = divisor // sparse_divisor * self.nnz if isinstance(sum_result, SparseTensor): return SparseTensor( sum_result.values / divisor, sum_result.row_indices, sum_result.col_indices, sum_result.shape, sparse_dim=sum_result.sparse_dim ) return sum_result / divisor
[文档] def prod( self, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdim: bool = False ) -> Union[torch.Tensor, "SparseTensor"]: """ Product of sparse tensor elements over specified axis. Warning: For sparse matrices, zero elements are not included in the product. This means prod() computes the product of non-zero values only. Parameters ---------- axis : int, tuple of ints, or None Axis or axes along which to compute product. keepdim : bool Whether to keep the reduced dimensions. Returns ------- torch.Tensor or SparseTensor Product values. Examples -------- >>> A = SparseTensor(val, row, col, (10, 10)) >>> A.prod() # Product of all non-zero values >>> A.prod(axis=0) # Product over batch dimension """ if axis is None: return self.values.prod() axes = self._normalize_axis(axis) dim_types = [self._get_dim_type(d) for d in axes] # Check if we're reducing over sparse dimensions has_sparse_reduction = any(dt in ('sparse_m', 'sparse_n') for dt in dim_types) if has_sparse_reduction: # For sparse reduction, prod is complex - convert to dense warnings.warn( "prod() over sparse dimensions converts to dense. " "This may use significant memory for large matrices." ) dense = self.to_dense() return dense.prod(dim=axes, keepdim=keepdim) else: # Only batch/block reduction val_axes = tuple(self._values_axis_for_dim(a) for a in axes) new_values = self.values.prod(dim=val_axes, keepdim=keepdim) new_shape = list(self._shape) if keepdim: for a in axes: new_shape[a] = 1 else: for a in sorted(axes, reverse=True): del new_shape[a] new_sparse_dim = list(self._sparse_dim) if not keepdim: removed_before_m = sum(1 for a in axes if a < self._sparse_dim[0]) removed_before_n = sum(1 for a in axes if a < self._sparse_dim[1]) new_sparse_dim[0] -= removed_before_m new_sparse_dim[1] -= removed_before_n return SparseTensor( new_values, self.row_indices, self.col_indices, tuple(new_shape), sparse_dim=tuple(new_sparse_dim) )
[文档] def max( self, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdim: bool = False ) -> Union[torch.Tensor, "SparseTensor"]: """Max of non-zero values over specified axis.""" if axis is None: return self.values.max() axes = self._normalize_axis(axis) dim_types = [self._get_dim_type(d) for d in axes] has_sparse_reduction = any(dt in ('sparse_m', 'sparse_n') for dt in dim_types) if has_sparse_reduction: dense = self.to_dense() return dense.max(dim=axes[0], keepdim=keepdim).values if len(axes) == 1 else dense.amax(dim=axes, keepdim=keepdim) else: val_axes = tuple(self._values_axis_for_dim(a) for a in axes) new_values = self.values.amax(dim=val_axes, keepdim=keepdim) new_shape = list(self._shape) if keepdim: for a in axes: new_shape[a] = 1 else: for a in sorted(axes, reverse=True): del new_shape[a] new_sparse_dim = list(self._sparse_dim) if not keepdim: removed_before_m = sum(1 for a in axes if a < self._sparse_dim[0]) removed_before_n = sum(1 for a in axes if a < self._sparse_dim[1]) new_sparse_dim[0] -= removed_before_m new_sparse_dim[1] -= removed_before_n return SparseTensor( new_values, self.row_indices, self.col_indices, tuple(new_shape), sparse_dim=tuple(new_sparse_dim) )
[文档] def min( self, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdim: bool = False ) -> Union[torch.Tensor, "SparseTensor"]: """Min of non-zero values over specified axis.""" if axis is None: return self.values.min() axes = self._normalize_axis(axis) dim_types = [self._get_dim_type(d) for d in axes] has_sparse_reduction = any(dt in ('sparse_m', 'sparse_n') for dt in dim_types) if has_sparse_reduction: dense = self.to_dense() return dense.min(dim=axes[0], keepdim=keepdim).values if len(axes) == 1 else dense.amin(dim=axes, keepdim=keepdim) else: val_axes = tuple(self._values_axis_for_dim(a) for a in axes) new_values = self.values.amin(dim=val_axes, keepdim=keepdim) new_shape = list(self._shape) if keepdim: for a in axes: new_shape[a] = 1 else: for a in sorted(axes, reverse=True): del new_shape[a] new_sparse_dim = list(self._sparse_dim) if not keepdim: removed_before_m = sum(1 for a in axes if a < self._sparse_dim[0]) removed_before_n = sum(1 for a in axes if a < self._sparse_dim[1]) new_sparse_dim[0] -= removed_before_m new_sparse_dim[1] -= removed_before_n return SparseTensor( new_values, self.row_indices, self.col_indices, tuple(new_shape), sparse_dim=tuple(new_sparse_dim) )
# ========================================================================= # Element-wise Operations # ========================================================================= def _apply_elementwise(self, func, *args, **kwargs) -> "SparseTensor": """Apply element-wise function to values. Returns the same type as self to support subclassing. Subclasses should ensure their __init__ accepts (values, row_indices, col_indices, shape) or override this method. """ new_values = func(self.values, *args, **kwargs) # Use type(self) to preserve subclass type try: return type(self)( new_values, self.row_indices, self.col_indices, self._shape ) except TypeError: # Fallback for subclasses with incompatible __init__ return SparseTensor( new_values, self.row_indices, self.col_indices, self._shape, sparse_dim=self._sparse_dim ) # Arithmetic operations def __add__(self, other: Union[torch.Tensor, "SparseTensor", float, int]) -> "SparseTensor": """Element-wise addition. For SparseTensor + SparseTensor, patterns must match.""" if isinstance(other, SparseTensor): if not torch.equal(self.row_indices, other.row_indices) or \ not torch.equal(self.col_indices, other.col_indices): raise ValueError("SparseTensor addition requires matching sparsity patterns") return self._apply_elementwise(lambda v: v + other.values) return self._apply_elementwise(lambda v: v + other) def __radd__(self, other): return self.__add__(other) def __sub__(self, other: Union[torch.Tensor, "SparseTensor", float, int]) -> "SparseTensor": if isinstance(other, SparseTensor): if not torch.equal(self.row_indices, other.row_indices) or \ not torch.equal(self.col_indices, other.col_indices): raise ValueError("SparseTensor subtraction requires matching sparsity patterns") return self._apply_elementwise(lambda v: v - other.values) return self._apply_elementwise(lambda v: v - other) def __rsub__(self, other): return self._apply_elementwise(lambda v: other - v) def __mul__(self, other: Union[torch.Tensor, "SparseTensor", float, int]) -> "SparseTensor": """Element-wise multiplication (Hadamard product for sparse tensors).""" if isinstance(other, SparseTensor): if not torch.equal(self.row_indices, other.row_indices) or \ not torch.equal(self.col_indices, other.col_indices): raise ValueError("SparseTensor multiplication requires matching sparsity patterns") return self._apply_elementwise(lambda v: v * other.values) return self._apply_elementwise(lambda v: v * other) def __rmul__(self, other): return self.__mul__(other) def __truediv__(self, other: Union[torch.Tensor, "SparseTensor", float, int]) -> "SparseTensor": if isinstance(other, SparseTensor): if not torch.equal(self.row_indices, other.row_indices) or \ not torch.equal(self.col_indices, other.col_indices): raise ValueError("SparseTensor division requires matching sparsity patterns") return self._apply_elementwise(lambda v: v / other.values) return self._apply_elementwise(lambda v: v / other) def __rtruediv__(self, other): return self._apply_elementwise(lambda v: other / v) def __floordiv__(self, other): if isinstance(other, SparseTensor): return self._apply_elementwise(lambda v: v // other.values) return self._apply_elementwise(lambda v: v // other) def __pow__(self, exponent: Union[float, int, torch.Tensor]) -> "SparseTensor": return self._apply_elementwise(lambda v: v ** exponent) def __neg__(self) -> "SparseTensor": return self._apply_elementwise(lambda v: -v) def __pos__(self) -> "SparseTensor": return self def __abs__(self) -> "SparseTensor": return self._apply_elementwise(torch.abs) # Math functions - directly delegate to values
[文档] def abs(self) -> "SparseTensor": """Element-wise absolute value.""" return self._apply_elementwise(torch.abs)
[文档] def sqrt(self) -> "SparseTensor": """Element-wise square root.""" return self._apply_elementwise(torch.sqrt)
[文档] def square(self) -> "SparseTensor": """Element-wise square.""" return self._apply_elementwise(torch.square)
[文档] def exp(self) -> "SparseTensor": """Element-wise exponential.""" return self._apply_elementwise(torch.exp)
[文档] def log(self) -> "SparseTensor": """Element-wise natural logarithm.""" return self._apply_elementwise(torch.log)
[文档] def log10(self) -> "SparseTensor": """Element-wise base-10 logarithm.""" return self._apply_elementwise(torch.log10)
[文档] def log2(self) -> "SparseTensor": """Element-wise base-2 logarithm.""" return self._apply_elementwise(torch.log2)
[文档] def sin(self) -> "SparseTensor": """Element-wise sine.""" return self._apply_elementwise(torch.sin)
[文档] def cos(self) -> "SparseTensor": """Element-wise cosine.""" return self._apply_elementwise(torch.cos)
[文档] def tan(self) -> "SparseTensor": """Element-wise tangent.""" return self._apply_elementwise(torch.tan)
[文档] def sinh(self) -> "SparseTensor": """Element-wise hyperbolic sine.""" return self._apply_elementwise(torch.sinh)
[文档] def cosh(self) -> "SparseTensor": """Element-wise hyperbolic cosine.""" return self._apply_elementwise(torch.cosh)
[文档] def tanh(self) -> "SparseTensor": """Element-wise hyperbolic tangent.""" return self._apply_elementwise(torch.tanh)
[文档] def sigmoid(self) -> "SparseTensor": """Element-wise sigmoid.""" return self._apply_elementwise(torch.sigmoid)
[文档] def relu(self) -> "SparseTensor": """Element-wise ReLU.""" return self._apply_elementwise(torch.relu)
[文档] def clamp(self, min: Optional[float] = None, max: Optional[float] = None) -> "SparseTensor": """Element-wise clamp.""" return self._apply_elementwise(lambda v: torch.clamp(v, min=min, max=max))
[文档] def sign(self) -> "SparseTensor": """Element-wise sign.""" return self._apply_elementwise(torch.sign)
[文档] def floor(self) -> "SparseTensor": """Element-wise floor.""" return self._apply_elementwise(torch.floor)
[文档] def ceil(self) -> "SparseTensor": """Element-wise ceil.""" return self._apply_elementwise(torch.ceil)
[文档] def round(self) -> "SparseTensor": """Element-wise round.""" return self._apply_elementwise(torch.round)
[文档] def reciprocal(self) -> "SparseTensor": """Element-wise reciprocal (1/x).""" return self._apply_elementwise(torch.reciprocal)
[文档] def pow(self, exponent: Union[float, int, torch.Tensor]) -> "SparseTensor": """Element-wise power.""" return self._apply_elementwise(lambda v: torch.pow(v, exponent))
# Comparison operations (return SparseTensor with bool values) def __eq__(self, other) -> "SparseTensor": if isinstance(other, SparseTensor): return self._apply_elementwise(lambda v: v == other.values) return self._apply_elementwise(lambda v: v == other) def __ne__(self, other) -> "SparseTensor": if isinstance(other, SparseTensor): return self._apply_elementwise(lambda v: v != other.values) return self._apply_elementwise(lambda v: v != other) def __lt__(self, other) -> "SparseTensor": if isinstance(other, SparseTensor): return self._apply_elementwise(lambda v: v < other.values) return self._apply_elementwise(lambda v: v < other) def __le__(self, other) -> "SparseTensor": if isinstance(other, SparseTensor): return self._apply_elementwise(lambda v: v <= other.values) return self._apply_elementwise(lambda v: v <= other) def __gt__(self, other) -> "SparseTensor": if isinstance(other, SparseTensor): return self._apply_elementwise(lambda v: v > other.values) return self._apply_elementwise(lambda v: v > other) def __ge__(self, other) -> "SparseTensor": if isinstance(other, SparseTensor): return self._apply_elementwise(lambda v: v >= other.values) return self._apply_elementwise(lambda v: v >= other) # Boolean operations
[文档] def logical_not(self) -> "SparseTensor": """Element-wise logical NOT.""" return self._apply_elementwise(torch.logical_not)
[文档] def logical_and(self, other: "SparseTensor") -> "SparseTensor": """Element-wise logical AND.""" return self._apply_elementwise(lambda v: torch.logical_and(v, other.values))
[文档] def logical_or(self, other: "SparseTensor") -> "SparseTensor": """Element-wise logical OR.""" return self._apply_elementwise(lambda v: torch.logical_or(v, other.values))
[文档] def logical_xor(self, other: "SparseTensor") -> "SparseTensor": """Element-wise logical XOR.""" return self._apply_elementwise(lambda v: torch.logical_xor(v, other.values))
# Type checking
[文档] def isnan(self) -> "SparseTensor": """Element-wise isnan check.""" return self._apply_elementwise(torch.isnan)
[文档] def isinf(self) -> "SparseTensor": """Element-wise isinf check.""" return self._apply_elementwise(torch.isinf)
[文档] def isfinite(self) -> "SparseTensor": """Element-wise isfinite check.""" return self._apply_elementwise(torch.isfinite)
# Gradient-related
[文档] def detach(self) -> "SparseTensor": """Detach from computation graph. Preserves subclass type.""" try: return type(self)( self.values.detach(), self.row_indices, self.col_indices, self._shape ) except TypeError: return SparseTensor( self.values.detach(), self.row_indices, self.col_indices, self._shape, sparse_dim=self._sparse_dim )
[文档] def requires_grad_(self, requires_grad: bool = True) -> "SparseTensor": """Enable/disable gradient tracking.""" self.values.requires_grad_(requires_grad) return self
@property def requires_grad(self) -> bool: """Whether gradient tracking is enabled.""" return self.values.requires_grad @property def grad(self) -> Optional[torch.Tensor]: """Gradient of values if available.""" return self.values.grad
[文档] def clone(self) -> "SparseTensor": """Create a copy of this SparseTensor. Preserves subclass type.""" try: return type(self)( self.values.clone(), self.row_indices.clone(), self.col_indices.clone(), self._shape ) except TypeError: return SparseTensor( self.values.clone(), self.row_indices.clone(), self.col_indices.clone(), self._shape, sparse_dim=self._sparse_dim )
[文档] def contiguous(self) -> "SparseTensor": """Make values contiguous in memory. Preserves subclass type.""" try: return type(self)( self.values.contiguous(), self.row_indices.contiguous(), self.col_indices.contiguous(), self._shape ) except TypeError: return SparseTensor( self.values.contiguous(), self.row_indices.contiguous(), self.col_indices.contiguous(), self._shape, sparse_dim=self._sparse_dim )
# ========================================================================= # Persistence (I/O) # =========================================================================
[文档] def save( self, path: Union[str, "os.PathLike"], metadata: Optional[Dict[str, str]] = None ) -> None: """ Save SparseTensor to safetensors format. Parameters ---------- path : str or PathLike Output file path (should end with .safetensors). metadata : dict, optional Additional metadata to store. Example ------- >>> A = SparseTensor(val, row, col, (100, 100)) >>> A.save("matrix.safetensors") """ from .io import save_sparse save_sparse(self, path, metadata)
[文档] @classmethod def load( cls, path: Union[str, "os.PathLike"], device: Union[str, torch.device] = "cpu" ) -> "SparseTensor": """ Load SparseTensor from safetensors format. Parameters ---------- path : str or PathLike Input file path. device : str or torch.device Device to load tensors to. Returns ------- SparseTensor The loaded sparse tensor. Example ------- >>> A = SparseTensor.load("matrix.safetensors", device="cuda") """ from .io import load_sparse return load_sparse(path, device)
[文档] def save_distributed( self, directory: Union[str, "os.PathLike"], num_partitions: int, partition_method: str = "simple", coords: Optional[torch.Tensor] = None, verbose: bool = False ) -> None: """ Save as partitioned files for distributed loading. Creates a directory with metadata and per-partition files. Each rank can then load only its own partition. Parameters ---------- directory : str or PathLike Output directory path. num_partitions : int Number of partitions to create. partition_method : str 'simple', 'metis', or 'geometric'. coords : torch.Tensor, optional Node coordinates for geometric partitioning. verbose : bool Print progress. Example ------- >>> A.save_distributed("matrix_dist", num_partitions=4) # Each rank loads its partition: >>> partition = DSparseMatrix.load("matrix_dist", rank) """ from .io import save_distributed save_distributed(self, directory, num_partitions, partition_method, coords, verbose)
# ============================================================================= # LUFactorization Class # =============================================================================
[文档] class LUFactorization: """ LU factorization wrapper for efficient repeated solves. Created by SparseTensor.lu(). Parameters ---------- lu_factor : scipy.sparse.linalg.SuperLU The SciPy LU factorization object. shape : Tuple[int, int] Matrix shape. dtype : torch.dtype Data type. device : torch.device Device. Examples -------- >>> A = SparseTensor(val, row, col, (10, 10)) >>> lu = A.lu() >>> x1 = lu.solve(b1) # First solve >>> x2 = lu.solve(b2) # Much faster - reuses factorization """ def __init__(self, lu_factor, shape: Tuple[int, int], dtype: torch.dtype, device: torch.device): self._lu = lu_factor self._shape = shape self._dtype = dtype self._device = device
[文档] def solve(self, b: torch.Tensor) -> torch.Tensor: """ Solve Ax = b using the cached factorization. Parameters ---------- b : torch.Tensor Right-hand side vector. Returns ------- torch.Tensor Solution x. """ import numpy as np b_np = b.detach().cpu().numpy() x_np = self._lu.solve(b_np) return torch.from_numpy(x_np).to(dtype=self._dtype, device=self._device)
def __repr__(self) -> str: return f"LUFactorization(shape={self._shape})"
# ============================================================================= # SparseTensorList Class # =============================================================================
[文档] class SparseTensorList: """ A list of SparseTensors with different structures. Provides a unified interface for batch operations on matrices with different sparsity patterns. Unlike batched SparseTensor (which requires same structure), SparseTensorList allows each matrix to have different shape and sparsity pattern. Parameters ---------- tensors : List[SparseTensor] List of SparseTensor objects. Attributes ---------- shapes : List[Tuple[int, ...]] List of shapes for each tensor. device : torch.device Device (from first tensor). dtype : torch.dtype Data type (from first tensor). Examples -------- >>> # Create matrices with different sizes >>> A1 = SparseTensor(val1, row1, col1, (10, 10)) >>> A2 = SparseTensor(val2, row2, col2, (20, 20)) >>> A3 = SparseTensor(val3, row3, col3, (30, 30)) >>> # Create list >>> matrices = SparseTensorList([A1, A2, A3]) >>> print(matrices.shapes) # [(10, 10), (20, 20), (30, 30)] >>> # Batch solve >>> x_list = matrices.solve([b1, b2, b3]) >>> # Check properties for all >>> is_sym = matrices.is_symmetric() # [tensor(True), tensor(True), tensor(True)] """ def __init__(self, tensors: List["SparseTensor"]): if not tensors: raise ValueError("SparseTensorList cannot be empty") self._tensors = list(tensors)
[文档] @classmethod def from_coo_list( cls, matrices: List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Tuple[int, ...]]], ) -> "SparseTensorList": """ Create from list of COO data tuples. Parameters ---------- matrices : List[Tuple] List of (values, row_indices, col_indices, shape) tuples. Returns ------- SparseTensorList List of SparseTensors. Examples -------- >>> data = [ ... (val1, row1, col1, (10, 10)), ... (val2, row2, col2, (20, 20)), ... ] >>> matrices = SparseTensorList.from_coo_list(data) """ tensors = [ SparseTensor(val, row, col, shape) for val, row, col, shape in matrices ] return cls(tensors)
[文档] @classmethod def from_torch_sparse_list(cls, A_list: List[torch.Tensor]) -> "SparseTensorList": """ Create from list of PyTorch sparse tensors. Parameters ---------- A_list : List[torch.Tensor] List of PyTorch sparse COO tensors. Returns ------- SparseTensorList List of SparseTensors. """ tensors = [SparseTensor.from_torch_sparse(A) for A in A_list] return cls(tensors)
@property def shapes(self) -> List[Tuple[int, ...]]: """List of shapes for each tensor.""" return [t.shape for t in self._tensors] @property def device(self) -> torch.device: """Device of the first tensor.""" return self._tensors[0].device @property def dtype(self) -> torch.dtype: """Data type of the first tensor.""" return self._tensors[0].dtype def __len__(self) -> int: """Number of tensors in the list.""" return len(self._tensors) def __getitem__(self, idx: int) -> "SparseTensor": """ Get tensor by index. Parameters ---------- idx : int Index (supports negative indexing). Returns ------- SparseTensor The tensor at that index. """ if idx < 0: idx = len(self._tensors) + idx return self._tensors[idx] def __iter__(self): """Iterate over tensors.""" return iter(self._tensors)
[文档] def to(self, device: Union[str, torch.device]) -> "SparseTensorList": """ Move all tensors to device. Parameters ---------- device : str or torch.device Target device. Returns ------- SparseTensorList New list with tensors on target device. """ return SparseTensorList([t.to(device) for t in self._tensors])
[文档] def cuda(self) -> "SparseTensorList": """Move all tensors to CUDA.""" return self.to('cuda')
[文档] def cpu(self) -> "SparseTensorList": """Move all tensors to CPU.""" return self.to('cpu')
# ========================================================================= # Arithmetic Operations # ========================================================================= def __matmul__(self, x_list: Union[List[torch.Tensor], torch.Tensor]) -> List[torch.Tensor]: """ Batch matrix-vector/matrix multiplication. Parameters ---------- x_list : List[torch.Tensor] or torch.Tensor If List: one vector/matrix per sparse tensor, each with compatible shape. If Tensor: broadcasted to all matrices (must have compatible shape for all). Returns ------- List[torch.Tensor] List of results [A1 @ x1, A2 @ x2, ...] or [A1 @ x, A2 @ x, ...] Examples -------- >>> matrices = SparseTensorList([A1, A2, A3]) >>> # Per-matrix vectors >>> y_list = matrices @ [x1, x2, x3] >>> # Broadcast same vector >>> y_list = matrices @ x # x applied to all """ if isinstance(x_list, torch.Tensor): # Broadcast same tensor to all return [t @ x_list for t in self._tensors] if len(x_list) != len(self._tensors): raise ValueError(f"Expected {len(self._tensors)} vectors, got {len(x_list)}") return [t @ x for t, x in zip(self._tensors, x_list)] def __add__(self, other: Union["SparseTensorList", float, int]) -> "SparseTensorList": """ Element-wise addition. Parameters ---------- other : SparseTensorList or scalar If SparseTensorList: add corresponding matrices (must have same length). If scalar: add to all matrices. Returns ------- SparseTensorList Result of addition. """ if isinstance(other, SparseTensorList): if len(other) != len(self._tensors): raise ValueError(f"Length mismatch: {len(self._tensors)} vs {len(other)}") return SparseTensorList([a + b for a, b in zip(self._tensors, other._tensors)]) # Scalar addition - add to values return SparseTensorList([ SparseTensor(t.values + other, t.row_indices, t.col_indices, t.shape) for t in self._tensors ]) def __radd__(self, other): return self.__add__(other) def __sub__(self, other: Union["SparseTensorList", float, int]) -> "SparseTensorList": """Element-wise subtraction.""" if isinstance(other, SparseTensorList): if len(other) != len(self._tensors): raise ValueError(f"Length mismatch: {len(self._tensors)} vs {len(other)}") return SparseTensorList([a - b for a, b in zip(self._tensors, other._tensors)]) return SparseTensorList([ SparseTensor(t.values - other, t.row_indices, t.col_indices, t.shape) for t in self._tensors ]) def __rsub__(self, other): return SparseTensorList([ SparseTensor(other - t.values, t.row_indices, t.col_indices, t.shape) for t in self._tensors ]) def __mul__(self, other: Union["SparseTensorList", float, int, torch.Tensor]) -> "SparseTensorList": """ Element-wise multiplication. Parameters ---------- other : SparseTensorList, scalar, or Tensor If SparseTensorList: multiply corresponding matrices element-wise. If scalar/Tensor: multiply all values. Returns ------- SparseTensorList Result of multiplication. """ if isinstance(other, SparseTensorList): if len(other) != len(self._tensors): raise ValueError(f"Length mismatch: {len(self._tensors)} vs {len(other)}") return SparseTensorList([a * b for a, b in zip(self._tensors, other._tensors)]) return SparseTensorList([t * other for t in self._tensors]) def __rmul__(self, other): return self.__mul__(other) def __truediv__(self, other: Union[float, int, torch.Tensor]) -> "SparseTensorList": """Element-wise division by scalar.""" return SparseTensorList([t / other for t in self._tensors]) def __neg__(self) -> "SparseTensorList": """Negate all values.""" return SparseTensorList([-t for t in self._tensors])
[文档] def sum(self, axis: Optional[int] = None) -> Union[List[torch.Tensor], torch.Tensor]: """ Sum values in each matrix. Parameters ---------- axis : int, optional If None: sum all values in each matrix, return List[scalar]. If 0: sum over rows for each matrix. If 1: sum over columns for each matrix. Returns ------- List[torch.Tensor] or torch.Tensor If axis is None: List of scalar tensors (one per matrix). If axis is 0 or 1: List of 1D tensors. Examples -------- >>> matrices = SparseTensorList([A1, A2, A3]) >>> totals = matrices.sum() # [sum(A1), sum(A2), sum(A3)] >>> row_sums = matrices.sum(axis=1) # [A1.sum(1), A2.sum(1), ...] """ return [t.sum(axis=axis) for t in self._tensors]
[文档] def mean(self, axis: Optional[int] = None) -> List[torch.Tensor]: """ Mean of values in each matrix. Parameters ---------- axis : int, optional Same as sum(). Returns ------- List[torch.Tensor] List of mean values/vectors. """ return [t.mean(axis=axis) for t in self._tensors]
[文档] def max(self) -> List[torch.Tensor]: """Maximum value in each matrix.""" return [t.max() for t in self._tensors]
[文档] def min(self) -> List[torch.Tensor]: """Minimum value in each matrix.""" return [t.min() for t in self._tensors]
[文档] def abs(self) -> "SparseTensorList": """Absolute value of all elements.""" return SparseTensorList([t.abs() for t in self._tensors])
[文档] def clamp(self, min: Optional[float] = None, max: Optional[float] = None) -> "SparseTensorList": """Clamp values in all matrices.""" return SparseTensorList([t.clamp(min=min, max=max) for t in self._tensors])
[文档] def pow(self, exponent: float) -> "SparseTensorList": """Element-wise power.""" return SparseTensorList([t.pow(exponent) for t in self._tensors])
[文档] def sqrt(self) -> "SparseTensorList": """Element-wise square root.""" return SparseTensorList([t.sqrt() for t in self._tensors])
[文档] def exp(self) -> "SparseTensorList": """Element-wise exponential.""" return SparseTensorList([t.exp() for t in self._tensors])
[文档] def log(self) -> "SparseTensorList": """Element-wise natural logarithm.""" return SparseTensorList([t.log() for t in self._tensors])
# ========================================================================= # Linear Algebra # =========================================================================
[文档] def solve(self, b_list: List[torch.Tensor], **kwargs) -> List[torch.Tensor]: """ Solve linear systems for all matrices. Parameters ---------- b_list : List[torch.Tensor] List of right-hand side vectors, one per matrix. **kwargs Additional arguments passed to SparseTensor.solve(). Returns ------- List[torch.Tensor] List of solutions. Examples -------- >>> matrices = SparseTensorList([A1, A2, A3]) >>> x_list = matrices.solve([b1, b2, b3]) """ if len(b_list) != len(self._tensors): raise ValueError(f"Expected {len(self._tensors)} RHS vectors, got {len(b_list)}") return [t.solve(b, **kwargs) for t, b in zip(self._tensors, b_list)]
[文档] def is_symmetric(self, **kwargs) -> List[torch.Tensor]: """ Check symmetry for all matrices. Parameters ---------- **kwargs Arguments passed to SparseTensor.is_symmetric(). Returns ------- List[torch.Tensor] List of boolean tensors. """ return [t.is_symmetric(**kwargs) for t in self._tensors]
[文档] def is_positive_definite(self, **kwargs) -> List[torch.Tensor]: """ Check positive definiteness for all matrices. Parameters ---------- **kwargs Arguments passed to SparseTensor.is_positive_definite(). Returns ------- List[torch.Tensor] List of boolean tensors. """ return [t.is_positive_definite(**kwargs) for t in self._tensors]
[文档] def norm(self, ord: Literal['fro', 1, 2] = 'fro') -> List[torch.Tensor]: """ Compute norms for all matrices. Parameters ---------- ord : {'fro', 1, 2} Norm type. Returns ------- List[torch.Tensor] List of norm values. """ return [t.norm(ord=ord) for t in self._tensors]
[文档] def eigs(self, k: int = 6, **kwargs) -> List[Tuple[torch.Tensor, Optional[torch.Tensor]]]: """ Compute eigenvalues for all matrices. Parameters ---------- k : int Number of eigenvalues. **kwargs Additional arguments. Returns ------- List[Tuple[torch.Tensor, Optional[torch.Tensor]]] List of (eigenvalues, eigenvectors) tuples. """ return [t.eigs(k=k, **kwargs) for t in self._tensors]
[文档] def eigsh(self, k: int = 6, **kwargs) -> List[Tuple[torch.Tensor, Optional[torch.Tensor]]]: """ Compute eigenvalues for symmetric matrices. Parameters ---------- k : int Number of eigenvalues. **kwargs Additional arguments. Returns ------- List[Tuple[torch.Tensor, Optional[torch.Tensor]]] List of (eigenvalues, eigenvectors) tuples. """ return [t.eigsh(k=k, **kwargs) for t in self._tensors]
[文档] def svd(self, k: int = 6) -> List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: """ Compute SVD for all matrices. Parameters ---------- k : int Number of singular values. Returns ------- List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]] List of (U, S, Vt) tuples. """ return [t.svd(k=k) for t in self._tensors]
[文档] def condition_number(self, ord: int = 2) -> List[torch.Tensor]: """ Compute condition numbers for all matrices. Parameters ---------- ord : int Norm order. Returns ------- List[torch.Tensor] List of condition numbers. """ return [t.condition_number(ord=ord) for t in self._tensors]
[文档] def det(self) -> List[torch.Tensor]: """ Compute determinants for all matrices. Returns ------- List[torch.Tensor] List of determinant values. Examples -------- >>> matrices = SparseTensorList([A1, A2, A3]) >>> dets = matrices.det() >>> print([d.item() for d in dets]) """ return [t.det() for t in self._tensors]
[文档] def spy( self, indices: Optional[List[int]] = None, ncols: int = 3, figsize: Optional[Tuple[float, float]] = None, **kwargs ): """ Visualize sparsity patterns for multiple matrices in a grid. Parameters ---------- indices : List[int], optional Which matrices to visualize. Default: all. ncols : int, optional Number of columns in subplot grid. Default: 3. figsize : Tuple[float, float], optional Figure size. Auto-computed if None. **kwargs Additional arguments passed to SparseTensor.spy(). Returns ------- fig : matplotlib.figure.Figure The figure object. Examples -------- >>> matrices = SparseTensorList([A1, A2, A3, A4]) >>> matrices.spy() # Visualize all in grid >>> matrices.spy(indices=[0, 2]) # Visualize specific ones """ try: import matplotlib.pyplot as plt except ImportError: raise ImportError("matplotlib is required for spy(). Install with: pip install matplotlib") if indices is None: indices = list(range(len(self._tensors))) n = len(indices) nrows = (n + ncols - 1) // ncols if figsize is None: figsize = (4 * ncols, 4 * nrows) fig, axes = plt.subplots(nrows, ncols, figsize=figsize, squeeze=False) for i, idx in enumerate(indices): row, col = i // ncols, i % ncols ax = axes[row, col] self._tensors[idx].spy(ax=ax, show_colorbar=False, **kwargs) M, N = self._tensors[idx].sparse_shape ax.set_title(f'[{idx}] {M}×{N}, nnz={self._tensors[idx].nnz:,}', fontsize=10) # Hide unused subplots for i in range(n, nrows * ncols): row, col = i // ncols, i % ncols axes[row, col].axis('off') plt.tight_layout() return fig
# ========================================================================= # Conversion Methods # =========================================================================
[文档] def to_block_diagonal(self) -> "SparseTensor": """ Merge all matrices into a single block-diagonal SparseTensor. Creates a sparse matrix where each input matrix appears as a block on the diagonal: diag(A1, A2, ..., An). Returns ------- SparseTensor Block-diagonal matrix with shape (sum(M_i), sum(N_i)). Notes ----- The resulting matrix has the structure: ``` [A1 0 0 ...] [ 0 A2 0 ...] [ 0 0 A3 ...] [... ... ... ] ``` Examples -------- >>> A1 = SparseTensor(val1, row1, col1, (10, 10)) >>> A2 = SparseTensor(val2, row2, col2, (20, 20)) >>> stl = SparseTensorList([A1, A2]) >>> A_block = stl.to_block_diagonal() # Shape (30, 30) """ if len(self._tensors) == 0: raise ValueError("Cannot convert empty SparseTensorList to block diagonal") if len(self._tensors) == 1: return self._tensors[0] # Compute offsets row_offsets = [0] col_offsets = [0] for t in self._tensors: M, N = t.sparse_shape row_offsets.append(row_offsets[-1] + M) col_offsets.append(col_offsets[-1] + N) total_rows = row_offsets[-1] total_cols = col_offsets[-1] # Concatenate all COO data with offsets all_values = [] all_rows = [] all_cols = [] for i, t in enumerate(self._tensors): all_values.append(t.values) all_rows.append(t.row_indices + row_offsets[i]) all_cols.append(t.col_indices + col_offsets[i]) values = torch.cat(all_values) rows = torch.cat(all_rows) cols = torch.cat(all_cols) return SparseTensor(values, rows, cols, (total_rows, total_cols))
[文档] @classmethod def from_block_diagonal( cls, sparse: "SparseTensor", sizes: List[Tuple[int, int]] ) -> "SparseTensorList": """ Split a block-diagonal SparseTensor into a list of matrices. Parameters ---------- sparse : SparseTensor Block-diagonal matrix to split. sizes : List[Tuple[int, int]] List of (rows, cols) for each block. Must sum to sparse.shape. Returns ------- SparseTensorList List of extracted blocks. Examples -------- >>> A_block = SparseTensor(val, row, col, (30, 30)) >>> stl = SparseTensorList.from_block_diagonal(A_block, [(10, 10), (20, 20)]) >>> print(len(stl)) # 2 """ if sparse.is_batched: raise NotImplementedError("from_block_diagonal not supported for batched tensors") # Validate sizes total_rows = sum(s[0] for s in sizes) total_cols = sum(s[1] for s in sizes) if (total_rows, total_cols) != sparse.sparse_shape: raise ValueError( f"Sizes sum to ({total_rows}, {total_cols}) but sparse has shape {sparse.sparse_shape}" ) # Compute offsets row_offsets = [0] col_offsets = [0] for m, n in sizes: row_offsets.append(row_offsets[-1] + m) col_offsets.append(col_offsets[-1] + n) tensors = [] row = sparse.row_indices col = sparse.col_indices val = sparse.values for i, (m, n) in enumerate(sizes): r_start, r_end = row_offsets[i], row_offsets[i + 1] c_start, c_end = col_offsets[i], col_offsets[i + 1] # Find entries in this block mask = (row >= r_start) & (row < r_end) & (col >= c_start) & (col < c_end) block_row = row[mask] - r_start block_col = col[mask] - c_start block_val = val[mask] tensors.append(SparseTensor(block_val, block_row, block_col, (m, n))) return cls(tensors)
@property def block_sizes(self) -> List[Tuple[int, int]]: """ Get the (rows, cols) size of each matrix. Returns ------- List[Tuple[int, int]] List of (M, N) tuples. """ return [t.sparse_shape for t in self._tensors] @property def total_nnz(self) -> int: """Total number of non-zeros across all matrices.""" return sum(t.nnz for t in self._tensors) @property def total_shape(self) -> Tuple[int, int]: """Shape of the block-diagonal representation.""" total_rows = sum(t.sparse_shape[0] for t in self._tensors) total_cols = sum(t.sparse_shape[1] for t in self._tensors) return (total_rows, total_cols)
[文档] def partition( self, num_partitions: int, threshold: int = 1000, partition_method: str = 'auto', device: Optional[Union[str, torch.device]] = None, verbose: bool = False ) -> "DSparseTensorList": """ Create distributed version for parallel computing. Parameters ---------- num_partitions : int Number of partitions (typically = world_size). threshold : int Graphs with nodes >= threshold are partitioned across ranks. Smaller graphs are assigned whole to individual ranks. partition_method : str Method for partitioning large graphs: 'metis', 'simple', 'auto'. device : torch.device, optional Target device. verbose : bool Print partition info. Returns ------- DSparseTensorList Distributed sparse tensor list. Notes ----- **Hybrid Strategy:** - Small graphs (< threshold nodes): Assigned whole to ranks round-robin. Zero edge cuts, no halo exchange needed. - Large graphs (>= threshold nodes): Partitioned across all ranks. Uses halo exchange for boundary nodes. This is optimal for molecular datasets where most molecules are small but some (proteins, polymers) can be very large. Examples -------- >>> stl = SparseTensorList([A1, A2, A3, ...]) >>> dstl = stl.partition(num_partitions=4, threshold=1000) >>> y_list = dstl @ x_list # Distributed matmul """ from .distributed import DSparseTensorList return DSparseTensorList.from_sparse_tensor_list( self, num_partitions=num_partitions, threshold=threshold, partition_method=partition_method, device=device, verbose=verbose )
def __repr__(self) -> str: return f"SparseTensorList(n={len(self._tensors)}, device={self.device})"