Source code for torch_sla.io

"""
Persistence utilities for SparseTensor and DSparseTensor.

Supports:
- safetensors format for efficient, safe serialization
- Matrix Market (.mtx) format for interoperability with other tools
- Distributed loading where different ranks load different partitions

Example
-------
>>> from torch_sla import SparseTensor
>>> from torch_sla.io import save_sparse, load_sparse, save_distributed, load_partition
>>>
>>> # Save single SparseTensor
>>> A = SparseTensor(val, row, col, shape)
>>> save_sparse(A, "matrix.safetensors")
>>> A_loaded = load_sparse("matrix.safetensors")
>>>
>>> # Matrix Market format
>>> save_mtx(A, "matrix.mtx")
>>> A = load_mtx("matrix.mtx")
>>>
>>> # Save partitioned for distributed loading
>>> save_distributed(A, "matrix_dist", num_partitions=4)
>>> # Each rank loads its partition
>>> partition = load_partition("matrix_dist", rank=rank, world_size=4)
"""

import os
import json
import torch
from typing import Dict, Optional, Tuple, Union, TYPE_CHECKING
from pathlib import Path

if TYPE_CHECKING:
    from .sparse_tensor import SparseTensor
    from .distributed import DSparseMatrix, DSparseTensor

try:
    from safetensors.torch import save_file, load_file
    SAFETENSORS_AVAILABLE = True
except ImportError:
    SAFETENSORS_AVAILABLE = False


def _ensure_safetensors():
    """Ensure safetensors is available."""
    if not SAFETENSORS_AVAILABLE:
        raise ImportError(
            "safetensors is required for persistence. "
            "Install with: pip install safetensors"
        )


# =============================================================================
# Matrix Market (.mtx) Format
# =============================================================================

