Skip to content
GitLab
Projects Groups Snippets
  • /
  • Help
    • Help
    • Support
    • Community forum
    • Submit feedback
    • Contribute to GitLab
  • Sign in / Register
  • M metaseq
  • Project information
    • Project information
    • Activity
    • Labels
    • Members
  • Repository
    • Repository
    • Files
    • Commits
    • Branches
    • Tags
    • Contributors
    • Graph
    • Compare
  • Issues 95
    • Issues 95
    • List
    • Boards
    • Service Desk
    • Milestones
  • Merge requests 41
    • Merge requests 41
  • CI/CD
    • CI/CD
    • Pipelines
    • Jobs
    • Schedules
  • Deployments
    • Deployments
    • Environments
    • Releases
  • Packages and registries
    • Packages and registries
    • Package Registry
    • Infrastructure Registry
  • Monitor
    • Monitor
    • Incidents
  • Analytics
    • Analytics
    • Value stream
    • CI/CD
    • Repository
  • Wiki
    • Wiki
  • Snippets
    • Snippets
  • Activity
  • Graph
  • Create a new issue
  • Jobs
  • Commits
  • Issue Boards
Collapse sidebar
  • Administrator
  • metaseq
  • Merge requests
  • !556

Add a new script to reshard model parallel parts

  • Review changes

  • Download
  • Email patches
  • Plain diff
Merged Administrator requested to merge reshard-mp into main Dec 21, 2022
  • Overview 8
  • Commits 1
  • Pipelines 0
  • Changes 2

Created by: tangbinh

Summary of Changes

The existing script for resharding model parallel parts (i.e. metaseq/scripts/reshard_model_parallel.py) loads all checkpoint parts at once and might result in OOM issues under RAM constraints, especially for very large models. Here, we rewrite the script and optimize for memory usage by first allocating an unsharded model state dict and iteratively merging model parallel parts into it.

Previously, peak memory usage was close to 2X model size as we needed to hold input and output state dicts, but theoretically it's closer to 1x model size now thanks to the iterative process.

The new script produces the same output as metaseq/scripts/reshard_model_parallel.py. We delete it to avoid duplication and note that the old script still remains accessible in the internal repo (see this script).

Test Plan

  • Run the script with an OPT 2.7B checkpoint to reshard 4 MP parts into 8 MP parts and make sure the resulting checkpoint performs reasonably:
    seq 0 3 | parallel --line-buffer 'python metaseq/scripts/reshard_fsdp.py --input "/data/checkpoints/opt-2.7b/raw/checkpoint_last-model_part-{}-shard*.pt" --output "/data/checkpoints/opt-2.7b/reshard-no-os/reshard-model_part-{}.pt" --skip-optimizer-state True --unflatten-weights True --output-dtype fp16'
    python -m metaseq.scripts.reshard_mp --input "/data/checkpoints/opt-2.7b/reshard_no_os/reshard-model_part-*.pt" --output "/data/checkpoints/opt-2.7b/reshard_no_os_mp8/reshard-model_part-{i}.pt" --num-output-parts 8
    python metaseq/scripts/interactive.py --merges-filename /data/checkpoints/gpt2-merges.txt --vocab-filename /data/checkpoints/gpt2-vocab.json --path /data/checkpoints/opt-2.7b/reshard_no_os_mp8/reshard.pt --model-parallel-size 8 --distributed-world-size 8  --beam 3 --max-source-positions 4 --max-target-positions 128
    
    > Prompt: What is the meaning of life?
    Output: To be happy.
  • We compare performance with metaseq/scripts/reshard_model_parallel.py while resharding an OPT-175B checkpoint from 8 MP parts into 16 MP parts. The old script takes 849.40 seconds and results in a peak RSS delta of 668,301 MB while the new script takes 891.65 seconds and has RSS delta of 458,185 MB (a 46% reduction in RAM usage).
    python metaseq/scripts/reshard_model_parallel.py --pth_prefix /data/checkpoints/opt-175b/reshard_no_os_unflat/reshard.pt --new-model-parts 16 --save-prefix  /data/checkpoints/opt-175b/reshard_no_os_unflat_mp16_ref/reshard.pt
    python -m metaseq.scripts.reshard_mp --input "/data/checkpoints/opt-175b/reshard_no_os_unflat/reshard-model_part-*.pt" --output "/data/checkpoints/opt-175b/reshard_no_os_unflat_mp16/reshard-model_part-{i}.pt" --num-output-parts 16
Assignee
Assign to
Reviewers
Request review from
Time tracking
Source branch: reshard-mp