Skip to content
GitLab
Projects Groups Snippets
  • /
  • Help
    • Help
    • Support
    • Community forum
    • Submit feedback
    • Contribute to GitLab
  • Sign in / Register
  • M metaseq
  • Project information
    • Project information
    • Activity
    • Labels
    • Members
  • Repository
    • Repository
    • Files
    • Commits
    • Branches
    • Tags
    • Contributors
    • Graph
    • Compare
  • Issues 95
    • Issues 95
    • List
    • Boards
    • Service Desk
    • Milestones
  • Merge requests 41
    • Merge requests 41
  • CI/CD
    • CI/CD
    • Pipelines
    • Jobs
    • Schedules
  • Deployments
    • Deployments
    • Environments
    • Releases
  • Packages and registries
    • Packages and registries
    • Package Registry
    • Infrastructure Registry
  • Monitor
    • Monitor
    • Incidents
  • Analytics
    • Analytics
    • Value stream
    • CI/CD
    • Repository
  • Wiki
    • Wiki
  • Snippets
    • Snippets
  • Activity
  • Graph
  • Create a new issue
  • Jobs
  • Commits
  • Issue Boards
Collapse sidebar
  • Administrator
  • metaseq
  • Merge requests
  • !459

Add a script to reshard FSDP checkpoints

  • Review changes

  • Download
  • Email patches
  • Plain diff
Merged Administrator requested to merge github/fork/tangbinh/reshard-fsdp into main Oct 26, 2022
  • Overview 19
  • Commits 1
  • Pipelines 0
  • Changes 6

Created by: tangbinh

Summary

We add a new script to reshard raw FSDP checkpoints as part of our efforts to consolidate the checkpoint resharding logic. This script is a bit more general than some of the existing ones:

  • Compared to reshard_mp.py, it allows us to optionally unflatten model weights and be compatible with the generator interface when ddp-backend is set to pytorch_ddp.
  • Compared to the consolidate_shard_weights and build_unflat_state_dict functions from FSDP (the former is used in stitch_fsdp_ckpt.py), it supports both unsharding and resharding model weights and optimizer states.
  • Compared to checkpoint_utils.py, which is used in convert_to_singleton.py, it doesn't require instantiating FSDP instances and avoid the various requirements that come with it (DDP, vocab files, configs, etc). We also decouple the filename handling to make it a bit more flexible.

Note that this script doesn't include the logic for model parallel resharding. We should probably have a separate script for it, which can be used together with this one.

Testing

  • Run the script to merge the sharded checkpoints of the 2.7B parameters model into one shard for each model parallel part and load the resharded checkpoints with the interactive CLI:
for j in {0..3}; do
    python -m metaseq.scripts.reshard_fsdp \
    --input-glob-pattern "/data/gpt-z/models/gptz/2.7B/raw/checkpoint_last-model_part-$j-shard*.pt" \
    --output-shard-name "/shared/home/binhtang/checkpoints/opt-2.7b/reshard-model_part-$j.pt" \
    --num-output-shards 1 --skip-optimizer-state True --unflatten-weights True;
done
python -m metaseq.cli.interactive_cli
> what is the meaning of life?
To be happy.
  • Run the script to reshard the 6.7B parameters model checkpoint for each model parallel part from 256 shards to 1 shard and from 1 shard back to 256 shards. The sharded checkpoints we get back are almost identical to the original ones except for some rank-specific data that are lost during the first conversion due to rank 0 copies (e.g optimizer_history, extra_state, cfg.distributed_training.distributed_rank).
for j in {0..1}; do
    python -m metaseq.scripts.reshard_fsdp \
    --input-glob-pattern "/data/gpt-z/models/gptz/6.7B/raw/checkpoint_last-model_part-$j-shard*.pt" \
    --output-shard-name "/shared/home/binhtang/checkpoints/opt-6.7b/reshard-model_part-$j.pt" \
    --num-output-shards 1 --skip-optimizer-state False --unflatten-weights False;
done

for j in {0..1}; do
    python -m metaseq.scripts.reshard_fsdp \
    --input-glob-pattern "/shared/home/binhtang/checkpoints/opt-6.7b/reshard-model_part-$j.pt" \
    --output-shard-name "/shared/home/binhtang/checkpoints/opt-6.7b-reshard/checkpoint_last-model_part-$j-shard{i}.pt" \
    --num-output-shards 256 --skip-optimizer-state False --unflatten-weights False;
done
import torch
for i in range(256):
    before = torch.load(f"/data/gpt-z/models/gptz/6.7B/raw/checkpoint_last-model_part-0-shard{i}.pt", map_location=torch.device("cpu"))
    after = torch.load(f"/shared/home/binhtang/checkpoints/opt-6.7b-reshard/checkpoint_last-model_part-0-shard{i}.pt", map_location=torch.device("cpu"))
    assert all(torch.allclose(before["model"][k], after["model"][k]) for k in before["model"].keys())
    assert(before["shard_metadata"] == after["shard_metadata"])
    assert(torch.allclose(x['exp_avg'], y['exp_avg']) for x, y in zip(before['last_optimizer_state']['state'], after['last_optimizer_state']['state']) for key in ('exp_avg', 'exp_avg_sq'))
Assignee
Assign to
Reviewers
Request review from
Time tracking
Source branch: github/fork/tangbinh/reshard-fsdp