"""
Persistence utilities for SparseTensor and DSparseTensor.
Supports:
- safetensors format for efficient, safe serialization
- Matrix Market (.mtx) format for interoperability with other tools
- Distributed loading where different ranks load different partitions
Example
-------
>>> from torch_sla import SparseTensor
>>> from torch_sla.io import save_sparse, load_sparse, save_distributed, load_partition
>>>
>>> # Save single SparseTensor
>>> A = SparseTensor(val, row, col, shape)
>>> save_sparse(A, "matrix.safetensors")
>>> A_loaded = load_sparse("matrix.safetensors")
>>>
>>> # Matrix Market format
>>> save_mtx(A, "matrix.mtx")
>>> A = load_mtx("matrix.mtx")
>>>
>>> # Save partitioned for distributed loading
>>> save_distributed(A, "matrix_dist", num_partitions=4)
>>> # Each rank loads its partition
>>> partition = load_partition("matrix_dist", rank=rank, world_size=4)
"""
import os
import json
import torch
from typing import Dict, Optional, Tuple, Union, TYPE_CHECKING
from pathlib import Path
if TYPE_CHECKING:
from .sparse_tensor import SparseTensor
from .distributed import DSparseMatrix, DSparseTensor
try:
from safetensors.torch import save_file, load_file
SAFETENSORS_AVAILABLE = True
except ImportError:
SAFETENSORS_AVAILABLE = False
def _ensure_safetensors():
"""Ensure safetensors is available."""
if not SAFETENSORS_AVAILABLE:
raise ImportError(
"safetensors is required for persistence. "
"Install with: pip install safetensors"
)
# =============================================================================
# Matrix Market (.mtx) Format
# =============================================================================
[docs]
def save_mtx(
tensor: "SparseTensor",
path: Union[str, Path],
comment: str = "",
field: str = "real",
symmetry: str = "general",
) -> None:
"""
Save a SparseTensor to Matrix Market (.mtx) format.
Parameters
----------
tensor : SparseTensor
The sparse tensor to save.
path : str or Path
Output file path (should end with .mtx).
comment : str, optional
Comment to include in the header.
field : str, optional
Field type: 'real', 'complex', 'integer', or 'pattern'.
Default: 'real'.
symmetry : str, optional
Symmetry type: 'general', 'symmetric', 'skew-symmetric', or 'hermitian'.
Default: 'general'.
Example
-------
>>> A = SparseTensor(val, row, col, (100, 100))
>>> save_mtx(A, "matrix.mtx")
>>> save_mtx(A, "matrix.mtx", symmetry="symmetric")
"""
from .sparse_tensor import SparseTensor
if not isinstance(tensor, SparseTensor):
raise TypeError(f"Expected SparseTensor, got {type(tensor)}")
if tensor.is_batched:
raise ValueError("Cannot save batched SparseTensor to Matrix Market format")
path = Path(path)
M, N = tensor.sparse_shape
nnz = tensor.nnz
# Get data on CPU
row = tensor.row_indices.cpu().numpy()
col = tensor.col_indices.cpu().numpy()
val = tensor.values.cpu().numpy()
# Determine field type from dtype
if tensor.dtype in (torch.complex64, torch.complex128):
field = "complex"
elif tensor.dtype in (torch.int32, torch.int64):
field = "integer"
with open(path, 'w') as f:
# Write header
f.write(f"%%MatrixMarket matrix coordinate {field} {symmetry}\n")
if comment:
for line in comment.split('\n'):
f.write(f"% {line}\n")
f.write(f"% Generated by torch-sla\n")
# Write dimensions
f.write(f"{M} {N} {nnz}\n")
# Write entries (1-indexed)
if field == "pattern":
for i in range(nnz):
f.write(f"{row[i] + 1} {col[i] + 1}\n")
elif field == "complex":
for i in range(nnz):
v = val[i]
f.write(f"{row[i] + 1} {col[i] + 1} {v.real} {v.imag}\n")
else:
for i in range(nnz):
f.write(f"{row[i] + 1} {col[i] + 1} {val[i]}\n")
[docs]
def load_mtx(
path: Union[str, Path],
dtype: Optional[torch.dtype] = None,
device: Union[str, torch.device] = "cpu",
) -> "SparseTensor":
"""
Load a SparseTensor from Matrix Market (.mtx) format.
Parameters
----------
path : str or Path
Input file path.
dtype : torch.dtype, optional
Data type for values. If None, inferred from file.
device : str or torch.device
Device to load tensors to.
Returns
-------
SparseTensor
The loaded sparse tensor.
Example
-------
>>> A = load_mtx("matrix.mtx")
>>> A = load_mtx("matrix.mtx", dtype=torch.float32, device="cuda")
"""
from .sparse_tensor import SparseTensor
import numpy as np
path = Path(path)
with open(path, 'r') as f:
# Parse header
header = f.readline().strip()
if not header.startswith("%%MatrixMarket"):
raise ValueError(f"Invalid Matrix Market header: {header}")
parts = header.split()
if len(parts) < 4:
raise ValueError(f"Invalid Matrix Market header: {header}")
obj_type = parts[1].lower() # matrix
format_type = parts[2].lower() # coordinate
field_type = parts[3].lower() if len(parts) > 3 else "real"
symmetry = parts[4].lower() if len(parts) > 4 else "general"
if obj_type != "matrix":
raise ValueError(f"Only 'matrix' object type supported, got: {obj_type}")
if format_type != "coordinate":
raise ValueError(f"Only 'coordinate' format supported, got: {format_type}")
# Skip comments
line = f.readline()
while line.startswith('%'):
line = f.readline()
# Parse dimensions
dims = line.strip().split()
M, N, nnz = int(dims[0]), int(dims[1]), int(dims[2])
# Parse entries
rows = []
cols = []
vals = []
for line in f:
parts = line.strip().split()
if not parts:
continue
r = int(parts[0]) - 1 # Convert to 0-indexed
c = int(parts[1]) - 1
if field_type == "pattern":
v = 1.0
elif field_type == "complex":
v = complex(float(parts[2]), float(parts[3]))
elif field_type == "integer":
v = int(parts[2])
else:
v = float(parts[2])
rows.append(r)
cols.append(c)
vals.append(v)
# Handle symmetry
if symmetry == "symmetric" and r != c:
rows.append(c)
cols.append(r)
vals.append(v)
elif symmetry == "skew-symmetric" and r != c:
rows.append(c)
cols.append(r)
vals.append(-v)
elif symmetry == "hermitian" and r != c:
rows.append(c)
cols.append(r)
vals.append(v.conjugate() if isinstance(v, complex) else v)
# Convert to tensors
row_tensor = torch.tensor(rows, dtype=torch.long, device=device)
col_tensor = torch.tensor(cols, dtype=torch.long, device=device)
# Determine dtype
if dtype is None:
if field_type == "complex":
dtype = torch.complex128
elif field_type == "integer":
dtype = torch.int64
else:
dtype = torch.float64
val_tensor = torch.tensor(vals, dtype=dtype, device=device)
return SparseTensor(val_tensor, row_tensor, col_tensor, (M, N))
[docs]
def load_mtx_info(path: Union[str, Path]) -> Dict:
"""
Read Matrix Market file header without loading data.
Parameters
----------
path : str or Path
Input file path.
Returns
-------
dict
Dictionary with keys: 'shape', 'nnz', 'field', 'symmetry'.
Example
-------
>>> info = load_mtx_info("matrix.mtx")
>>> print(f"Shape: {info['shape']}, NNZ: {info['nnz']}")
"""
path = Path(path)
with open(path, 'r') as f:
header = f.readline().strip()
parts = header.split()
field_type = parts[3].lower() if len(parts) > 3 else "real"
symmetry = parts[4].lower() if len(parts) > 4 else "general"
# Skip comments
line = f.readline()
while line.startswith('%'):
line = f.readline()
dims = line.strip().split()
M, N, nnz = int(dims[0]), int(dims[1]), int(dims[2])
return {
'shape': (M, N),
'nnz': nnz,
'field': field_type,
'symmetry': symmetry,
}
# =============================================================================
# SparseTensor I/O
# =============================================================================
[docs]
def save_sparse(
tensor: "SparseTensor",
path: Union[str, Path],
metadata: Optional[Dict[str, str]] = None
) -> None:
"""
Save a SparseTensor to safetensors format.
Parameters
----------
tensor : SparseTensor
The sparse tensor to save.
path : str or Path
Output file path (should end with .safetensors).
metadata : dict, optional
Additional metadata to store in the file.
Example
-------
>>> A = SparseTensor(val, row, col, (100, 100))
>>> save_sparse(A, "matrix.safetensors")
"""
_ensure_safetensors()
from .sparse_tensor import SparseTensor
if not isinstance(tensor, SparseTensor):
raise TypeError(f"Expected SparseTensor, got {type(tensor)}")
# Prepare tensors dict
tensors = {
"values": tensor.values.contiguous().cpu(),
"row_indices": tensor.row_indices.contiguous().cpu(),
"col_indices": tensor.col_indices.contiguous().cpu(),
"shape": torch.tensor(tensor.sparse_shape, dtype=torch.int64),
}
# Prepare metadata
meta = {
"sparse_dim_0": str(tensor.sparse_dim[0]),
"sparse_dim_1": str(tensor.sparse_dim[1]),
"dtype": str(tensor.dtype),
"format": "sparse_tensor",
"version": "1.0",
}
if metadata:
meta.update(metadata)
save_file(tensors, str(path), metadata=meta)
[docs]
def load_sparse(
path: Union[str, Path],
device: Union[str, torch.device] = "cpu"
) -> "SparseTensor":
"""
Load a SparseTensor from safetensors format.
Parameters
----------
path : str or Path
Input file path.
device : str or torch.device
Device to load tensors to.
Returns
-------
SparseTensor
The loaded sparse tensor.
Example
-------
>>> A = load_sparse("matrix.safetensors", device="cuda")
"""
_ensure_safetensors()
from .sparse_tensor import SparseTensor
tensors = load_file(str(path), device=str(device))
values = tensors["values"]
row_indices = tensors["row_indices"]
col_indices = tensors["col_indices"]
shape = tuple(tensors["shape"].tolist())
return SparseTensor(values, row_indices, col_indices, shape)
[docs]
def load_sparse_as_partition(
path: Union[str, Path],
rank: int,
world_size: int,
partition_method: str = "simple",
coords: Optional[torch.Tensor] = None,
device: Union[str, torch.device] = "cpu"
) -> "DSparseMatrix":
"""
Load a SparseTensor file and return only this rank's partition.
This allows distributed reading of a single SparseTensor file,
where each rank loads the full file but only keeps its partition.
For very large matrices, use save_distributed() instead to avoid
loading the full matrix on each rank.
Parameters
----------
path : str or Path
Path to SparseTensor file (.safetensors).
rank : int
Rank of this process.
world_size : int
Total number of processes.
partition_method : str
'simple', 'metis', or 'geometric'.
coords : torch.Tensor, optional
Node coordinates for geometric partitioning.
device : str or torch.device
Device to load partition to.
Returns
-------
DSparseMatrix
This rank's partition of the matrix.
Example
-------
>>> # Each rank calls this:
>>> rank = dist.get_rank()
>>> world_size = dist.get_world_size()
>>> partition = load_sparse_as_partition("matrix.safetensors", rank, world_size)
"""
_ensure_safetensors()
from .sparse_tensor import SparseTensor
from .distributed import DSparseMatrix, partition_graph_metis, partition_coordinates, partition_simple
# Load full matrix (could be optimized for very large matrices)
A = load_sparse(path, device="cpu")
# Compute partition IDs locally (same on all ranks for determinism)
shape = A.sparse_shape
if coords is not None:
partition_ids = partition_coordinates(coords, world_size)
elif partition_method == 'metis':
partition_ids = partition_graph_metis(A.row_indices, A.col_indices, shape[0], world_size)
else:
partition_ids = partition_simple(shape[0], world_size)
# Create partition for this rank
return DSparseMatrix.from_global(
A.values, A.row_indices, A.col_indices, shape,
world_size, rank,
partition_ids=partition_ids,
device=device,
verbose=(rank == 0)
)
# =============================================================================
# Distributed I/O - Save partitioned for multi-rank loading
# =============================================================================
[docs]
def save_distributed(
tensor: "SparseTensor",
directory: Union[str, Path],
num_partitions: int,
partition_method: str = "simple",
coords: Optional[torch.Tensor] = None,
verbose: bool = False
) -> None:
"""
Save a SparseTensor as partitioned files for distributed loading.
Creates a directory with:
- metadata.json: Global metadata and partition info
- partition_0.safetensors, partition_1.safetensors, ...: Per-partition data
Parameters
----------
tensor : SparseTensor
The global sparse tensor to partition and save.
directory : str or Path
Output directory path.
num_partitions : int
Number of partitions to create.
partition_method : str
Partitioning method: 'simple', 'metis', or 'geometric'.
coords : torch.Tensor, optional
Node coordinates for geometric partitioning.
verbose : bool
Print progress information.
Example
-------
>>> A = SparseTensor(val, row, col, (1000, 1000))
>>> save_distributed(A, "matrix_dist", num_partitions=4)
# Creates:
# matrix_dist/metadata.json
# matrix_dist/partition_0.safetensors
# matrix_dist/partition_1.safetensors
# matrix_dist/partition_2.safetensors
# matrix_dist/partition_3.safetensors
"""
_ensure_safetensors()
from .sparse_tensor import SparseTensor
from .distributed import DSparseTensor
if not isinstance(tensor, SparseTensor):
raise TypeError(f"Expected SparseTensor, got {type(tensor)}")
directory = Path(directory)
directory.mkdir(parents=True, exist_ok=True)
# Create DSparseTensor to get partitions
d_tensor = DSparseTensor(
tensor.values,
tensor.row_indices,
tensor.col_indices,
tensor.sparse_shape,
num_partitions=num_partitions,
coords=coords,
partition_method=partition_method,
verbose=verbose
)
# Save global metadata
metadata = {
"version": "1.0",
"format": "distributed_sparse_tensor",
"shape": list(tensor.sparse_shape),
"dtype": str(tensor.dtype),
"nnz": int(tensor.nnz),
"num_partitions": num_partitions,
"partition_method": partition_method,
"sparse_dim": list(tensor.sparse_dim),
}
# Collect partition metadata
partition_info = []
for i, partition in enumerate(d_tensor._partitions):
info = {
"partition_id": i,
"num_owned": int(partition.num_owned),
"num_halo": int(partition.num_halo),
"num_local": int(partition.num_local),
"nnz": int(partition.nnz),
"neighbors": partition.partition.neighbor_partitions,
}
partition_info.append(info)
metadata["partitions"] = partition_info
with open(directory / "metadata.json", "w") as f:
json.dump(metadata, f, indent=2)
# Save each partition
for i, partition in enumerate(d_tensor._partitions):
tensors = {
"values": partition.local_values.contiguous().cpu(),
"row_indices": partition.local_row.contiguous().cpu(),
"col_indices": partition.local_col.contiguous().cpu(),
"owned_nodes": partition.partition.owned_nodes.contiguous().cpu(),
"halo_nodes": partition.partition.halo_nodes.contiguous().cpu(),
"local_nodes": partition.partition.local_nodes.contiguous().cpu(),
"global_to_local": partition.partition.global_to_local.contiguous().cpu(),
"local_to_global": partition.partition.local_to_global.contiguous().cpu(),
}
# Save neighbor info
for neighbor_id in partition.partition.neighbor_partitions:
send_key = f"send_to_{neighbor_id}"
recv_key = f"recv_from_{neighbor_id}"
if neighbor_id in partition.partition.send_indices:
tensors[send_key] = partition.partition.send_indices[neighbor_id].contiguous().cpu()
if neighbor_id in partition.partition.recv_indices:
tensors[recv_key] = partition.partition.recv_indices[neighbor_id].contiguous().cpu()
meta = {
"partition_id": str(i),
"num_owned": str(partition.num_owned),
"num_halo": str(partition.num_halo),
"neighbors": json.dumps(partition.partition.neighbor_partitions),
}
save_file(tensors, str(directory / f"partition_{i}.safetensors"), metadata=meta)
if verbose:
print(f"Saved {num_partitions} partitions to {directory}")
[docs]
def load_partition(
directory: Union[str, Path],
rank: int,
world_size: Optional[int] = None,
device: Union[str, torch.device] = "cpu"
) -> "DSparseMatrix":
"""
Load a single partition for the given rank.
Each rank loads only its own partition, enabling efficient distributed loading.
Parameters
----------
directory : str or Path
Directory containing partitioned data.
rank : int
Rank of this process.
world_size : int, optional
Total number of processes (must match num_partitions).
If None, reads from metadata.
device : str or torch.device
Device to load tensors to.
Returns
-------
DSparseMatrix
The partition for this rank.
Example
-------
>>> # In distributed context
>>> rank = dist.get_rank()
>>> world_size = dist.get_world_size()
>>> partition = load_partition("matrix_dist", rank, world_size, device="cuda")
"""
_ensure_safetensors()
from .distributed import DSparseMatrix, Partition
directory = Path(directory)
# Load metadata
with open(directory / "metadata.json", "r") as f:
metadata = json.load(f)
num_partitions = metadata["num_partitions"]
global_shape = tuple(metadata["shape"])
if world_size is not None and world_size != num_partitions:
raise ValueError(
f"world_size ({world_size}) must match num_partitions ({num_partitions})"
)
if rank >= num_partitions:
raise ValueError(
f"rank ({rank}) must be < num_partitions ({num_partitions})"
)
# Load partition file
tensors = load_file(str(directory / f"partition_{rank}.safetensors"), device=str(device))
values = tensors["values"]
row_indices = tensors["row_indices"]
col_indices = tensors["col_indices"]
owned_nodes = tensors["owned_nodes"]
halo_nodes = tensors["halo_nodes"]
local_nodes = tensors["local_nodes"]
global_to_local = tensors["global_to_local"]
local_to_global = tensors["local_to_global"]
# Reconstruct neighbor info
partition_meta = metadata["partitions"][rank]
neighbors = partition_meta["neighbors"]
send_indices = {}
recv_indices = {}
for neighbor_id in neighbors:
send_key = f"send_to_{neighbor_id}"
recv_key = f"recv_from_{neighbor_id}"
if send_key in tensors:
send_indices[neighbor_id] = tensors[send_key]
if recv_key in tensors:
recv_indices[neighbor_id] = tensors[recv_key]
# Create Partition object
partition = Partition(
partition_id=rank,
local_nodes=local_nodes,
owned_nodes=owned_nodes,
halo_nodes=halo_nodes,
neighbor_partitions=neighbors,
send_indices=send_indices,
recv_indices=recv_indices,
global_to_local=global_to_local,
local_to_global=local_to_global,
)
# Create DSparseMatrix
num_owned = len(owned_nodes)
num_halo = len(halo_nodes)
num_local = num_owned + num_halo
local_shape = (num_local, num_local)
return DSparseMatrix(
partition=partition,
local_values=values,
local_row=row_indices,
local_col=col_indices,
local_shape=local_shape,
global_shape=global_shape,
num_partitions=num_partitions,
device=device,
verbose=False,
)
[docs]
def load_distributed_as_sparse(
directory: Union[str, Path],
device: Union[str, torch.device] = "cpu"
) -> "SparseTensor":
"""
Load a distributed/partitioned save as a single SparseTensor.
This gathers all partitions into one SparseTensor.
Useful when you have partitioned data but want to use it on a single node.
Parameters
----------
directory : str or Path
Directory containing partitioned data (from save_distributed or DSparseTensor.save).
device : str or torch.device
Device to load to.
Returns
-------
SparseTensor
The complete sparse tensor.
Example
-------
>>> # Load partitioned data as single SparseTensor
>>> A = load_distributed_as_sparse("matrix_dist", device="cuda")
"""
_ensure_safetensors()
from .sparse_tensor import SparseTensor
from .distributed import Partition
directory = Path(directory)
with open(directory / "metadata.json", "r") as f:
metadata = json.load(f)
num_partitions = metadata["num_partitions"]
global_shape = tuple(metadata["shape"])
# Collect all entries from partitions
all_values = []
all_rows = []
all_cols = []
for i in range(num_partitions):
tensors = load_file(str(directory / f"partition_{i}.safetensors"), device="cpu")
values = tensors["values"]
row_indices = tensors["row_indices"]
col_indices = tensors["col_indices"]
owned_nodes = tensors["owned_nodes"]
local_to_global = tensors["local_to_global"]
num_owned = len(owned_nodes)
# Only keep entries where row is owned (to avoid duplicates)
owned_mask = row_indices < num_owned
local_vals = values[owned_mask]
local_rows = row_indices[owned_mask]
local_cols = col_indices[owned_mask]
# Convert local to global indices
global_rows = local_to_global[local_rows]
global_cols = local_to_global[local_cols]
all_values.append(local_vals)
all_rows.append(global_rows)
all_cols.append(global_cols)
global_values = torch.cat(all_values).to(device)
global_rows = torch.cat(all_rows).to(device)
global_cols = torch.cat(all_cols).to(device)
return SparseTensor(global_values, global_rows, global_cols, global_shape)
# =============================================================================
# Alternative: Save as single file with partition index
# =============================================================================
def save_sparse_sharded(
tensor: "SparseTensor",
path: Union[str, Path],
num_shards: int,
partition_method: str = "simple",
coords: Optional[torch.Tensor] = None,
metadata: Optional[Dict[str, str]] = None,
verbose: bool = False
) -> None:
"""
Save a SparseTensor with shard information embedded.
This saves all shards in a single directory but with clear shard boundaries,
allowing selective loading.
Parameters
----------
tensor : SparseTensor
The global sparse tensor to save.
path : str or Path
Output directory path.
num_shards : int
Number of shards to create.
partition_method : str
Partitioning method.
coords : torch.Tensor, optional
Node coordinates for geometric partitioning.
metadata : dict, optional
Additional metadata.
verbose : bool
Print progress information.
"""
# Alias for save_distributed
save_distributed(tensor, path, num_shards, partition_method, coords, verbose)
def load_sparse_shard(
path: Union[str, Path],
shard_id: int,
device: Union[str, torch.device] = "cpu"
) -> "DSparseMatrix":
"""
Load a specific shard from sharded sparse tensor.
Parameters
----------
path : str or Path
Directory containing sharded data.
shard_id : int
Shard ID to load.
device : str or torch.device
Device to load to.
Returns
-------
DSparseMatrix
The loaded shard.
"""
# Alias for load_partition
meta = load_metadata(path)
return load_partition(path, shard_id, meta["num_partitions"], device)
# =============================================================================
# DSparseTensor I/O
# =============================================================================
[docs]
def save_dsparse(
tensor: "DSparseTensor",
directory: Union[str, Path],
verbose: bool = False
) -> None:
"""
Save a DSparseTensor to disk.
Parameters
----------
tensor : DSparseTensor
The distributed sparse tensor to save.
directory : str or Path
Output directory.
verbose : bool
Print progress.
"""
_ensure_safetensors()
from .distributed import DSparseTensor
if not isinstance(tensor, DSparseTensor):
raise TypeError(f"Expected DSparseTensor, got {type(tensor)}")
directory = Path(directory)
directory.mkdir(parents=True, exist_ok=True)
# Save metadata
metadata = {
"version": "1.0",
"format": "dsparse_tensor",
"shape": list(tensor.shape),
"dtype": str(tensor.dtype),
"num_partitions": tensor.num_partitions,
}
partition_info = []
for i, partition in enumerate(tensor._partitions):
info = {
"partition_id": i,
"num_owned": int(partition.num_owned),
"num_halo": int(partition.num_halo),
"num_local": int(partition.num_local),
"nnz": int(partition.nnz),
"neighbors": partition.partition.neighbor_partitions,
}
partition_info.append(info)
metadata["partitions"] = partition_info
with open(directory / "metadata.json", "w") as f:
json.dump(metadata, f, indent=2)
# Save each partition
for i, partition in enumerate(tensor._partitions):
tensors = {
"values": partition.local_values.contiguous().cpu(),
"row_indices": partition.local_row.contiguous().cpu(),
"col_indices": partition.local_col.contiguous().cpu(),
"owned_nodes": partition.partition.owned_nodes.contiguous().cpu(),
"halo_nodes": partition.partition.halo_nodes.contiguous().cpu(),
"local_nodes": partition.partition.local_nodes.contiguous().cpu(),
"global_to_local": partition.partition.global_to_local.contiguous().cpu(),
"local_to_global": partition.partition.local_to_global.contiguous().cpu(),
}
for neighbor_id in partition.partition.neighbor_partitions:
if neighbor_id in partition.partition.send_indices:
tensors[f"send_to_{neighbor_id}"] = partition.partition.send_indices[neighbor_id].contiguous().cpu()
if neighbor_id in partition.partition.recv_indices:
tensors[f"recv_from_{neighbor_id}"] = partition.partition.recv_indices[neighbor_id].contiguous().cpu()
save_file(tensors, str(directory / f"partition_{i}.safetensors"))
if verbose:
print(f"Saved DSparseTensor with {tensor.num_partitions} partitions to {directory}")
[docs]
def load_dsparse(
directory: Union[str, Path],
device: Union[str, torch.device] = "cpu"
) -> "DSparseTensor":
"""
Load a complete DSparseTensor from disk.
Parameters
----------
directory : str or Path
Directory containing saved data.
device : str or torch.device
Device to load to.
Returns
-------
DSparseTensor
The loaded distributed sparse tensor.
"""
_ensure_safetensors()
from .distributed import DSparseTensor, DSparseMatrix, Partition
directory = Path(directory)
with open(directory / "metadata.json", "r") as f:
metadata = json.load(f)
num_partitions = metadata["num_partitions"]
global_shape = tuple(metadata["shape"])
partitions = []
for i in range(num_partitions):
tensors = load_file(str(directory / f"partition_{i}.safetensors"), device=str(device))
partition_meta = metadata["partitions"][i]
neighbors = partition_meta["neighbors"]
send_indices = {}
recv_indices = {}
for neighbor_id in neighbors:
send_key = f"send_to_{neighbor_id}"
recv_key = f"recv_from_{neighbor_id}"
if send_key in tensors:
send_indices[neighbor_id] = tensors[send_key]
if recv_key in tensors:
recv_indices[neighbor_id] = tensors[recv_key]
partition = Partition(
partition_id=i,
local_nodes=tensors["local_nodes"],
owned_nodes=tensors["owned_nodes"],
halo_nodes=tensors["halo_nodes"],
neighbor_partitions=neighbors,
send_indices=send_indices,
recv_indices=recv_indices,
global_to_local=tensors["global_to_local"],
local_to_global=tensors["local_to_global"],
)
num_owned = len(tensors["owned_nodes"])
num_halo = len(tensors["halo_nodes"])
num_local = num_owned + num_halo
dsm = DSparseMatrix(
partition=partition,
local_values=tensors["values"],
local_row=tensors["row_indices"],
local_col=tensors["col_indices"],
local_shape=(num_local, num_local),
global_shape=global_shape,
num_partitions=num_partitions,
device=device,
verbose=False,
)
partitions.append(dsm)
# Create DSparseTensor from partitions
# We need to reconstruct global data from partitions for the gather method
# Collect values, rows, cols from all partitions (owned portion only)
all_values = []
all_rows = []
all_cols = []
for p in partitions:
owned_mask = p.local_row < p.num_owned
local_vals = p.local_values[owned_mask]
local_rows = p.local_row[owned_mask]
local_cols = p.local_col[owned_mask]
# Convert local to global indices
global_rows = p.partition.local_to_global[local_rows]
global_cols = p.partition.local_to_global[local_cols]
all_values.append(local_vals)
all_rows.append(global_rows)
all_cols.append(global_cols)
global_values = torch.cat(all_values)
global_rows = torch.cat(all_rows)
global_cols = torch.cat(all_cols)
d_tensor = DSparseTensor.__new__(DSparseTensor)
d_tensor._partitions = partitions
d_tensor._shape = global_shape
d_tensor._device = torch.device(device)
d_tensor._values = global_values.to(device)
d_tensor._row_indices = global_rows.to(device)
d_tensor._col_indices = global_cols.to(device)
d_tensor._distributed_mode = False
d_tensor._num_partitions = num_partitions
d_tensor._coords = None
d_tensor._partition_method = 'loaded'
d_tensor._verbose = False
return d_tensor