AI-research-SKILLs pytorch-fsdp2

Adds PyTorch FSDP2 (fully_shard) to training scripts with correct init, sharding, mixed precision/offload config, and distributed checkpointing. Use when models exceed single-GPU memory or when you need DTensor-based sharding with DeviceMesh.

install
source · Clone the upstream repo
git clone https://github.com/Orchestra-Research/AI-Research-SKILLs
Claude Code · Install into ~/.claude/skills/
T=$(mktemp -d) && git clone --depth=1 https://github.com/Orchestra-Research/AI-Research-SKILLs "$T" && mkdir -p ~/.claude/skills && cp -r "$T/08-distributed-training/pytorch-fsdp2" ~/.claude/skills/orchestra-research-ai-research-skills-pytorch-fsdp2-b54004 && rm -rf "$T"
manifest: 08-distributed-training/pytorch-fsdp2/SKILL.md
source content

Skill: Use PyTorch FSDP2 (
fully_shard
) correctly in a training script

This skill teaches a coding agent how to add PyTorch FSDP2 to a training loop with correct initialization, sharding, mixed precision/offload configuration, and checkpointing.

FSDP2 in PyTorch is exposed primarily via

torch.distributed.fsdp.fully_shard
and the
FSDPModule
methods it adds in-place to modules. See:
references/pytorch_fully_shard_api.md
,
references/pytorch_fsdp2_tutorial.md
.


When to use this skill

Use FSDP2 when:

  • Your model doesn’t fit on one GPU (parameters + gradients + optimizer state).
  • You want an eager-mode sharding approach that is DTensor-based per-parameter sharding (more inspectable, simpler sharded state dicts) than FSDP1.
  • You may later compose DP with Tensor Parallel using DeviceMesh.

Avoid (or be careful) if:

  • You need strict backwards-compatible checkpoints across PyTorch versions (DCP warns against this).
  • You’re forced onto older PyTorch versions without the FSDP2 stack.

Alternatives (when FSDP2 is not the best fit)

  • DistributedDataParallel (DDP): Use the standard data-parallel wrapper when you want classic distributed data parallel training.
  • FullyShardedDataParallel (FSDP1): Use the original FSDP wrapper for parameter sharding across data-parallel workers.

Reference:

references/pytorch_ddp_notes.md
,
references/pytorch_fsdp1_api.md
.


Contract the agent must follow

  1. Launch with
    torchrun
    and set the CUDA device per process (usually via
    LOCAL_RANK
    ).
  2. Apply
    fully_shard()
    bottom-up
    , i.e., shard submodules (e.g., Transformer blocks) before the root module.
  3. Call
    model(input)
    , not
    model.forward(input)
    , so the FSDP2 hooks run (unless you explicitly
    unshard()
    or register the forward method).
  4. Create the optimizer after sharding and make sure it is built on the DTensor parameters (post-
    fully_shard
    ).
  5. Checkpoint using Distributed Checkpoint (DCP) or the distributed-state-dict helpers, not naïve
    torch.save(model.state_dict())
    unless you deliberately gather to full tensors.

(Each of these rules is directly described in the official API docs/tutorial; see references.)


Step-by-step procedure

0) Version & environment sanity

  • Prefer a recent stable PyTorch where the docs show FSDP2 and DCP updated recently.
  • Use
    torchrun --nproc_per_node <gpus_per_node> ...
    and ensure
    RANK
    ,
    WORLD_SIZE
    ,
    LOCAL_RANK
    are visible.

Reference:

references/pytorch_fsdp2_tutorial.md
(launch commands and setup),
references/pytorch_fully_shard_api.md
(user contract).


1) Initialize distributed and set device

Minimal, correct pattern:

  • dist.init_process_group(backend="nccl")
  • torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
  • Optionally create a
    DeviceMesh
    to describe the data-parallel group(s)

Reference:

references/pytorch_device_mesh_tutorial.md
(why DeviceMesh exists & how it manages process groups).


