Skip to content
GitLab
Projects Groups Snippets
  • /
  • Help
    • Help
    • Support
    • Community forum
    • Submit feedback
    • Contribute to GitLab
  • Sign in / Register
  • M metaseq
  • Project information
    • Project information
    • Activity
    • Labels
    • Members
  • Repository
    • Repository
    • Files
    • Commits
    • Branches
    • Tags
    • Contributors
    • Graph
    • Compare
  • Issues 95
    • Issues 95
    • List
    • Boards
    • Service Desk
    • Milestones
  • Merge requests 41
    • Merge requests 41
  • CI/CD
    • CI/CD
    • Pipelines
    • Jobs
    • Schedules
  • Deployments
    • Deployments
    • Environments
    • Releases
  • Packages and registries
    • Packages and registries
    • Package Registry
    • Infrastructure Registry
  • Monitor
    • Monitor
    • Incidents
  • Analytics
    • Analytics
    • Value stream
    • CI/CD
    • Repository
  • Wiki
    • Wiki
  • Snippets
    • Snippets
  • Activity
  • Graph
  • Create a new issue
  • Jobs
  • Commits
  • Issue Boards
Collapse sidebar
  • Administrator
  • metaseq
  • Issues
  • #31
Closed
Open
Issue created May 04, 2022 by Administrator@rootOwner

How to load sharded checkpoints?

Created by: patrickvonplaten

❓ Questions and Help

After having set-up the libraries as described in: https://github.com/facebookresearch/metaseq/blob/main/docs/setup.md , it is possible to load the 350m checkpoint since it's not sharded as follows:

wget https://dl.fbaipublicfiles.com/opt/v1_20220502/350m/reshard.pt ./
  1. Next we need to comment out one line in the Megatron-LM library which is only relevant for training (initialize different random seeds accross pp ranks): Comment out this line: https://github.com/ngoyal2707/Megatron-LM/blob/ae0b844c1f6725c3433a95e42cac760b3885170b/megatron/initialize.py#L65 in your local clone of Megatron-LM

  2. Now we write the following Python script to a run_model.py file:

import os

from transformers import AutoTokenizer, GPT2Tokenizer
from megatron.initialize import initialize_megatron
from metaseq import checkpoint_utils
import torch

path = "./"

# arguments taken from: https://arxiv.org/pdf/2205.01068.pdf | table 1
initialize_megatron(args_defaults={
    "micro_batch_size": 1, 
    "num_layers": 24, 
    "hidden_size": 1024, 
    "num_attention_heads": 16,
    "max_position_embeddings": 2048, 
    "encoder_seq_length": 2048 
})

tokenizer = GPT2Tokenizer.from_pretrained("patrickvonplaten/opt_gpt2_tokenizer")
tokenizer.save_pretrained(path)

checkpoint = checkpoint_utils.load_model_ensemble_and_task(
    [os.path.join(path, "reshard.pt")],
    arg_overrides={
        "vocab_filename": os.path.join(path, "vocab.json"),
        "merges_filename": os.path.join(path, "merges.txt"),
    }
)

model = checkpoint[0][0].eval()
  1. We can load the checkpoint when running
torchrun run_model.py --pipeline-model-parallel-size 1 --tensor-model-parallel-size 1

Problem This only works for the 350m checkpoint!!! For the other checkpoints this doesn't work. E.g. when replacing: [os.path.join(path, "reshard.pt")] by [os.path.join(path, "reshard-model_part-0.pt"), os.path.join(path, "reshard-model_part-1.pt")] (part-0 and part-1 of the 125M model), we're getting an error because the weigths are all flattened into 1D-arrays.

Using https://github.com/facebookresearch/metaseq/pull/29 sadly also doesn't help, since the checkpoints don't seem to be in the *shard* format as required here: https://github.com/facebookresearch/metaseq/blob/48b9b6c083237f9b95c2eb67afc10005e10d67ee/metaseq/distributed/stitch_fsdp_ckpt.py#L45

The parameter flattening seems to come from Fairscale and we've found some functionality to unflatten it here: https://github.com/facebookresearch/fairscale/blob/51b53ddb6c3aa77426c7d5cc0b543b79628053c4/fairscale/nn/misc/flatten_params_wrapper.py#L358 , but we don't manage to wrap our head around how to make it work exactly.

@stephenroller @suchenzang @zhiqwang - any pointers on how we could load the 125M model (and the others) into a model instance of metaseq?

Assignee
Assign to
Time tracking