[docs] def save_mtx( tensor: "SparseTensor", path: Union[str, Path], comment: str = "", field: str = "real", symmetry: str = "general", ) -> None: """ Save a SparseTensor to Matrix Market (.mtx) format. Parameters ---------- tensor : SparseTensor The sparse tensor to save. path : str or Path Output file path (should end with .mtx). comment : str, optional Comment to include in the header. field : str, optional Field type: 'real', 'complex', 'integer', or 'pattern'. Default: 'real'. symmetry : str, optional Symmetry type: 'general', 'symmetric', 'skew-symmetric', or 'hermitian'. Default: 'general'. Example ------- >>> A = SparseTensor(val, row, col, (100, 100)) >>> save_mtx(A, "matrix.mtx") >>> save_mtx(A, "matrix.mtx", symmetry="symmetric") """ from .sparse_tensor import SparseTensor if not isinstance(tensor, SparseTensor): raise TypeError(f"Expected SparseTensor, got {type(tensor)}") if tensor.is_batched: raise ValueError("Cannot save batched SparseTensor to Matrix Market format") path = Path(path) M, N = tensor.sparse_shape nnz = tensor.nnz # Get data on CPU row = tensor.row_indices.cpu().numpy() col = tensor.col_indices.cpu().numpy() val = tensor.values.cpu().numpy() # Determine field type from dtype if tensor.dtype in (torch.complex64, torch.complex128): field = "complex" elif tensor.dtype in (torch.int32, torch.int64): field = "integer" with open(path, 'w') as f: # Write header f.write(f"%%MatrixMarket matrix coordinate {field} {symmetry}\n") if comment: for line in comment.split('\n'): f.write(f"% {line}\n") f.write(f"% Generated by torch-sla\n") # Write dimensions f.write(f"{M} {N} {nnz}\n") # Write entries (1-indexed) if field == "pattern": for i in range(nnz): f.write(f"{row[i] + 1} {col[i] + 1}\n") elif field == "complex": for i in range(nnz): v = val[i] f.write(f"{row[i] + 1} {col[i] + 1} {v.real} {v.imag}\n") else: for i in range(nnz): f.write(f"{row[i] + 1} {col[i] + 1} {val[i]}\n")
[docs] def load_mtx( path: Union[str, Path], dtype: Optional[torch.dtype] = None, device: Union[str, torch.device] = "cpu", ) -> "SparseTensor": """ Load a SparseTensor from Matrix Market (.mtx) format. Parameters ---------- path : str or Path Input file path. dtype : torch.dtype, optional Data type for values. If None, inferred from file. device : str or torch.device Device to load tensors to. Returns ------- SparseTensor The loaded sparse tensor. Example ------- >>> A = load_mtx("matrix.mtx") >>> A = load_mtx("matrix.mtx", dtype=torch.float32, device="cuda") """ from .sparse_tensor import SparseTensor import numpy as np path = Path(path) with open(path, 'r') as f: # Parse header header = f.readline().strip() if not header.startswith("%%MatrixMarket"): raise ValueError(f"Invalid Matrix Market header: {header}") parts = header.split() if len(parts) < 4: raise ValueError(f"Invalid Matrix Market header: {header}") obj_type = parts[1].lower() # matrix format_type = parts[2].lower() # coordinate field_type = parts[3].lower() if len(parts) > 3 else "real" symmetry = parts[4].lower() if len(parts) > 4 else "general" if obj_type != "matrix": raise ValueError(f"Only 'matrix' object type supported, got: {obj_type}") if format_type != "coordinate": raise ValueError(f"Only 'coordinate' format supported, got: {format_type}") # Skip comments line = f.readline() while line.startswith('%'): line = f.readline() # Parse dimensions dims = line.strip().split() M, N, nnz = int(dims[0]), int(dims[1]), int(dims[2]) # Parse entries rows = [] cols = [] vals = [] for line in f: parts = line.strip().split() if not parts: continue r = int(parts[0]) - 1 # Convert to 0-indexed c = int(parts[1]) - 1 if field_type == "pattern": v = 1.0 elif field_type == "complex": v = complex(float(parts[2]), float(parts[3])) elif field_type == "integer": v = int(parts[2]) else: v = float(parts[2]) rows.append(r) cols.append(c) vals.append(v) # Handle symmetry if symmetry == "symmetric" and r != c: rows.append(c) cols.append(r) vals.append(v) elif symmetry == "skew-symmetric" and r != c: rows.append(c) cols.append(r) vals.append(-v) elif symmetry == "hermitian" and r != c: rows.append(c) cols.append(r) vals.append(v.conjugate() if isinstance(v, complex) else v) # Convert to tensors row_tensor = torch.tensor(rows, dtype=torch.long, device=device) col_tensor = torch.tensor(cols, dtype=torch.long, device=device) # Determine dtype if dtype is None: if field_type == "complex": dtype = torch.complex128 elif field_type == "integer": dtype = torch.int64 else: dtype = torch.float64 val_tensor = torch.tensor(vals, dtype=dtype, device=device) return SparseTensor(val_tensor, row_tensor, col_tensor, (M, N))
[docs] def load_mtx_info(path: Union[str, Path]) -> Dict: """ Read Matrix Market file header without loading data. Parameters ---------- path : str or Path Input file path. Returns ------- dict Dictionary with keys: 'shape', 'nnz', 'field', 'symmetry'. Example ------- >>> info = load_mtx_info("matrix.mtx") >>> print(f"Shape: {info['shape']}, NNZ: {info['nnz']}") """ path = Path(path) with open(path, 'r') as f: header = f.readline().strip() parts = header.split() field_type = parts[3].lower() if len(parts) > 3 else "real" symmetry = parts[4].lower() if len(parts) > 4 else "general" # Skip comments line = f.readline() while line.startswith('%'): line = f.readline() dims = line.strip().split() M, N, nnz = int(dims[0]), int(dims[1]), int(dims[2]) return { 'shape': (M, N), 'nnz': nnz, 'field': field_type, 'symmetry': symmetry, }
# ============================================================================= # SparseTensor I/O # =============================================================================
[docs] def save_sparse( tensor: "SparseTensor", path: Union[str, Path], metadata: Optional[Dict[str, str]] = None ) -> None: """ Save a SparseTensor to safetensors format. Parameters ---------- tensor : SparseTensor The sparse tensor to save. path : str or Path Output file path (should end with .safetensors). metadata : dict, optional Additional metadata to store in the file. Example ------- >>> A = SparseTensor(val, row, col, (100, 100)) >>> save_sparse(A, "matrix.safetensors") """ _ensure_safetensors() from .sparse_tensor import SparseTensor if not isinstance(tensor, SparseTensor): raise TypeError(f"Expected SparseTensor, got {type(tensor)}") # Prepare tensors dict tensors = { "values": tensor.values.contiguous().cpu(), "row_indices": tensor.row_indices.contiguous().cpu(), "col_indices": tensor.col_indices.contiguous().cpu(), "shape": torch.tensor(tensor.sparse_shape, dtype=torch.int64), } # Prepare metadata meta = { "sparse_dim_0": str(tensor.sparse_dim[0]), "sparse_dim_1": str(tensor.sparse_dim[1]), "dtype": str(tensor.dtype), "format": "sparse_tensor", "version": "1.0", } if metadata: meta.update(metadata) save_file(tensors, str(path), metadata=meta)
[docs] def load_sparse( path: Union[str, Path], device: Union[str, torch.device] = "cpu" ) -> "SparseTensor": """ Load a SparseTensor from safetensors format. Parameters ---------- path : str or Path Input file path. device : str or torch.device Device to load tensors to. Returns ------- SparseTensor The loaded sparse tensor. Example ------- >>> A = load_sparse("matrix.safetensors", device="cuda") """ _ensure_safetensors() from .sparse_tensor import SparseTensor tensors = load_file(str(path), device=str(device)) values = tensors["values"] row_indices = tensors["row_indices"] col_indices = tensors["col_indices"] shape = tuple(tensors["shape"].tolist()) return SparseTensor(values, row_indices, col_indices, shape)
[docs] def load_sparse_as_partition( path: Union[str, Path], rank: int, world_size: int, partition_method: str = "simple", coords: Optional[torch.Tensor] = None, device: Union[str, torch.device] = "cpu" ) -> "DSparseMatrix": """ Load a SparseTensor file and return only this rank's partition. This allows distributed reading of a single SparseTensor file, where each rank loads the full file but only keeps its partition. For very large matrices, use save_distributed() instead to avoid loading the full matrix on each rank. Parameters ---------- path : str or Path Path to SparseTensor file (.safetensors). rank : int Rank of this process. world_size : int Total number of processes. partition_method : str 'simple', 'metis', or 'geometric'. coords : torch.Tensor, optional Node coordinates for geometric partitioning. device : str or torch.device Device to load partition to. Returns ------- DSparseMatrix This rank's partition of the matrix. Example ------- >>> # Each rank calls this: >>> rank = dist.get_rank() >>> world_size = dist.get_world_size() >>> partition = load_sparse_as_partition("matrix.safetensors", rank, world_size) """ _ensure_safetensors() from .sparse_tensor import SparseTensor from .distributed import DSparseMatrix, partition_graph_metis, partition_coordinates, partition_simple # Load full matrix (could be optimized for very large matrices) A = load_sparse(path, device="cpu") # Compute partition IDs locally (same on all ranks for determinism) shape = A.sparse_shape if coords is not None: partition_ids = partition_coordinates(coords, world_size) elif partition_method == 'metis': partition_ids = partition_graph_metis(A.row_indices, A.col_indices, shape[0], world_size) else: partition_ids = partition_simple(shape[0], world_size) # Create partition for this rank return DSparseMatrix.from_global( A.values, A.row_indices, A.col_indices, shape, world_size, rank, partition_ids=partition_ids, device=device, verbose=(rank == 0) )
# ============================================================================= # Distributed I/O - Save partitioned for multi-rank loading # =============================================================================
[docs] def save_distributed( tensor: "SparseTensor", directory: Union[str, Path], num_partitions: int, partition_method: str = "simple", coords: Optional[torch.Tensor] = None, verbose: bool = False ) -> None: """ Save a SparseTensor as partitioned files for distributed loading. Creates a directory with: - metadata.json: Global metadata and partition info - partition_0.safetensors, partition_1.safetensors, ...: Per-partition data Parameters ---------- tensor : SparseTensor The global sparse tensor to partition and save. directory : str or Path Output directory path. num_partitions : int Number of partitions to create. partition_method : str Partitioning method: 'simple', 'metis', or 'geometric'. coords : torch.Tensor, optional Node coordinates for geometric partitioning. verbose : bool Print progress information. Example ------- >>> A = SparseTensor(val, row, col, (1000, 1000)) >>> save_distributed(A, "matrix_dist", num_partitions=4) # Creates: # matrix_dist/metadata.json # matrix_dist/partition_0.safetensors # matrix_dist/partition_1.safetensors # matrix_dist/partition_2.safetensors # matrix_dist/partition_3.safetensors """ _ensure_safetensors() from .sparse_tensor import SparseTensor from .distributed import DSparseTensor if not isinstance(tensor, SparseTensor): raise TypeError(f"Expected SparseTensor, got {type(tensor)}") directory = Path(directory) directory.mkdir(parents=True, exist_ok=True) # Create DSparseTensor to get partitions d_tensor = DSparseTensor( tensor.values, tensor.row_indices, tensor.col_indices, tensor.sparse_shape, num_partitions=num_partitions, coords=coords, partition_method=partition_method, verbose=verbose ) # Save global metadata metadata = { "version": "1.0", "format": "distributed_sparse_tensor", "shape": list(tensor.sparse_shape), "dtype": str(tensor.dtype), "nnz": int(tensor.nnz), "num_partitions": num_partitions, "partition_method": partition_method, "sparse_dim": list(tensor.sparse_dim), } # Collect partition metadata partition_info = [] for i, partition in enumerate(d_tensor._partitions): info = { "partition_id": i, "num_owned": int(partition.num_owned), "num_halo": int(partition.num_halo), "num_local": int(partition.num_local), "nnz": int(partition.nnz), "neighbors": partition.partition.neighbor_partitions, } partition_info.append(info) metadata["partitions"] = partition_info with open(directory / "metadata.json", "w") as f: json.dump(metadata, f, indent=2) # Save each partition for i, partition in enumerate(d_tensor._partitions): tensors = { "values": partition.local_values.contiguous().cpu(), "row_indices": partition.local_row.contiguous().cpu(), "col_indices": partition.local_col.contiguous().cpu(), "owned_nodes": partition.partition.owned_nodes.contiguous().cpu(), "halo_nodes": partition.partition.halo_nodes.contiguous().cpu(), "local_nodes": partition.partition.local_nodes.contiguous().cpu(), "global_to_local": partition.partition.global_to_local.contiguous().cpu(), "local_to_global": partition.partition.local_to_global.contiguous().cpu(), } # Save neighbor info for neighbor_id in partition.partition.neighbor_partitions: send_key = f"send_to_{neighbor_id}" recv_key = f"recv_from_{neighbor_id}" if neighbor_id in partition.partition.send_indices: tensors[send_key] = partition.partition.send_indices[neighbor_id].contiguous().cpu() if neighbor_id in partition.partition.recv_indices: tensors[recv_key] = partition.partition.recv_indices[neighbor_id].contiguous().cpu() meta = { "partition_id": str(i), "num_owned": str(partition.num_owned), "num_halo": str(partition.num_halo), "neighbors": json.dumps(partition.partition.neighbor_partitions), } save_file(tensors, str(directory / f"partition_{i}.safetensors"), metadata=meta) if verbose: print(f"Saved {num_partitions} partitions to {directory}")
[docs] def load_partition( directory: Union[str, Path], rank: int, world_size: Optional[int] = None, device: Union[str, torch.device] = "cpu" ) -> "DSparseMatrix": """ Load a single partition for the given rank. Each rank loads only its own partition, enabling efficient distributed loading. Parameters ---------- directory : str or Path Directory containing partitioned data. rank : int Rank of this process. world_size : int, optional Total number of processes (must match num_partitions). If None, reads from metadata. device : str or torch.device Device to load tensors to. Returns ------- DSparseMatrix The partition for this rank. Example ------- >>> # In distributed context >>> rank = dist.get_rank() >>> world_size = dist.get_world_size() >>> partition = load_partition("matrix_dist", rank, world_size, device="cuda") """ _ensure_safetensors() from .distributed import DSparseMatrix, Partition directory = Path(directory) # Load metadata with open(directory / "metadata.json", "r") as f: metadata = json.load(f) num_partitions = metadata["num_partitions"] global_shape = tuple(metadata["shape"]) if world_size is not None and world_size != num_partitions: raise ValueError( f"world_size ({world_size}) must match num_partitions ({num_partitions})" ) if rank >= num_partitions: raise ValueError( f"rank ({rank}) must be < num_partitions ({num_partitions})" ) # Load partition file tensors = load_file(str(directory / f"partition_{rank}.safetensors"), device=str(device)) values = tensors["values"] row_indices = tensors["row_indices"] col_indices = tensors["col_indices"] owned_nodes = tensors["owned_nodes"] halo_nodes = tensors["halo_nodes"] local_nodes = tensors["local_nodes"] global_to_local = tensors["global_to_local"] local_to_global = tensors["local_to_global"] # Reconstruct neighbor info partition_meta = metadata["partitions"][rank] neighbors = partition_meta["neighbors"] send_indices = {} recv_indices = {} for neighbor_id in neighbors: send_key = f"send_to_{neighbor_id}" recv_key = f"recv_from_{neighbor_id}" if send_key in tensors: send_indices[neighbor_id] = tensors[send_key] if recv_key in tensors: recv_indices[neighbor_id] = tensors[recv_key] # Create Partition object partition = Partition( partition_id=rank, local_nodes=local_nodes, owned_nodes=owned_nodes, halo_nodes=halo_nodes, neighbor_partitions=neighbors, send_indices=send_indices, recv_indices=recv_indices, global_to_local=global_to_local, local_to_global=local_to_global, ) # Create DSparseMatrix num_owned = len(owned_nodes) num_halo = len(halo_nodes) num_local = num_owned + num_halo local_shape = (num_local, num_local) return DSparseMatrix( partition=partition, local_values=values, local_row=row_indices, local_col=col_indices, local_shape=local_shape, global_shape=global_shape, num_partitions=num_partitions, device=device, verbose=False, )
[docs] def load_metadata(directory: Union[str, Path]) -> Dict: """ Load metadata from a distributed sparse tensor directory. Parameters ---------- directory : str or Path Directory containing partitioned data. Returns ------- dict Metadata including shape, dtype, num_partitions, etc. Example ------- >>> meta = load_metadata("matrix_dist") >>> print(f"Shape: {meta['shape']}, Partitions: {meta['num_partitions']}") """ directory = Path(directory) with open(directory / "metadata.json", "r") as f: return json.load(f)
[docs] def load_distributed_as_sparse( directory: Union[str, Path], device: Union[str, torch.device] = "cpu" ) -> "SparseTensor": """ Load a distributed/partitioned save as a single SparseTensor. This gathers all partitions into one SparseTensor. Useful when you have partitioned data but want to use it on a single node. Parameters ---------- directory : str or Path Directory containing partitioned data (from save_distributed or DSparseTensor.save). device : str or torch.device Device to load to. Returns ------- SparseTensor The complete sparse tensor. Example ------- >>> # Load partitioned data as single SparseTensor >>> A = load_distributed_as_sparse("matrix_dist", device="cuda") """ _ensure_safetensors() from .sparse_tensor import SparseTensor from .distributed import Partition directory = Path(directory) with open(directory / "metadata.json", "r") as f: metadata = json.load(f) num_partitions = metadata["num_partitions"] global_shape = tuple(metadata["shape"]) # Collect all entries from partitions all_values = [] all_rows = [] all_cols = [] for i in range(num_partitions): tensors = load_file(str(directory / f"partition_{i}.safetensors"), device="cpu") values = tensors["values"] row_indices = tensors["row_indices"] col_indices = tensors["col_indices"] owned_nodes = tensors["owned_nodes"] local_to_global = tensors["local_to_global"] num_owned = len(owned_nodes) # Only keep entries where row is owned (to avoid duplicates) owned_mask = row_indices < num_owned local_vals = values[owned_mask] local_rows = row_indices[owned_mask] local_cols = col_indices[owned_mask] # Convert local to global indices global_rows = local_to_global[local_rows] global_cols = local_to_global[local_cols] all_values.append(local_vals) all_rows.append(global_rows) all_cols.append(global_cols) global_values = torch.cat(all_values).to(device) global_rows = torch.cat(all_rows).to(device) global_cols = torch.cat(all_cols).to(device) return SparseTensor(global_values, global_rows, global_cols, global_shape)
# ============================================================================= # Alternative: Save as single file with partition index # ============================================================================= def save_sparse_sharded( tensor: "SparseTensor", path: Union[str, Path], num_shards: int, partition_method: str = "simple", coords: Optional[torch.Tensor] = None, metadata: Optional[Dict[str, str]] = None, verbose: bool = False ) -> None: """ Save a SparseTensor with shard information embedded. This saves all shards in a single directory but with clear shard boundaries, allowing selective loading. Parameters ---------- tensor : SparseTensor The global sparse tensor to save. path : str or Path Output directory path. num_shards : int Number of shards to create. partition_method : str Partitioning method. coords : torch.Tensor, optional Node coordinates for geometric partitioning. metadata : dict, optional Additional metadata. verbose : bool Print progress information. """ # Alias for save_distributed save_distributed(tensor, path, num_shards, partition_method, coords, verbose) def load_sparse_shard( path: Union[str, Path], shard_id: int, device: Union[str, torch.device] = "cpu" ) -> "DSparseMatrix": """ Load a specific shard from sharded sparse tensor. Parameters ---------- path : str or Path Directory containing sharded data. shard_id : int Shard ID to load. device : str or torch.device Device to load to. Returns ------- DSparseMatrix The loaded shard. """ # Alias for load_partition meta = load_metadata(path) return load_partition(path, shard_id, meta["num_partitions"], device) # ============================================================================= # DSparseTensor I/O # =============================================================================
[docs] def save_dsparse( tensor: "DSparseTensor", directory: Union[str, Path], verbose: bool = False ) -> None: """ Save a DSparseTensor to disk. Parameters ---------- tensor : DSparseTensor The distributed sparse tensor to save. directory : str or Path Output directory. verbose : bool Print progress. """ _ensure_safetensors() from .distributed import DSparseTensor if not isinstance(tensor, DSparseTensor): raise TypeError(f"Expected DSparseTensor, got {type(tensor)}") directory = Path(directory) directory.mkdir(parents=True, exist_ok=True) # Save metadata metadata = { "version": "1.0", "format": "dsparse_tensor", "shape": list(tensor.shape), "dtype": str(tensor.dtype), "num_partitions": tensor.num_partitions, } partition_info = [] for i, partition in enumerate(tensor._partitions): info = { "partition_id": i, "num_owned": int(partition.num_owned), "num_halo": int(partition.num_halo), "num_local": int(partition.num_local), "nnz": int(partition.nnz), "neighbors": partition.partition.neighbor_partitions, } partition_info.append(info) metadata["partitions"] = partition_info with open(directory / "metadata.json", "w") as f: json.dump(metadata, f, indent=2) # Save each partition for i, partition in enumerate(tensor._partitions): tensors = { "values": partition.local_values.contiguous().cpu(), "row_indices": partition.local_row.contiguous().cpu(), "col_indices": partition.local_col.contiguous().cpu(), "owned_nodes": partition.partition.owned_nodes.contiguous().cpu(), "halo_nodes": partition.partition.halo_nodes.contiguous().cpu(), "local_nodes": partition.partition.local_nodes.contiguous().cpu(), "global_to_local": partition.partition.global_to_local.contiguous().cpu(), "local_to_global": partition.partition.local_to_global.contiguous().cpu(), } for neighbor_id in partition.partition.neighbor_partitions: if neighbor_id in partition.partition.send_indices: tensors[f"send_to_{neighbor_id}"] = partition.partition.send_indices[neighbor_id].contiguous().cpu() if neighbor_id in partition.partition.recv_indices: tensors[f"recv_from_{neighbor_id}"] = partition.partition.recv_indices[neighbor_id].contiguous().cpu() save_file(tensors, str(directory / f"partition_{i}.safetensors")) if verbose: print(f"Saved DSparseTensor with {tensor.num_partitions} partitions to {directory}")
[docs] def load_dsparse( directory: Union[str, Path], device: Union[str, torch.device] = "cpu" ) -> "DSparseTensor": """ Load a complete DSparseTensor from disk. Parameters ---------- directory : str or Path Directory containing saved data. device : str or torch.device Device to load to. Returns ------- DSparseTensor The loaded distributed sparse tensor. """ _ensure_safetensors() from .distributed import DSparseTensor, DSparseMatrix, Partition directory = Path(directory) with open(directory / "metadata.json", "r") as f: metadata = json.load(f) num_partitions = metadata["num_partitions"] global_shape = tuple(metadata["shape"]) partitions = [] for i in range(num_partitions): tensors = load_file(str(directory / f"partition_{i}.safetensors"), device=str(device)) partition_meta = metadata["partitions"][i] neighbors = partition_meta["neighbors"] send_indices = {} recv_indices = {} for neighbor_id in neighbors: send_key = f"send_to_{neighbor_id}" recv_key = f"recv_from_{neighbor_id}" if send_key in tensors: send_indices[neighbor_id] = tensors[send_key] if recv_key in tensors: recv_indices[neighbor_id] = tensors[recv_key] partition = Partition( partition_id=i, local_nodes=tensors["local_nodes"], owned_nodes=tensors["owned_nodes"], halo_nodes=tensors["halo_nodes"], neighbor_partitions=neighbors, send_indices=send_indices, recv_indices=recv_indices, global_to_local=tensors["global_to_local"], local_to_global=tensors["local_to_global"], ) num_owned = len(tensors["owned_nodes"]) num_halo = len(tensors["halo_nodes"]) num_local = num_owned + num_halo dsm = DSparseMatrix( partition=partition, local_values=tensors["values"], local_row=tensors["row_indices"], local_col=tensors["col_indices"], local_shape=(num_local, num_local), global_shape=global_shape, num_partitions=num_partitions, device=device, verbose=False, ) partitions.append(dsm) # Create DSparseTensor from partitions # We need to reconstruct global data from partitions for the gather method # Collect values, rows, cols from all partitions (owned portion only) all_values = [] all_rows = [] all_cols = [] for p in partitions: owned_mask = p.local_row < p.num_owned local_vals = p.local_values[owned_mask] local_rows = p.local_row[owned_mask] local_cols = p.local_col[owned_mask] # Convert local to global indices global_rows = p.partition.local_to_global[local_rows] global_cols = p.partition.local_to_global[local_cols] all_values.append(local_vals) all_rows.append(global_rows) all_cols.append(global_cols) global_values = torch.cat(all_values) global_rows = torch.cat(all_rows) global_cols = torch.cat(all_cols) d_tensor = DSparseTensor.__new__(DSparseTensor) d_tensor._partitions = partitions d_tensor._shape = global_shape d_tensor._device = torch.device(device) d_tensor._values = global_values.to(device) d_tensor._row_indices = global_rows.to(device) d_tensor._col_indices = global_cols.to(device) d_tensor._distributed_mode = False d_tensor._num_partitions = num_partitions d_tensor._coords = None d_tensor._partition_method = 'loaded' d_tensor._verbose = False return d_tensor