Source code for torch_sla.nonlinear_solve

"""
Adjoint Nonlinear Solve module for PyTorch

This module provides differentiable nonlinear equation solvers using the adjoint method.

For a nonlinear equation F(u, θ) = 0, where:
- u is the solution variable
- θ are parameters (e.g., neural network weights)

The forward pass solves for u* such that F(u*, θ) = 0.
The backward pass computes gradients ∂L/∂θ using the adjoint method:

    ∂L/∂θ = -λᵀ · ∂F/∂θ

where λ is the adjoint variable satisfying:

    (∂F/∂u)ᵀ · λ = (∂L/∂u)ᵀ

This avoids storing intermediate Jacobians and is memory-efficient.

Supported methods:
- Newton-Raphson with line search
- Picard iteration (fixed-point)
- Anderson acceleration

Usage:
------
    from torch_sla import nonlinear_solve
    
    # Define residual function F(u, θ) -> residual
    def residual_fn(u, theta):
        # Your nonlinear equation
        return F(u, theta)
    
    # Solve with adjoint gradients
    u = nonlinear_solve(residual_fn, u0, theta, method='newton')
    
    # Gradients flow through automatically
    loss = loss_fn(u)
    loss.backward()  # Computes ∂L/∂θ via adjoint method
"""

import torch
from torch import Tensor
from torch.autograd import Function
from typing import Callable, Optional, Tuple, Union, Dict, Any
import warnings

from .linear_solve import spsolve