2) Build model on meta device (recommended for very large models)

For big models, initialize on

meta
, apply sharding, then materialize weights on GPU:

  • with torch.device("meta"): model = ...
  • apply
    fully_shard(...)
    on submodules, then
    fully_shard(model)
  • model.to_empty(device="cuda")
  • model.reset_parameters()
    (or your init routine)

Reference:

references/pytorch_fsdp2_tutorial.md
(migration guide shows this flow explicitly).


3) Apply
fully_shard()
bottom-up (wrapping policy = “apply where needed”)

Do not only call

fully_shard
on the topmost module.

Recommended sharding pattern for transformer-like models:

  • iterate modules,
    if isinstance(m, TransformerBlock): fully_shard(m, ...)
  • then
    fully_shard(model, ...)

Why:

  • fully_shard
    forms “parameter groups” for collective efficiency and excludes params already grouped by earlier calls. Bottom-up gives better overlap and lower peak memory.

Reference:

references/pytorch_fully_shard_api.md
(bottom-up requirement and why).


4) Configure
reshard_after_forward
for memory/perf trade-offs

Default behavior:

  • None
    means
    True
    for non-root modules and
    False
    for root modules (good default).

Heuristics:

  • If you’re memory-bound: keep defaults or force
    True
    on many blocks.
  • If you’re throughput-bound and can afford memory: consider keeping unsharded params longer (root often
    False
    ).
  • Advanced: use an
    int
    to reshard to a smaller mesh after forward (e.g., intra-node) if it’s a meaningful divisor.

Reference:

references/pytorch_fully_shard_api.md
(full semantics).


5) Mixed precision & offload (optional but common)

FSDP2 uses:

  • mp_policy=MixedPrecisionPolicy(param_dtype=..., reduce_dtype=..., output_dtype=..., cast_forward_inputs=...)
  • offload_policy=CPUOffloadPolicy()
    if you want CPU offload

Rules of thumb:

  • Start with BF16 parameters/reductions on H100/A100-class GPUs (if numerically stable for your model).
  • Keep
    reduce_dtype
    aligned with your gradient reduction expectations.
  • If you use CPU offload, budget for PCIe/NVLink traffic and runtime overhead.

Reference:

references/pytorch_fully_shard_api.md
(MixedPrecisionPolicy / OffloadPolicy classes).


6) Optimizer, gradient clipping, accumulation

  • Create the optimizer after sharding so it holds DTensor params.
  • If you need gradient accumulation / no_sync:
    • use the FSDP2 mechanism (
      set_requires_gradient_sync
      ) instead of FSDP1’s
      no_sync()
      .

Gradient clipping:

  • Use the approach shown in the FSDP2 tutorial (“Gradient Clipping and Optimizer with DTensor”), because parameters/gradients are DTensors.

Reference:

references/pytorch_fsdp2_tutorial.md
.


7) Checkpointing: prefer DCP or distributed state dict helpers

Two recommended approaches:

A) Distributed Checkpoint (DCP) — best default

  • DCP saves/loads from multiple ranks in parallel and supports load-time resharding.
  • DCP produces multiple files (often at least one per rank) and operates “in place”.

B) Distributed state dict helpers

  • get_model_state_dict
    /
    set_model_state_dict
    with
    StateDictOptions(full_state_dict=True, cpu_offload=True, broadcast_from_rank0=True, ...)
  • For optimizer:
    get_optimizer_state_dict
    /
    set_optimizer_state_dict

Avoid:

  • Saving DTensor state dicts with plain
    torch.save
    unless you intentionally convert with
    DTensor.full_tensor()
    and manage memory carefully.

References:

  • references/pytorch_dcp_overview.md
    (DCP behavior and caveats)
  • references/pytorch_dcp_recipe.md
    and
    references/pytorch_dcp_async_recipe.md
    (end-to-end usage)
  • references/pytorch_fsdp2_tutorial.md
    (DTensor vs DCP state-dict flows)
  • references/pytorch_examples_fsdp2.md
    (working checkpoint scripts)

