Skillshub pytorch-fsdp2
Skill: Use PyTorch FSDP2 (`fully_shard`) correctly in a training script
git clone https://github.com/ComeOnOliver/skillshub
T=$(mktemp -d) && git clone --depth=1 https://github.com/ComeOnOliver/skillshub "$T" && mkdir -p ~/.claude/skills && cp -r "$T/skills/Orchestra-Research/AI-Research-SKILLs/pytorch-fsdp2" ~/.claude/skills/comeonoliver-skillshub-pytorch-fsdp2 && rm -rf "$T"
skills/Orchestra-Research/AI-Research-SKILLs/pytorch-fsdp2/SKILL.mdSkill: Use PyTorch FSDP2 (fully_shard
) correctly in a training script
fully_shardThis skill teaches a coding agent how to add PyTorch FSDP2 to a training loop with correct initialization, sharding, mixed precision/offload configuration, and checkpointing.
FSDP2 in PyTorch is exposed primarily via
and thetorch.distributed.fsdp.fully_shardmethods it adds in-place to modules. See:FSDPModule,references/pytorch_fully_shard_api.md.references/pytorch_fsdp2_tutorial.md
When to use this skill
Use FSDP2 when:
- Your model doesn’t fit on one GPU (parameters + gradients + optimizer state).
- You want an eager-mode sharding approach that is DTensor-based per-parameter sharding (more inspectable, simpler sharded state dicts) than FSDP1.
- You may later compose DP with Tensor Parallel using DeviceMesh.
Avoid (or be careful) if:
- You need strict backwards-compatible checkpoints across PyTorch versions (DCP warns against this).
- You’re forced onto older PyTorch versions without the FSDP2 stack.
Alternatives (when FSDP2 is not the best fit)
- DistributedDataParallel (DDP): Use the standard data-parallel wrapper when you want classic distributed data parallel training.
- FullyShardedDataParallel (FSDP1): Use the original FSDP wrapper for parameter sharding across data-parallel workers.
Reference:
references/pytorch_ddp_notes.md, references/pytorch_fsdp1_api.md.
Contract the agent must follow
- Launch with
and set the CUDA device per process (usually viatorchrun
).LOCAL_RANK - Apply
bottom-up, i.e., shard submodules (e.g., Transformer blocks) before the root module.fully_shard() - Call
, notmodel(input)
, so the FSDP2 hooks run (unless you explicitlymodel.forward(input)
or register the forward method).unshard() - Create the optimizer after sharding and make sure it is built on the DTensor parameters (post-
).fully_shard - Checkpoint using Distributed Checkpoint (DCP) or the distributed-state-dict helpers, not naïve
unless you deliberately gather to full tensors.torch.save(model.state_dict())
(Each of these rules is directly described in the official API docs/tutorial; see references.)
Step-by-step procedure
0) Version & environment sanity
- Prefer a recent stable PyTorch where the docs show FSDP2 and DCP updated recently.
- Use
and ensuretorchrun --nproc_per_node <gpus_per_node> ...
,RANK
,WORLD_SIZE
are visible.LOCAL_RANK
Reference:
references/pytorch_fsdp2_tutorial.md (launch commands and setup), references/pytorch_fully_shard_api.md (user contract).
1) Initialize distributed and set device
Minimal, correct pattern:
dist.init_process_group(backend="nccl")torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))- Optionally create a
to describe the data-parallel group(s)DeviceMesh
Reference:
references/pytorch_device_mesh_tutorial.md (why DeviceMesh exists & how it manages process groups).
2) Build model on meta device (recommended for very large models)
For big models, initialize on
meta, apply sharding, then materialize weights on GPU:
with torch.device("meta"): model = ...- apply
on submodules, thenfully_shard(...)fully_shard(model) model.to_empty(device="cuda")
(or your init routine)model.reset_parameters()
Reference:
references/pytorch_fsdp2_tutorial.md (migration guide shows this flow explicitly).
3) Apply fully_shard()
bottom-up (wrapping policy = “apply where needed”)
fully_shard()Do not only call
fully_shard on the topmost module.
Recommended sharding pattern for transformer-like models:
- iterate modules,
if isinstance(m, TransformerBlock): fully_shard(m, ...) - then
fully_shard(model, ...)
Why:
forms “parameter groups” for collective efficiency and excludes params already grouped by earlier calls. Bottom-up gives better overlap and lower peak memory.fully_shard
Reference:
references/pytorch_fully_shard_api.md (bottom-up requirement and why).
4) Configure reshard_after_forward
for memory/perf trade-offs
reshard_after_forwardDefault behavior:
meansNone
for non-root modules andTrue
for root modules (good default).False
Heuristics:
- If you’re memory-bound: keep defaults or force
on many blocks.True - If you’re throughput-bound and can afford memory: consider keeping unsharded params longer (root often
).False - Advanced: use an
to reshard to a smaller mesh after forward (e.g., intra-node) if it’s a meaningful divisor.int
Reference:
references/pytorch_fully_shard_api.md (full semantics).
5) Mixed precision & offload (optional but common)
FSDP2 uses:
mp_policy=MixedPrecisionPolicy(param_dtype=..., reduce_dtype=..., output_dtype=..., cast_forward_inputs=...)
if you want CPU offloadoffload_policy=CPUOffloadPolicy()
Rules of thumb:
- Start with BF16 parameters/reductions on H100/A100-class GPUs (if numerically stable for your model).
- Keep
aligned with your gradient reduction expectations.reduce_dtype - If you use CPU offload, budget for PCIe/NVLink traffic and runtime overhead.
Reference:
references/pytorch_fully_shard_api.md (MixedPrecisionPolicy / OffloadPolicy classes).
6) Optimizer, gradient clipping, accumulation
- Create the optimizer after sharding so it holds DTensor params.
- If you need gradient accumulation / no_sync:
- use the FSDP2 mechanism (
) instead of FSDP1’sset_requires_gradient_sync
.no_sync()
- use the FSDP2 mechanism (
Gradient clipping:
- Use the approach shown in the FSDP2 tutorial (“Gradient Clipping and Optimizer with DTensor”), because parameters/gradients are DTensors.
Reference:
references/pytorch_fsdp2_tutorial.md.
7) Checkpointing: prefer DCP or distributed state dict helpers
Two recommended approaches:
A) Distributed Checkpoint (DCP) — best default
- DCP saves/loads from multiple ranks in parallel and supports load-time resharding.
- DCP produces multiple files (often at least one per rank) and operates “in place”.
B) Distributed state dict helpers
/get_model_state_dict
withset_model_state_dictStateDictOptions(full_state_dict=True, cpu_offload=True, broadcast_from_rank0=True, ...)- For optimizer:
/get_optimizer_state_dictset_optimizer_state_dict
Avoid:
- Saving DTensor state dicts with plain
unless you intentionally convert withtorch.save
and manage memory carefully.DTensor.full_tensor()
References:
(DCP behavior and caveats)references/pytorch_dcp_overview.md
andreferences/pytorch_dcp_recipe.md
(end-to-end usage)references/pytorch_dcp_async_recipe.md
(DTensor vs DCP state-dict flows)references/pytorch_fsdp2_tutorial.md
(working checkpoint scripts)references/pytorch_examples_fsdp2.md
Workflow checklists (copy-paste friendly)
Workflow A: Retrofit FSDP2 into an existing training script
- Launch with
and initialize the process group.torchrun - Set the CUDA device from
; create aLOCAL_RANK
if you need multi-dim parallelism.DeviceMesh - Build the model (use
if needed), applymeta
bottom-up, thenfully_shard
.fully_shard(model) - Create the optimizer after sharding so it captures DTensor parameters.
- Use
so hooks run; usemodel(inputs)
for accumulation.set_requires_gradient_sync - Add DCP save/load via
helpers.torch.distributed.checkpoint
Reference:
references/pytorch_fsdp2_tutorial.md, references/pytorch_fully_shard_api.md, references/pytorch_device_mesh_tutorial.md, references/pytorch_dcp_recipe.md.
Workflow B: Add DCP save/load (minimal pattern)
- Wrap state in
or assemble state viaStateful
.get_state_dict - Call
from all ranks to a shared path.dcp.save(...) - Call
and restore withdcp.load(...)
.set_state_dict - Validate any resharding assumptions when loading into a different mesh.
Reference:
references/pytorch_dcp_recipe.md.
Debug checklist (what the agent should check first)
- All ranks on distinct GPUs?
If not, verify
and yourtorch.cuda.set_device(LOCAL_RANK)
flags.torchrun - Did you accidentally call
directly?forward()
Use
or explicitlymodel(input)
/ register forward.unshard() - Is
applied bottom-up?fully_shard()
If only root is sharded, expect worse memory/perf and possible confusion. - Optimizer created at the right time?
Must be built on DTensor parameters after sharding. - Checkpointing path consistent?
- If using DCP, don’t mix with ad-hoc
unless you understand conversions.torch.save - Be mindful of PyTorch-version compatibility warnings for DCP.
- If using DCP, don’t mix with ad-hoc
Common issues and fixes
- Forward hooks not running → Call
(ormodel(inputs)
explicitly) instead ofunshard()
.model.forward(...) - Optimizer sees non-DTensor params → Create optimizer after all
calls.fully_shard - Only root module sharded → Apply
bottom-up on submodules before the root.fully_shard - Memory spikes after forward → Set
for more modules.reshard_after_forward=True - Gradient accumulation desync → Use
instead of FSDP1’sset_requires_gradient_sync
.no_sync()
Reference:
references/pytorch_fully_shard_api.md, references/pytorch_fsdp2_tutorial.md.
Minimal reference implementation outline (agent-friendly)
The coding agent should implement a script with these labeled blocks:
: init process group, set deviceinit_distributed()
: model on meta, applybuild_model_meta()
, materialize weightsfully_shard
: optimizer created after shardingbuild_optimizer()
: forward/backward/step withtrain_step()
and DTensor-aware patternsmodel(inputs)
: DCP or distributed state dict helperscheckpoint_save/load()
Concrete examples live in
references/pytorch_examples_fsdp2.md and the official tutorial reference.
References
references/pytorch_fsdp2_tutorial.mdreferences/pytorch_fully_shard_api.mdreferences/pytorch_ddp_notes.mdreferences/pytorch_fsdp1_api.mdreferences/pytorch_device_mesh_tutorial.mdreferences/pytorch_tp_tutorial.mdreferences/pytorch_dcp_overview.mdreferences/pytorch_dcp_recipe.mdreferences/pytorch_dcp_async_recipe.mdreferences/pytorch_examples_fsdp2.md
(optional, production notes)references/torchtitan_fsdp_notes.md
(optional, integration example)references/ray_train_fsdp2_example.md