Claude-skill-registry debug-distributed

Guide for debugging distributed training issues in AReaL. Use when user encounters hangs, wrong results, OOM, or communication errors.

install
source · Clone the upstream repo
git clone https://github.com/majiayu000/claude-skill-registry
Claude Code · Install into ~/.claude/skills/
T=$(mktemp -d) && git clone --depth=1 https://github.com/majiayu000/claude-skill-registry "$T" && mkdir -p ~/.claude/skills && cp -r "$T/skills/data/debug-distributed" ~/.claude/skills/majiayu000-claude-skill-registry-debug-distributed && rm -rf "$T"
manifest: skills/data/debug-distributed/SKILL.md
source content

Debug Distributed Training

Debugging guide for distributed training issues in AReaL (FSDP2, TP, CP, EP).

When to Use

This skill is triggered when:

  • Training hangs or deadlocks
  • Results differ across ranks or are numerically wrong
  • OOM errors in distributed settings
  • NCCL/communication errors or device mesh issues

Debugging Principles

Minimal Reproduction

Always follow the minimal demo principle: Reproduce with the least amount of code to narrow down the issue faster.

# Bad: Debug in full training loop
# Good: Create minimal script
import torch
import torch.distributed as dist

dist.init_process_group("nccl")
rank = dist.get_rank()

# Reproduce the exact operation that fails
tensor = torch.ones(10).cuda()
dist.all_reduce(tensor)  # <-- Isolate the failing op
print(f"Rank {rank}: {tensor}")

Reduction strategy:

  1. Remove unrelated model components
  2. Use small tensor sizes
  3. Reduce world_size to minimum (e.g., 2 GPUs)
  4. Remove torch.compile if possible
  5. Disable activation checkpointing

Step-by-Step Debugging Guide

1. Hang Debugging (Deadlocks, Synchronization)

Environment Variables for Debugging:

# Full debug logging
export TORCH_DISTRIBUTED_DEBUG=DETAIL
export NCCL_DEBUG=INFO
export NCCL_DEBUG_SUBSYS=ALL

# torch.compile debugging
export TORCH_LOGS="+dynamo,recompiles"
export TORCHDYNAMO_VERBOSE=1

Dump Call Stack with py-spy (for hung processes):

# Find process IDs
ps aux | grep python

# Dump call stack of specific rank
py-spy dump --pid <PID>

# Record flame graph for performance analysis
py-spy record -o profile.svg --pid <PID> --duration 30

Common Causes:

  1. Mismatched Collectives: One rank calls
    all_reduce
    , another doesn't.
  2. Wrong Process Group: Using wrong group for collective.
  3. Tensor Shape Mismatch: Different shapes across ranks.

Debug Steps:

# Verify group membership
mesh = parallel_dims.get_mesh("dp_shard_cp")
group = mesh.get_group()
print(f"Rank {dist.get_rank()}: group size = {dist.get_world_size(group)}")

# Print shapes on all ranks
print(f"Rank {dist.get_rank()}: tensor.shape = {tensor.shape}")
dist.barrier()

Timeout Adjustment (for debugging only):

from areal.utils.distributed import patch_dist_group_timeout
from datetime import timedelta
patch_dist_group_timeout(timedelta(minutes=30))

2. Wrong Results (Gradient, Reduction Issues)

Check DTensor Placements:

from torch.distributed.tensor import DTensor
if isinstance(param, DTensor):
    print(f"Param {name}: placements={param.placements}, mesh={param.device_mesh}")

Verify Gradient Reduction:

for name, param in model.named_parameters():
    if param.grad is not None:
        print(f"Rank {dist.get_rank()}: {name} grad_sum = {param.grad.sum().item()}")

3. OOM Issues (Memory, Sharding)

Check Memory Usage:

print(f"Rank {dist.get_rank()}: "
      f"allocated={torch.cuda.memory_allocated()/1e9:.2f}GB, "
      f"reserved={torch.cuda.memory_reserved()/1e9:.2f}GB")

Check FSDP Coverage:

for name, param in model.named_parameters():
    is_dtensor = isinstance(param, DTensor)
    print(f"{name}: is_dtensor={is_dtensor}, shape={param.shape}")

4. Communication Errors

ErrorCauseSolution
NCCL WARN Cuda failure
GPU communicationCheck NCCL version, GPU topology
RuntimeError: Timed out
Rank synchronizationIncrease timeout, check code paths
Invalid device mesh
Mesh configurationVerify world_size = dp * tp * cp

Debugging Tools

Environment Variables Reference

VariablePurpose
TORCH_DISTRIBUTED_DEBUG=DETAIL
Detailed distributed logging
NCCL_DEBUG=INFO
NCCL communication logging
NCCL_DEBUG_SUBSYS=ALL
All NCCL subsystems
TORCH_LOGS="+dynamo,recompiles"
torch.compile logging
TORCHDYNAMO_VERBOSE=1
Dynamo verbose output
CUDA_LAUNCH_BLOCKING=1
Synchronous CUDA (slow, for debugging)

py-spy for Call Stack Analysis

# Install
pip install py-spy

# Dump call stack of hung process
py-spy dump --pid <PID>

# Dump all Python processes
pgrep -f python | xargs -I {} py-spy dump --pid {}

# Record flame graph
py-spy record -o profile.svg --pid <PID> --duration 30

Rank-Conditional Printing

def print_all_ranks(msg):
    for r in range(dist.get_world_size()):
        if dist.get_rank() == r:
            print(f"[Rank {r}] {msg}")
        dist.barrier()

Check Device Mesh

def debug_mesh(parallel_dims):
    mesh = parallel_dims.world_mesh
    for dim_name in mesh.mesh_dim_names:
        submesh = parallel_dims.get_mesh(dim_name)
        if submesh:
            print(f"Rank {dist.get_rank()}: {dim_name} size={submesh.size()}")

Validate Tensor Consistency

def check_tensor_consistency(tensor, name, group=None):
    local_sum = tensor.sum().item()
    tensor_sums = [None] * dist.get_world_size(group)
    dist.all_gather_object(tensor_sums, local_sum, group=group)
    if dist.get_rank() == 0 and len(set(tensor_sums)) > 1:
        print(f"WARNING: {name} inconsistent: {tensor_sums}")

Key Files Reference

ComponentFile
Parallel Dims
areal/experimental/models/archon/parallel_dims.py
Expert Parallel
areal/experimental/models/archon/expert_parallel.py
Ulysses (CP)
areal/experimental/models/archon/ulysses.py
FSDP/TP Apply
areal/experimental/models/archon/qwen2/infra/parallelize.py

<!-- ================================================================================ MAINTAINER GUIDE ================================================================================ Location: .claude/skills/debug-distributed/SKILL.md Invocation: /debug-distributed ## Purpose Debugging guide for distributed training issues. Covers FSDP2, Tensor Parallelism, Context Parallelism, and Expert Parallelism. ## How to Update ### When Adding New Parallelism Features 1. Add section for the parallelism type 2. Document common error patterns and debugging snippets ### When PyTorch Distributed APIs Change 1. Update DTensor/DeviceMesh examples 2. Update environment variable references ### When New Error Patterns Emerge 1. Add to "Common Errors and Solutions" table 2. Reference relevant source files ================================================================================ -->