Workflow checklists (copy-paste friendly)

Workflow A: Retrofit FSDP2 into an existing training script

  • Launch with
    torchrun
    and initialize the process group.
  • Set the CUDA device from
    LOCAL_RANK
    ; create a
    DeviceMesh
    if you need multi-dim parallelism.
  • Build the model (use
    meta
    if needed), apply
    fully_shard
    bottom-up, then
    fully_shard(model)
    .
  • Create the optimizer after sharding so it captures DTensor parameters.
  • Use
    model(inputs)
    so hooks run; use
    set_requires_gradient_sync
    for accumulation.
  • Add DCP save/load via
    torch.distributed.checkpoint
    helpers.

Reference:

references/pytorch_fsdp2_tutorial.md
,
references/pytorch_fully_shard_api.md
,
references/pytorch_device_mesh_tutorial.md
,
references/pytorch_dcp_recipe.md
.

Workflow B: Add DCP save/load (minimal pattern)

  • Wrap state in
    Stateful
    or assemble state via
    get_state_dict
    .
  • Call
    dcp.save(...)
    from all ranks to a shared path.
  • Call
    dcp.load(...)
    and restore with
    set_state_dict
    .
  • Validate any resharding assumptions when loading into a different mesh.

Reference:

references/pytorch_dcp_recipe.md
.

Debug checklist (what the agent should check first)

  1. All ranks on distinct GPUs?
    If not, verify
    torch.cuda.set_device(LOCAL_RANK)
    and your
    torchrun
    flags.
  2. Did you accidentally call
    forward()
    directly?

    Use
    model(input)
    or explicitly
    unshard()
    / register forward.
  3. Is
    fully_shard()
    applied bottom-up?

    If only root is sharded, expect worse memory/perf and possible confusion.
  4. Optimizer created at the right time?
    Must be built on DTensor parameters after sharding.
  5. Checkpointing path consistent?
    • If using DCP, don’t mix with ad-hoc
      torch.save
      unless you understand conversions.
    • Be mindful of PyTorch-version compatibility warnings for DCP.

Common issues and fixes

  • Forward hooks not running → Call
    model(inputs)
    (or
    unshard()
    explicitly) instead of
    model.forward(...)
    .
  • Optimizer sees non-DTensor params → Create optimizer after all
    fully_shard
    calls.
  • Only root module sharded → Apply
    fully_shard
    bottom-up on submodules before the root.
  • Memory spikes after forward → Set
    reshard_after_forward=True
    for more modules.
  • Gradient accumulation desync → Use
    set_requires_gradient_sync
    instead of FSDP1’s
    no_sync()
    .

Reference:

references/pytorch_fully_shard_api.md
,
references/pytorch_fsdp2_tutorial.md
.


Minimal reference implementation outline (agent-friendly)

The coding agent should implement a script with these labeled blocks:

  • init_distributed()
    : init process group, set device
  • build_model_meta()
    : model on meta, apply
    fully_shard
    , materialize weights
  • build_optimizer()
    : optimizer created after sharding
  • train_step()
    : forward/backward/step with
    model(inputs)
    and DTensor-aware patterns
  • checkpoint_save/load()
    : DCP or distributed state dict helpers

Concrete examples live in

references/pytorch_examples_fsdp2.md
and the official tutorial reference.


References

  • references/pytorch_fsdp2_tutorial.md
  • references/pytorch_fully_shard_api.md
  • references/pytorch_ddp_notes.md
  • references/pytorch_fsdp1_api.md
  • references/pytorch_device_mesh_tutorial.md
  • references/pytorch_tp_tutorial.md
  • references/pytorch_dcp_overview.md
  • references/pytorch_dcp_recipe.md
  • references/pytorch_dcp_async_recipe.md
  • references/pytorch_examples_fsdp2.md
  • references/torchtitan_fsdp_notes.md
    (optional, production notes)
  • references/ray_train_fsdp2_example.md
    (optional, integration example)