[docs] class NonlinearSolveAdjoint(Function): """ Adjoint-based nonlinear solver with automatic differentiation. Uses implicit differentiation to compute gradients without storing intermediate Jacobians. Memory-efficient for large-scale problems. """
[docs] @staticmethod def forward( ctx, u0: Tensor, num_params: int, # Number of parameter tensors *args, # params tensors followed by config dict ) -> Tensor: """ Forward pass: solve F(u, θ) = 0 for u. Args: u0: Initial guess for solution num_params: Number of parameter tensors *args: First num_params elements are param tensors, last is config dict Returns: u: Solution satisfying F(u, θ) ≈ 0 """ # Extract params and config params = args[:num_params] config = args[num_params] # Extract config residual_fn = config['residual_fn'] jacobian_fn = config.get('jacobian_fn', None) method = config.get('method', 'newton') tol = config.get('tol', 1e-6) atol = config.get('atol', 1e-10) max_iter = config.get('max_iter', 50) line_search = config.get('line_search', True) verbose = config.get('verbose', False) linear_solver = config.get('linear_solver', 'pytorch') linear_method = config.get('linear_method', 'cg') # Detach for forward solve (no gradient tracking during iteration) u = u0.detach().clone() params_detached = tuple(p.detach() if isinstance(p, Tensor) else p for p in params) if method == 'newton': u, info = _newton_solve( u, params_detached, residual_fn, jacobian_fn, tol=tol, atol=atol, max_iter=max_iter, line_search=line_search, verbose=verbose, linear_solver=linear_solver, linear_method=linear_method ) elif method == 'picard': u, info = _picard_solve( u, params_detached, residual_fn, tol=tol, atol=atol, max_iter=max_iter, verbose=verbose ) elif method == 'anderson': u, info = _anderson_solve( u, params_detached, residual_fn, tol=tol, atol=atol, max_iter=max_iter, verbose=verbose ) else: raise ValueError(f"Unknown method: {method}. Use 'newton', 'picard', or 'anderson'") # Save for backward - save u and all param tensors that require grad tensors_to_save = [u] param_requires_grad = [] params_no_grad = [] # Store non-grad params for backward for p in params: if isinstance(p, Tensor): param_requires_grad.append(p.requires_grad) if p.requires_grad: tensors_to_save.append(p) params_no_grad.append(None) # placeholder else: params_no_grad.append(p.detach()) # save detached copy else: param_requires_grad.append(False) params_no_grad.append(p) # non-tensor param ctx.save_for_backward(*tensors_to_save) ctx.residual_fn = residual_fn ctx.jacobian_fn = jacobian_fn ctx.num_params = num_params ctx.param_requires_grad = param_requires_grad ctx.params_no_grad = params_no_grad # for reconstructing params in backward ctx.linear_solver = linear_solver ctx.linear_method = linear_method ctx.tol = tol ctx.atol = atol ctx.info = info return u
[docs] @staticmethod def backward(ctx, grad_u: Tensor): """ Backward pass using adjoint method. Computes ∂L/∂θ = -λᵀ · ∂F/∂θ where (∂F/∂u)ᵀ · λ = grad_u Returns: Tuple of gradients: (grad_u0, grad_num_params, *grad_params, grad_config) """ saved = ctx.saved_tensors u = saved[0] param_tensors = saved[1:] # Only tensors that required grad residual_fn = ctx.residual_fn jacobian_fn = ctx.jacobian_fn num_params = ctx.num_params param_requires_grad = ctx.param_requires_grad params_no_grad = ctx.params_no_grad # Reconstruct params list using saved tensors and cached non-grad params param_idx = 0 params_for_backward = [] for i, requires_grad in enumerate(param_requires_grad): if requires_grad: params_for_backward.append(param_tensors[param_idx]) param_idx += 1 else: params_for_backward.append(params_no_grad[i]) # use cached non-grad param # Enable gradient computation (backward is called in no_grad context) with torch.enable_grad(): # Step 1: Solve adjoint equation (∂F/∂u)ᵀ · λ = grad_u lambda_adj = _solve_adjoint_system( u, params_for_backward, residual_fn, jacobian_fn, grad_u, linear_solver=ctx.linear_solver, linear_method=ctx.linear_method, tol=ctx.tol, atol=ctx.atol ) # Step 2: Compute ∂L/∂θ = -λᵀ · ∂F/∂θ for each parameter grad_params = [] # Setup variables for gradient computation u_var = u.detach().requires_grad_(True) params_var = [] for i, requires_grad in enumerate(param_requires_grad): if requires_grad: params_var.append(params_for_backward[i].detach().requires_grad_(True)) else: params_var.append(params_for_backward[i]) # Compute F(u, θ) with gradient tracking F = residual_fn(u_var, *params_var) # Compute gradients for each parameter for i, requires_grad in enumerate(param_requires_grad): if requires_grad: # ∂L/∂θᵢ = -λᵀ · ∂F/∂θᵢ grad_p = torch.autograd.grad( F, params_var[i], grad_outputs=lambda_adj, retain_graph=True, allow_unused=True )[0] if grad_p is not None: grad_params.append(-grad_p) else: grad_params.append(None) else: grad_params.append(None) # Return: (grad_u0, grad_num_params, *grad_params, grad_config) # grad_u0 = None, grad_num_params = None, grad_config = None return (None, None) + tuple(grad_params) + (None,)
def _newton_solve( u: Tensor, params: Tuple, residual_fn: Callable, jacobian_fn: Optional[Callable], tol: float, atol: float, max_iter: int, line_search: bool, verbose: bool, linear_solver: str, linear_method: str, ) -> Tuple[Tensor, Dict[str, Any]]: """Newton-Raphson solver with optional line search.""" device = u.device dtype = u.dtype n = u.numel() for iteration in range(max_iter): # Compute residual F = residual_fn(u, *params) F_norm = torch.norm(F).item() if verbose: print(f" Newton iter {iteration}: ||F|| = {F_norm:.2e}") # Check convergence if F_norm < atol: if verbose: print(f" Converged (atol) at iteration {iteration}") return u, {'converged': True, 'iterations': iteration, 'residual': F_norm} if iteration > 0 and F_norm < tol * F_norm_0: if verbose: print(f" Converged (rtol) at iteration {iteration}") return u, {'converged': True, 'iterations': iteration, 'residual': F_norm} if iteration == 0: F_norm_0 = F_norm # Compute Newton step: J(u) * du = -F if jacobian_fn is not None: # Explicit Jacobian provided val, row, col, shape = jacobian_fn(u, *params) du = spsolve(val, row, col, shape, -F, backend=linear_solver, method=linear_method, tol=tol * 0.1, maxiter=max(100, n // 10)) else: # Use Jacobian-free Newton-Krylov du = _jacobian_free_solve( u, params, residual_fn, -F, linear_solver=linear_solver, linear_method=linear_method, tol=tol * 0.1, max_iter=max(100, n // 10) ) # Line search if line_search: alpha = _armijo_line_search(u, du, params, residual_fn, F_norm) else: alpha = 1.0 # Update u = u + alpha * du warnings.warn(f"Newton did not converge in {max_iter} iterations, ||F|| = {F_norm:.2e}") return u, {'converged': False, 'iterations': max_iter, 'residual': F_norm} def _picard_solve( u: Tensor, params: Tuple, residual_fn: Callable, tol: float, atol: float, max_iter: int, verbose: bool, ) -> Tuple[Tensor, Dict[str, Any]]: """Picard (fixed-point) iteration solver.""" for iteration in range(max_iter): F = residual_fn(u, *params) F_norm = torch.norm(F).item() if verbose: print(f" Picard iter {iteration}: ||F|| = {F_norm:.2e}") if F_norm < atol: return u, {'converged': True, 'iterations': iteration, 'residual': F_norm} if iteration > 0 and F_norm < tol * F_norm_0: return u, {'converged': True, 'iterations': iteration, 'residual': F_norm} if iteration == 0: F_norm_0 = F_norm # Fixed-point update: u_new = u - F(u) # This assumes F(u) = u - g(u) form, so u = g(u) u = u - F warnings.warn(f"Picard did not converge in {max_iter} iterations") return u, {'converged': False, 'iterations': max_iter, 'residual': F_norm} def _anderson_solve( u: Tensor, params: Tuple, residual_fn: Callable, tol: float, atol: float, max_iter: int, verbose: bool, m: int = 5, # Anderson depth ) -> Tuple[Tensor, Dict[str, Any]]: """Anderson acceleration solver.""" device = u.device dtype = u.dtype n = u.numel() # History storage X_hist = [] # Previous iterates F_hist = [] # Previous residuals for iteration in range(max_iter): F = residual_fn(u, *params) F_norm = torch.norm(F).item() if verbose: print(f" Anderson iter {iteration}: ||F|| = {F_norm:.2e}") if F_norm < atol: return u, {'converged': True, 'iterations': iteration, 'residual': F_norm} if iteration > 0 and F_norm < tol * F_norm_0: return u, {'converged': True, 'iterations': iteration, 'residual': F_norm} if iteration == 0: F_norm_0 = F_norm # Store history X_hist.append(u.clone()) F_hist.append(F.clone()) # Limit history size if len(X_hist) > m + 1: X_hist.pop(0) F_hist.pop(0) # Anderson mixing if len(F_hist) >= 2: # Build matrix of residual differences k = len(F_hist) - 1 dF = torch.stack([F_hist[i+1] - F_hist[i] for i in range(k)], dim=1) # [n, k] # Solve least squares: min ||F_k - dF @ alpha||^2 # (dF^T dF) alpha = dF^T F_k gram = dF.T @ dF + 1e-10 * torch.eye(k, device=device, dtype=dtype) rhs = dF.T @ F_hist[-1] alpha = torch.linalg.solve(gram, rhs) # Compute new iterate u_new = X_hist[-1] - F_hist[-1] # Simple fixed-point for i in range(k): u_new = u_new - alpha[i] * (X_hist[i+1] - X_hist[i] - (F_hist[i+1] - F_hist[i])) u = u_new else: # Simple fixed-point for first iteration u = u - F warnings.warn(f"Anderson did not converge in {max_iter} iterations") return u, {'converged': False, 'iterations': max_iter, 'residual': F_norm} def _jacobian_free_solve( u: Tensor, params: Tuple, residual_fn: Callable, rhs: Tensor, linear_solver: str, linear_method: str, tol: float, max_iter: int, ) -> Tensor: """ Jacobian-free Newton-Krylov solve. Solves J(u) @ x = rhs using Krylov methods with Jacobian-vector products computed via automatic differentiation. """ device = u.device dtype = u.dtype n = u.numel() # Detach params - we only need Jacobian w.r.t. u, not params params_detached = tuple( p.detach() if isinstance(p, Tensor) else p for p in params ) def matvec(v): """Compute J(u) @ v using autograd (jvp).""" # Enable gradient tracking - needed when called from autograd.Function.forward with torch.enable_grad(): # Enable gradient tracking for u only u_var = u.detach().clone().requires_grad_(True) # Compute F(u) with gradient tracking (params detached) F = residual_fn(u_var, *params_detached) # Jacobian-vector product via autograd # Jv = ∂F/∂u @ v Jv = torch.autograd.grad( outputs=F, inputs=u_var, grad_outputs=v, create_graph=False, retain_graph=False )[0] return Jv # Use CG with matvec x = torch.zeros_like(rhs) r = rhs.clone() # r = b - A @ x, initially r = b p = r.clone() rs_old = torch.dot(r.flatten(), r.flatten()) rhs_norm = torch.norm(rhs) if rhs_norm < 1e-30: return x for i in range(max_iter): Ap = matvec(p) pAp = torch.dot(p.flatten(), Ap.flatten()) if abs(pAp) < 1e-30: break alpha = rs_old / pAp x = x + alpha * p r = r - alpha * Ap rs_new = torch.dot(r.flatten(), r.flatten()) if torch.sqrt(rs_new) < tol * rhs_norm: break beta = rs_new / rs_old p = r + beta * p rs_old = rs_new return x def _solve_adjoint_system( u: Tensor, params: Tuple, residual_fn: Callable, jacobian_fn: Optional[Callable], rhs: Tensor, linear_solver: str, linear_method: str, tol: float, atol: float, ) -> Tensor: """ Solve the adjoint system: (∂F/∂u)ᵀ @ λ = rhs """ device = u.device dtype = u.dtype n = u.numel() if jacobian_fn is not None: # Explicit Jacobian: transpose and solve val, row, col, shape = jacobian_fn(u, *params) # Transpose: swap row and col lambda_adj = spsolve(val, col, row, (shape[1], shape[0]), rhs, backend=linear_solver, method=linear_method, tol=tol, maxiter=max(100, n // 10)) else: # Jacobian-free: use CG with Jᵀv products def matvec_transpose(v): """Compute Jᵀ @ v using autograd.""" u_var = u.detach().requires_grad_(True) F = residual_fn(u_var, *params) # For Jᵀv, we use the identity: Jᵀv = ∂(v·F)/∂u vF = torch.dot(v.flatten(), F.flatten()) Jtv = torch.autograd.grad(vF, u_var, retain_graph=False)[0] return Jtv # CG for Jᵀ @ λ = rhs lambda_adj = torch.zeros_like(rhs) r = rhs.clone() p = r.clone() rs_old = torch.dot(r.flatten(), r.flatten()) for i in range(max(100, n // 10)): Ap = matvec_transpose(p) pAp = torch.dot(p.flatten(), Ap.flatten()) if abs(pAp) < 1e-30: break alpha = rs_old / pAp lambda_adj = lambda_adj + alpha * p r = r - alpha * Ap rs_new = torch.dot(r.flatten(), r.flatten()) if torch.sqrt(rs_new) < tol * torch.norm(rhs) + atol: break beta = rs_new / rs_old p = r + beta * p rs_old = rs_new return lambda_adj def _armijo_line_search( u: Tensor, du: Tensor, params: Tuple, residual_fn: Callable, F_norm: float, c: float = 1e-4, rho: float = 0.5, max_iter: int = 20, ) -> float: """Armijo backtracking line search.""" alpha = 1.0 for _ in range(max_iter): u_new = u + alpha * du F_new = residual_fn(u_new, *params) F_new_norm = torch.norm(F_new).item() # Armijo condition: f(x + α*d) ≤ f(x) - c*α*||d|| if F_new_norm <= (1 - c * alpha) * F_norm: return alpha alpha *= rho return alpha # ============================================================================ # High-level API # ============================================================================
[docs] def nonlinear_solve( residual_fn: Callable, u0: Tensor, *params, jacobian_fn: Optional[Callable] = None, method: str = 'newton', tol: float = 1e-6, atol: float = 1e-10, max_iter: int = 50, line_search: bool = True, verbose: bool = False, linear_solver: str = 'pytorch', linear_method: str = 'cg', ) -> Tensor: """ Solve nonlinear equation F(u, θ) = 0 with adjoint-based gradients. Args: residual_fn: Function F(u, *params) -> residual tensor u0: Initial guess for solution *params: Parameters θ (tensors with requires_grad=True for gradient computation) jacobian_fn: Optional function J(u, *params) -> (val, row, col, shape) Returns sparse Jacobian in COO format. If None, uses autograd. method: Nonlinear solver method - 'newton': Newton-Raphson with optional line search (default) - 'picard': Fixed-point iteration - 'anderson': Anderson acceleration tol: Relative convergence tolerance atol: Absolute convergence tolerance max_iter: Maximum number of nonlinear iterations line_search: Use Armijo line search for Newton (default: True) verbose: Print convergence information linear_solver: Backend for linear solves ('pytorch', 'scipy', 'cudss') linear_method: Method for linear solves ('cg', 'bicgstab', 'lu') Returns: u: Solution tensor satisfying F(u, θ) ≈ 0 Example: >>> def residual(u, A_val, b): ... # Nonlinear: A(u) @ u - b where A depends on u ... return torch.sparse.mm(A, u.unsqueeze(1)).squeeze() - b ... >>> u0 = torch.zeros(n, requires_grad=False) >>> A_val = torch.randn(nnz, requires_grad=True) >>> b = torch.randn(n, requires_grad=True) >>> >>> u = nonlinear_solve(residual, u0, A_val, b, method='newton') >>> loss = some_loss(u) >>> loss.backward() # Computes ∂L/∂A_val and ∂L/∂b via adjoint """ config = { 'residual_fn': residual_fn, 'jacobian_fn': jacobian_fn, 'method': method, 'tol': tol, 'atol': atol, 'max_iter': max_iter, 'line_search': line_search, 'verbose': verbose, 'linear_solver': linear_solver, 'linear_method': linear_method, } # Call apply with: u0, num_params, *params, config return NonlinearSolveAdjoint.apply(u0, len(params), *params, config)
# Alias adjoint_solve = nonlinear_solve