Skip to content
GitLab
Projects Groups Snippets
  • /
  • Help
    • Help
    • Support
    • Community forum
    • Submit feedback
    • Contribute to GitLab
  • Sign in / Register
  • M metaseq
  • Project information
    • Project information
    • Activity
    • Labels
    • Members
  • Repository
    • Repository
    • Files
    • Commits
    • Branches
    • Tags
    • Contributors
    • Graph
    • Compare
  • Issues 95
    • Issues 95
    • List
    • Boards
    • Service Desk
    • Milestones
  • Merge requests 41
    • Merge requests 41
  • CI/CD
    • CI/CD
    • Pipelines
    • Jobs
    • Schedules
  • Deployments
    • Deployments
    • Environments
    • Releases
  • Packages and registries
    • Packages and registries
    • Package Registry
    • Infrastructure Registry
  • Monitor
    • Monitor
    • Incidents
  • Analytics
    • Analytics
    • Value stream
    • CI/CD
    • Repository
  • Wiki
    • Wiki
  • Snippets
    • Snippets
  • Activity
  • Graph
  • Create a new issue
  • Jobs
  • Commits
  • Issue Boards
Collapse sidebar
  • Administrator
  • metaseq
  • Issues
  • #186
Closed
Open
Issue created Jun 27, 2022 by Administrator@rootOwner

OPT 125m model training from scratch with language modeling task and OSCAR dataset

Created by: chelseajohn

❓ Questions and Help

What is your question?

I am trying to train the OPT 125m model from scratch using the OSCAR dataset (1GB) and the language modeling task. I get the following error :

Traceback (most recent call last):
  File "/p/scratch/opengptx-elm/john2/opengpt/OPT/slurm_snapshot_code_oss/2022-06-24T09_30_24.704852/metaseq_cli/train.py", line 589, in <module>
    cli_main()
  File "/p/scratch/opengptx-elm/john2/opengpt/OPT/slurm_snapshot_code_oss/2022-06-24T09_30_24.704852/metaseq_cli/train.py", line 585, in cli_main
    distributed_utils.call_main(cfg, main)
  File "/p/scratch/opengptx-elm/john2/opengpt/OPT/slurm_snapshot_code_oss/2022-06-24T09_30_24.704852/metaseq/distributed/utils.py", line 267, in call_main
    return distributed_main(
  File "/p/scratch/opengptx-elm/john2/opengpt/OPT/slurm_snapshot_code_oss/2022-06-24T09_30_24.704852/metaseq/distributed/utils.py", line 205, in distributed_main
    main(cfg, **kwargs)
  File "/p/scratch/opengptx-elm/john2/opengpt/OPT/slurm_snapshot_code_oss/2022-06-24T09_30_24.704852/metaseq_cli/train.py", line 167, in main
    valid_losses, should_stop = train(cfg, trainer, task, epoch_itr)
  File "/p/software/juwelsbooster/stages/2020/software/Python/3.8.5-GCCcore-10.3.0/lib/python3.8/contextlib.py", line 75, in inner
    return func(*args, **kwds)
  File "/p/scratch/opengptx-elm/john2/opengpt/OPT/slurm_snapshot_code_oss/2022-06-24T09_30_24.704852/metaseq_cli/train.py", line 307, in train
    for i, samples in enumerate(progress):
  File "/p/scratch/opengptx-elm/john2/opengpt/OPT/slurm_snapshot_code_oss/2022-06-24T09_30_24.704852/metaseq/logging/progress_bar/json_progress_bar.py", line 38, in __iter__
    for i, obj in enumerate(self.iterable, start=self.n):
  File "/p/scratch/opengptx-elm/john2/opengpt/OPT/slurm_snapshot_code_oss/2022-06-24T09_30_24.704852/metaseq/data/iterators.py", line 60, in __iter__
    for x in self.iterable:
  File "/p/scratch/opengptx-elm/john2/opengpt/OPT/slurm_snapshot_code_oss/2022-06-24T09_30_24.704852/metaseq/data/iterators.py", line 737, in __next__
    raise item
  File "/p/scratch/opengptx-elm/john2/opengpt/OPT/slurm_snapshot_code_oss/2022-06-24T09_30_24.704852/metaseq/data/iterators.py", line 668, in run
    for item in self._source:
  File "/p/project/opengptx-elm/john2/opengpt/OPT/env/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 530, in __next__
    data = self._next_data()
  File "/p/project/opengptx-elm/john2/opengpt/OPT/env/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1224, in _next_data
    return self._process_data(data)
  File "/p/project/opengptx-elm/john2/opengpt/OPT/env/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1250, in _process_data
    data.reraise()
  File "/p/project/opengptx-elm/john2/opengpt/OPT/env/lib/python3.8/site-packages/torch/_utils.py", line 457, in reraise
    raise exception
ValueError: Caught ValueError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/p/project/opengptx-elm/john2/opengpt/OPT/env/lib/python3.8/site-packages/torch/utils/data/_utils/worker.py", line 287, in _worker_loop
    data = fetcher.fetch(index)
  File "/p/project/opengptx-elm/john2/opengpt/OPT/env/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 49, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/p/project/opengptx-elm/john2/opengpt/OPT/env/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 49, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/p/scratch/opengptx-elm/john2/opengpt/OPT/slurm_snapshot_code_oss/2022-06-24T09_30_24.704852/metaseq/data/monolingual_dataset.py", line 97, in __getitem__
    source, future_target, _ = self.dataset[index]
  File "/p/scratch/opengptx-elm/john2/opengpt/OPT/slurm_snapshot_code_oss/2022-06-24T09_30_24.704852/metaseq/data/token_block_dataset.py", line 166, in __getitem__
    [self.dataset[idx] for idx in range(start_ds_idx, end_ds_idx + 1)]
  File "/p/scratch/opengptx-elm/john2/opengpt/OPT/slurm_snapshot_code_oss/2022-06-24T09_30_24.704852/metaseq/data/token_block_dataset.py", line 166, in <listcomp>
    [self.dataset[idx] for idx in range(start_ds_idx, end_ds_idx + 1)]
  File "/p/scratch/opengptx-elm/john2/opengpt/OPT/slurm_snapshot_code_oss/2022-06-24T09_30_24.704852/metaseq/data/indexed_dataset.py", line 514, in __getitem__
    np_array = np.frombuffer(
ValueError: offset must be non-negative and no greater than buffer length (468678968)

I would really appreciate it, if someone has any idea on how to tackle this error. I have run out of ideas and any suggestions/pointers are welcome.

Code

The command to run the code is as follows : opt-baselines -n 2 -g 4 --account opengptx-elm --partition develbooster --prefix opt125m_1 --model-size 125m --juwelsbooster --data /p/scratch/opengptx-elm/john2/opengpt/data/oscar --ntasks-per-node 4 --cpus-per-task 12 --checkpoints-dir "$CHECKPOINT_PATH" --tensorboard-logdir "$TENSORBOARD_PATH" --no-save-dir --snapshot-root "$ROOT_OUTPUT_DIR" --time 10 --no-wandb --cpu-bind none --salloc

The setup code for the metaseq can be found here

What have you tried?

The following have been tried :

  • varying the number of gpu's from 4 to 16
  • varying the number of workers 0 to 8
  • batch size
  • larger dataset (4GB)
  • moving to PyTorch version 1.10.1
  • using the --reset-dataloader option

What's your environment?

  • metaseq Version : Forked repo
  • PyTorch Version : 1.11.0
  • OS : Linux
  • How you installed metaseq (pip, source): python -m pip install -e .
  • Python version: 3.8.5
  • CUDA/cuDNN version: cuDNN/8.2.1.32-CUDA-11.3
  • GPU models and configuration: Each node contains 4 × NVIDIA A100 Tensor Core GPU with 40 GB; connected via NVLink3 to each other more info
  • Any other relevant information: Running on 4 × NVIDIA A100 Tensor Core GPU with 40 GB, with number of workers = 0 and setting CUDA_LAUNCH_BLOCKING= 1 gives the error :
Traceback (most recent call last):
 File "/p/scratch/opengptx-elm/john2/opengpt/OPT/slurm_snapshot_code_oss/2022-06-24T10_17_10.587237/metaseq_cli/train.py", line 589, in <module>
   cli_main()
 File "/p/scratch/opengptx-elm/john2/opengpt/OPT/slurm_snapshot_code_oss/2022-06-24T10_17_10.587237/metaseq_cli/train.py", line 585, in cli_main
   distributed_utils.call_main(cfg, main)
 File "/p/scratch/opengptx-elm/john2/opengpt/OPT/slurm_snapshot_code_oss/2022-06-24T10_17_10.587237/metaseq/distributed/utils.py", line 267, in call_main
   return distributed_main(
 File "/p/scratch/opengptx-elm/john2/opengpt/OPT/slurm_snapshot_code_oss/2022-06-24T10_17_10.587237/metaseq/distributed/utils.py", line 205, in distributed_main
   main(cfg, **kwargs)
 File "/p/scratch/opengptx-elm/john2/opengpt/OPT/slurm_snapshot_code_oss/2022-06-24T10_17_10.587237/metaseq_cli/train.py", line 167, in main
   valid_losses, should_stop = train(cfg, trainer, task, epoch_itr)
 File "/p/software/juwelsbooster/stages/2020/software/Python/3.8.5-GCCcore-10.3.0/lib/python3.8/contextlib.py", line 75, in inner
   return func(*args, **kwds)
 File "/p/scratch/opengptx-elm/john2/opengpt/OPT/slurm_snapshot_code_oss/2022-06-24T10_17_10.587237/metaseq_cli/train.py", line 321, in train
   valid_losses, should_stop = train(i, samples)
 File "/p/scratch/opengptx-elm/john2/opengpt/OPT/slurm_snapshot_code_oss/2022-06-24T10_17_10.587237/metaseq_cli/train.py", line 281, in train
   log_output = trainer.train_step(samples)
 File "/p/software/juwelsbooster/stages/2020/software/Python/3.8.5-GCCcore-10.3.0/lib/python3.8/contextlib.py", line 75, in inner
   return func(*args, **kwds)
 File "/p/scratch/opengptx-elm/john2/opengpt/OPT/slurm_snapshot_code_oss/2022-06-24T10_17_10.587237/metaseq/trainer.py", line 721, in train_step
   raise e
 File "/p/scratch/opengptx-elm/john2/opengpt/OPT/slurm_snapshot_code_oss/2022-06-24T10_17_10.587237/metaseq/trainer.py", line 689, in train_step
   loss, sample_size_i, logging_output = self.task.train_step(
 File "/p/scratch/opengptx-elm/john2/opengpt/OPT/slurm_sTraceback (most recent call last):
 File "/p/scratch/opengptx-elm/john2/opengpt/OPT/slurm_snapshot_code_oss/2022-06-24T10_17_10.587237/metaseq_cli/train.py", line 589, in <module>
   cli_main()
 File "/p/scratch/opengptx-elm/john2/opengpt/OPT/slurm_snapshot_code_oss/2022-06-24T10_17_10.587237/metaseq_cli/train.py", line 585, in cli_main
   distributed_utils.call_main(cfg, main)
 File "/p/scratch/opengptx-elm/john2/opengpt/OPT/slurm_snapshot_code_oss/2022-06-24T10_17_10.587237/metaseq/distributed/utils.py", line 267, in call_main
   return distributed_main(
 File "/p/scratch/opengptx-elm/john2/opengpt/OPT/slurm_snapshot_code_oss/2022-06-24T10_17_10.587237/metaseq/distributed/utils.py", line 205, in distributed_main
   main(cfg, **kwargs)
 File "/p/scratch/opengptx-elm/john2/opengpt/OPT/slurm_snapshot_code_oss/2022-06-24T10_17_10.587237/metaseq_cli/train.py", line 167, in main
   valid_losses, should_stop = train(cfg, trainer, task, epoch_itr)
 File "/p/software/juwelsbooster/stages/2020/software/Python/3.8.5-GCCcore-10.3.0napshot_code_oss/2022-06-24T10_17_10.587237/metaseq/tasks/base_task.py", line 409, in train_step
   loss, sample_size, logging_output = criterion(model, sample)
 File "/p/project/opengptx-elm/john2/opengpt/OPT/env/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
   return forward_call(*input, **kwargs)
 File "/p/scratch/opengptx-elm/john2/opengpt/OPT/slurm_snapshot_code_oss/2022-06-24T10_17_10.587237/metaseq/model_parallel/criterions/vocab_parallel_cross_entropy.py", line 43, in forward
   net_output = model(**sample["net_input"])
 File "/p/project/opengptx-elm/john2/opengpt/OPT/env/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
   return forward_call(*input, **kwargs)
 File "/p/project/opengptx-elm/john2/opengpt/OPT/fairscale/fairscale/nn/data_parallel/fully_sharded_data_parallel.py", line 1152, in forward
   outputs = self.module(*args, **kwargs)
 File "/p/project/opengptx-elm/john2/opengpt/OPT/env/lib/python3.8/site-packages/torch/nn/modules/module.py",/lib/python3.8/contextlib.py", line 75, in inner
   return func(*args, **kwds)
 File "/p/scratch/opengptx-elm/john2/opengpt/OPT/slurm_snapshot_code_oss/2022-06-24T10_17_10.587237/metaseq_cli/train.py", line 321, in train
   valid_losses, should_stop = train(i, samples)
 File "/p/scratch/opengptx-elm/john2/opengpt/OPT/slurm_snapshot_code_oss/2022-06-24T10_17_10.587237/metaseq_cli/train.py", line 281, in train
   log_output = trainer.train_step(samples)
 File "/p/software/juwelsbooster/stages/2020/software/Python/3.8.5-GCCcore-10.3.0/lib/python3.8/contextlib.py", line 75, in inner
   return func(*args, **kwds)
 File "/p/scratch/opengptx-elm/john2/opengpt/OPT/slurm_snapshot_code_oss/2022-06-24T10_17_10.587237/metaseq/trainer.py", line 721, in train_step
   raise e
 File "/p/scratch/opengptx-elm/john2/opengpt/OPT/slurm_snapshot_code_oss/2022-06-24T10_17_10.587237/metaseq/trainer.py", line 689, in train_step
   loss, sample_size_i, logging_output = self.task.train_step(
 File "/p/scratch/opengptx-elm/john2/opengpt/OPT/slurm_s line 1110, in _call_impl
   return forward_call(*input, **kwargs)
 File "/p/project/opengptx-elm/john2/opengpt/OPT/fairscale/fairscale/nn/misc/flatten_params_wrapper.py", line 459, in forward
   return self.module(*inputs, **kwinputs)
 File "/p/project/opengptx-elm/john2/opengpt/OPT/env/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
   return forward_call(*input, **kwargs)
 File "/p/scratch/opengptx-elm/john2/opengpt/OPT/slurm_snapshot_code_oss/2022-06-24T10_17_10.587237/metaseq/models/base_model.py", line 371, in forward
   return self.decoder(src_tokens, **kwargs)
 File "/p/project/opengptx-elm/john2/opengpt/OPT/env/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
   return forward_call(*input, **kwargs)
 File "/p/scratch/opengptx-elm/john2/opengpt/OPT/slurm_snapshnapshot_code_oss/2022-06-24T10_17_10.587237/metaseq/tasks/base_task.py", line 409, in train_step
   loss, sample_size, logging_output = criterion(model, sample)
 File "/p/project/opengptx-elm/john2/opengpt/OPT/env/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
   return forward_call(*input, **kwargs)
 File "/p/scratch/opengptx-elm/john2/opengpt/OPT/slurm_snapshot_code_oss/2022-06-24T10_17_10.587237/metaseq/model_parallel/criterions/vocab_parallel_cross_entropy.py", line 43, in forward
   net_output = model(**sample["net_input"])
 File "/p/project/opengptx-elm/john2/opengpt/OPT/env/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
   return forward_call(*input, **kwargs)
 File "/p/project/opengptx-elm/john2/opengpt/OPT/fairscale/fairscale/nn/data_parallel/fully_sharded_data_parallel.py", line 1152, in forward
   outputs = self.module(*args, **kwargs)
 File "/p/project/opengptx-elm/john2/opengpt/OPT/env/lib/python3.8/site-packages/torch/nn/modules/module.py",Traceback (most recent call last):
 File "/p/scratch/opengptx-elm/john2/opengpt/OPT/slurm_snapshot_code_oss/2022-06-24T10_17_10.587237/metaseq_cli/train.py", line 589, in <module>
   cli_main()
 File "/p/scratch/opengptx-elm/john2/opengpt/OPT/slurm_snapshot_code_oss/2022-06-24T10_17_10.587237/metaseq_cli/train.py", line 585, in cli_main
   distributed_utils.call_main(cfg, main)
 File "/p/scratch/opengptx-elm/john2/opengpt/OPT/slurm_snapshot_code_oss/2022-06-24T10_17_10.587237/metaseq/distributed/utils.py", line 267, in call_main
   return distributed_main(
 File "/p/scratch/opengptx-elm/john2/opengpt/OPT/slurm_snapshot_code_oss/2022-06-24T10_17_10.587237/metaseq/distributed/utils.py", line 205, in distributed_main
   main(cfg, **kwargs)
 File "/p/scratch/opengptx-elm/john2/opengpt/OPT/slurm_snapshot_code_oss/2022-06-24T10_17_10.587237/metaseq_cli/train.py", line 167, in main
   valid_losses, should_stop = train(cfg, trainer, task, epoch_itr)
 File "/p/software/juwelsbooster/stages/2020/software/Python/3.8.5-GCCcore-10.3.0Traceback (most recent call last):
 File "/p/scratch/opengptx-elm/john2/opengpt/OPT/slurm_snapshot_code_oss/2022-06-24T10_17_10.587237/metaseq_cli/train.py", line 589, in <module>
   cli_main()
 File "/p/scratch/opengptx-elm/john2/opengpt/OPT/slurm_snapshot_code_oss/2022-06-24T10_17_10.587237/metaseq_cli/train.py", line 585, in cli_main
   distributed_utils.call_main(cfg, main)
 File "/p/scratch/opengptx-elm/john2/opengpt/OPT/slurm_snapshot_code_oss/2022-06-24T10_17_10.587237/metaseq/distributed/utils.py", line 267, in call_main
   return distributed_main(
 File "/p/scratch/opengptx-elm/john2/opengpt/OPT/slurm_snapshot_code_oss/2022-06-24T10_17_10.587237/metaseq/distributed/utils.py", line 205, in distributed_main
   main(cfg, **kwargs)
 File "/p/scratch/opengptx-elm/john2/opengpt/OPT/slurm_snapshot_code_oss/2022-06-24T10_17_10.587237/metaseq_cli/train.py", line 167, in main
   valid_losses, should_stop = train(cfg, trainer, task, epoch_itr)
 File "/p/software/juwelsbooster/stages/2020/software/Python/3.8.5-GCCcore-10.3.0ot_code_oss/2022-06-24T10_17_10.587237/metaseq/models/transformer.py", line 651, in forward
   x, extra = self.extract_features(
 File "/p/scratch/opengptx-elm/john2/opengpt/OPT/slurm_snapshot_code_oss/2022-06-24T10_17_10.587237/metaseq/models/transformer.py", line 676, in extract_features
   return self.extract_features_scriptable(
 File "/p/scratch/opengptx-elm/john2/opengpt/OPT/slurm_snapshot_code_oss/2022-06-24T10_17_10.587237/metaseq/models/transformer.py", line 740, in extract_features_scriptable
   x, layer_attn, _, l_aux_i = layer(
 File "/p/project/opengptx-elm/john2/opengpt/OPT/env/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
   return forward_call(*input, **kwargs)
 File "/p/project/opengptx-elm/john2/opengpt/OPT/fairscale/fairscale/nn/data_parallel/fully_sharded_data_parallel.py", line 1152, in forward
   outputs = self.module(*args, **kwargs)
 File "/p/project/opengptx-elm/john2/opengpt/OPT/env/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_imp line 1110, in _call_impl
   return forward_call(*input, **kwargs)
 File "/p/project/opengptx-elm/john2/opengpt/OPT/fairscale/fairscale/nn/misc/flatten_params_wrapper.py", line 459, in forward
   return self.module(*inputs, **kwinputs)
 File "/p/project/opengptx-elm/john2/opengpt/OPT/env/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
   return forward_call(*input, **kwargs)
 File "/p/scratch/opengptx-elm/john2/opengpt/OPT/slurm_snapshot_code_oss/2022-06-24T10_17_10.587237/metaseq/models/base_model.py", line 371, in forward
   return self.decoder(src_tokens, **kwargs)
 File "/p/project/opengptx-elm/john2/opengpt/OPT/env/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
   return forward_call(*input, **kwargs)
 File "/p/scratch/opengptx-elm/john2/opengpt/OPT/slurm_snapsh/lib/python3.8/contextlib.py", line 75, in inner
   return func(*args, **kwds)
 File "/p/scratch/opengptx-elm/john2/opengpt/OPT/slurm_snapshot_code_oss/2022-06-24T10_17_10.587237/metaseq_cli/train.py", line 321, in train
   valid_losses, should_stop = train(i, samples)
 File "/p/scratch/opengptx-elm/john2/opengpt/OPT/slurm_snapshot_code_oss/2022-06-24T10_17_10.587237/metaseq_cli/train.py", line 281, in train
   log_output = trainer.train_step(samples)
 File "/p/software/juwelsbooster/stages/2020/software/Python/3.8.5-GCCcore-10.3.0/lib/python3.8/contextlib.py", line 75, in inner
   return func(*args, **kwds)
 File "/p/scratch/opengptx-elm/john2/opengpt/OPT/slurm_snapshot_code_oss/2022-06-24T10_17_10.587237/metaseq/trainer.py", line 721, in train_step
   raise e
 File "/p/scratch/opengptx-elm/john2/opengpt/OPT/slurm_snapshot_code_oss/2022-06-24T10_17_10.587237/metaseq/trainer.py", line 689, in train_step
   loss, sample_size_i, logging_output = self.task.train_step(
 File "/p/scratch/opengptx-elm/john2/opengpt/OPT/slurm_s/lib/python3.8/contextlib.py", line 75, in inner
   return func(*args, **kwds)
 File "/p/scratch/opengptx-elm/john2/opengpt/OPT/slurm_snapshot_code_oss/2022-06-24T10_17_10.587237/metaseq_cli/train.py", line 321, in train
   valid_losses, should_stop = train(i, samples)
 File "/p/scratch/opengptx-elm/john2/opengpt/OPT/slurm_snapshot_code_oss/2022-06-24T10_17_10.587237/metaseq_cli/train.py", line 281, in train
   log_output = trainer.train_step(samples)
 File "/p/software/juwelsbooster/stages/2020/software/Python/3.8.5-GCCcore-10.3.0/lib/python3.8/contextlib.py", line 75, in inner
   return func(*args, **kwds)
 File "/p/scratch/opengptx-elm/john2/opengpt/OPT/slurm_snapshot_code_oss/2022-06-24T10_17_10.587237/metaseq/trainer.py", line 721, in train_step
   raise e
 File "/p/scratch/opengptx-elm/john2/opengpt/OPT/slurm_snapshot_code_oss/2022-06-24T10_17_10.587237/metaseq/trainer.py", line 689, in train_step
   loss, sample_size_i, logging_output = self.task.train_step(
 File "/p/scratch/opengptx-elm/john2/opengpt/OPT/slurm_sl
   return forward_call(*input, **kwargs)
 File "/p/project/opengptx-elm/john2/opengpt/OPT/fairscale/fairscale/nn/misc/flatten_params_wrapper.py", line 459, in forward
   return self.module(*inputs, **kwinputs)
 File "/p/project/opengptx-elm/john2/opengpt/OPT/env/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
   return forward_call(*input, **kwargs)
 File "/p/scratch/opengptx-elm/john2/opengpt/OPT/slurm_snapshot_code_oss/2022-06-24T10_17_10.587237/metaseq/modules/checkpoint_activation_wrapper/checkpoint_activations.py", line 207, in _checkpointed_forward
   output = CheckpointFunction.apply(
 File "/p/scratch/opengptx-elm/john2/opengpt/OPT/slurm_snapshot_code_oss/2022-06-24T10_17_10.587237/metaseq/modules/checkpoint_activation_wrapper/checkpoint_activations.py", line 307, in forward
   outputs = run_function(*unpacked_args, **unpacked_kwargs)
 File "/p/scratch/opengptx-elm/john2/opengpt/OPT/slurm_snapshot_code_oss/2022-06-24T10_17_10.587237/metaseq/modules/transformer_layer.py", line 541ot_code_oss/2022-06-24T10_17_10.587237/metaseq/models/transformer.py", line 651, in forward
   x, extra = self.extract_features(
 File "/p/scratch/opengptx-elm/john2/opengpt/OPT/slurm_snapshot_code_oss/2022-06-24T10_17_10.587237/metaseq/models/transformer.py", line 676, in extract_features
   return self.extract_features_scriptable(
 File "/p/scratch/opengptx-elm/john2/opengpt/OPT/slurm_snapshot_code_oss/2022-06-24T10_17_10.587237/metaseq/models/transformer.py", line 740, in extract_features_scriptable
   x, layer_attn, _, l_aux_i = layer(
 File "/p/project/opengptx-elm/john2/opengpt/OPT/env/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
   return forward_call(*input, **kwargs)
 File "/p/project/opengptx-elm/john2/opengpt/OPT/fairscale/fairscale/nn/data_parallel/fully_sharded_data_parallel.py", line 1152, in forward
   outputs = self.module(*args, **kwargs)
 File "/p/project/opengptx-elm/john2/opengpt/OPT/env/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impnapshot_code_oss/2022-06-24T10_17_10.587237/metaseq/tasks/base_task.py", line 409, in train_step
   loss, sample_size, logging_output = criterion(model, sample)
 File "/p/project/opengptx-elm/john2/opengpt/OPT/env/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
   return forward_call(*input, **kwargs)
 File "/p/scratch/opengptx-elm/john2/opengpt/OPT/slurm_snapshot_code_oss/2022-06-24T10_17_10.587237/metaseq/model_parallel/criterions/vocab_parallel_cross_entropy.py", line 43, in forward
   net_output = model(**sample["net_input"])
 File "/p/project/opengptx-elm/john2/opengpt/OPT/env/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
   return forward_call(*input, **kwargs)
 File "/p/project/opengptx-elm/john2/opengpt/OPT/fairscale/fairscale/nn/data_parallel/fully_sharded_data_parallel.py", line 1152, in forward
   outputs = self.module(*args, **kwargs)
 File "/p/project/opengptx-elm/john2/opengpt/OPT/env/lib/python3.8/site-packages/torch/nn/modules/module.py",napshot_code_oss/2022-06-24T10_17_10.587237/metaseq/tasks/base_task.py", line 409, in train_step
   loss, sample_size, logging_output = criterion(model, sample)
 File "/p/project/opengptx-elm/john2/opengpt/OPT/env/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
   return forward_call(*input, **kwargs)
 File "/p/scratch/opengptx-elm/john2/opengpt/OPT/slurm_snapshot_code_oss/2022-06-24T10_17_10.587237/metaseq/model_parallel/criterions/vocab_parallel_cross_entropy.py", line 43, in forward
   net_output = model(**sample["net_input"])
 File "/p/project/opengptx-elm/john2/opengpt/OPT/env/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
   return forward_call(*input, **kwargs)
 File "/p/project/opengptx-elm/john2/opengpt/OPT/fairscale/fairscale/nn/data_parallel/fully_sharded_data_parallel.py", line 1152, in forward
   outputs = self.module(*args, **kwargs)
 File "/p/project/opengptx-elm/john2/opengpt/OPT/env/lib/python3.8/site-packages/torch/nn/modules/module.py",, in forward
   x, attn = self.forward_attention(
 File "/p/scratch/opengptx-elm/john2/opengpt/OPT/slurm_snapshot_code_oss/2022-06-24T10_17_10.587237/metaseq/model_parallel/modules/transformer_layer.py", line 169, in forward_attention
   (attn_output, attn_bias), attn_weights = self.self_attn(
 File "/p/project/opengptx-elm/john2/opengpt/OPT/env/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
   return forward_call(*input, **kwargs)
 File "/p/scratch/opengptx-elm/john2/opengpt/OPT/slurm_snapshot_code_oss/2022-06-24T10_17_10.587237/metaseq/model_parallel/modules/multihead_attention.py", line 380, in forward
   with get_cuda_rng_tracker().fork():
 File "/p/software/juwelsbooster/stages/2020/software/Python/3.8.5-GCCcore-10.3.0/lib/python3.8/contextlib.py", line 113, in __enter__
   return next(self.gen)
 File "/p/project/opengptx-elm/john2/opengpt/OPT/Megatron-LM/megatron/mpu/random.py", line 161, in fork
   orig_cuda_rng_state = torch.cuda.get_rng_state()
 File "/p/project/opengptx-elm/jol
   return forward_call(*input, **kwargs)
 File "/p/project/opengptx-elm/john2/opengpt/OPT/fairscale/fairscale/nn/misc/flatten_params_wrapper.py", line 459, in forward
   return self.module(*inputs, **kwinputs)
 File "/p/project/opengptx-elm/john2/opengpt/OPT/env/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
   return forward_call(*input, **kwargs)
 File "/p/scratch/opengptx-elm/john2/opengpt/OPT/slurm_snapshot_code_oss/2022-06-24T10_17_10.587237/metaseq/modules/checkpoint_activation_wrapper/checkpoint_activations.py", line 207, in _checkpointed_forward
   output = CheckpointFunction.apply(
 File "/p/scratch/opengptx-elm/john2/opengpt/OPT/slurm_snapshot_code_oss/2022-06-24T10_17_10.587237/metaseq/modules/checkpoint_activation_wrapper/checkpoint_activations.py", line 307, in forward
   outputs = run_function(*unpacked_args, **unpacked_kwargs)
 File "/p/scratch/opengptx-elm/john2/opengpt/OPT/slurm_snapshot_code_oss/2022-06-24T10_17_10.587237/metaseq/modules/transformer_layer.py", line 541 line 1110, in _call_impl
   return forward_call(*input, **kwargs)
 File "/p/project/opengptx-elm/john2/opengpt/OPT/fairscale/fairscale/nn/misc/flatten_params_wrapper.py", line 459, in forward
   return self.module(*inputs, **kwinputs)
 File "/p/project/opengptx-elm/john2/opengpt/OPT/env/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
   return forward_call(*input, **kwargs)
 File "/p/scratch/opengptx-elm/john2/opengpt/OPT/slurm_snapshot_code_oss/2022-06-24T10_17_10.587237/metaseq/models/base_model.py", line 371, in forward
   return self.decoder(src_tokens, **kwargs)
 File "/p/project/opengptx-elm/john2/opengpt/OPT/env/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
   return forward_call(*input, **kwargs)
 File "/p/scratch/opengptx-elm/john2/opengpt/OPT/slurm_snapsh line 1110, in _call_impl
   return forward_call(*input, **kwargs)
 File "/p/project/opengptx-elm/john2/opengpt/OPT/fairscale/fairscale/nn/misc/flatten_params_wrapper.py", line 459, in forward
   return self.module(*inputs, **kwinputs)
 File "/p/project/opengptx-elm/john2/opengpt/OPT/env/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
   return forward_call(*input, **kwargs)
 File "/p/scratch/opengptx-elm/john2/opengpt/OPT/slurm_snapshot_code_oss/2022-06-24T10_17_10.587237/metaseq/models/base_model.py", line 371, in forward
   return self.decoder(src_tokens, **kwargs)
 File "/p/project/opengptx-elm/john2/opengpt/OPT/env/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
   return forward_call(*input, **kwargs)
 File "/p/scratch/opengptx-elm/john2/opengpt/OPT/slurm_snapshhn2/opengpt/OPT/env/lib/python3.8/site-packages/torch/cuda/random.py", line 31, in get_rng_state
   return default_generator.get_state()
RuntimeError: CUDA error: an illegal memory access was encountered
, in forward
   x, attn = self.forward_attention(
 File "/p/scratch/opengptx-elm/john2/opengpt/OPT/slurm_snapshot_code_oss/2022-06-24T10_17_10.587237/metaseq/model_parallel/modules/transformer_layer.py", line 169, in forward_attention
   (attn_output, attn_bias), attn_weights = self.self_attn(
 File "/p/project/opengptx-elm/john2/opengpt/OPT/env/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
   return forward_call(*input, **kwargs)
 File "/p/scratch/opengptx-elm/john2/opengpt/OPT/slurm_snapshot_code_oss/2022-06-24T10_17_10.587237/metaseq/model_parallel/modules/multihead_attention.py", line 380, in forward
   with get_cuda_rng_tracker().fork():
 File "/p/software/juwelsbooster/stages/2020/software/Python/3.8.5-GCCcore-10.3.0/lib/python3.8/contextlib.py", line 113, in __enter__
   return next(self.gen)
 File "/p/project/opengptx-elm/john2/opengpt/OPT/Megatron-LM/megatron/mpu/random.py", line 161, in fork
   orig_cuda_rng_state = torch.cuda.get_rng_state()
 File "/p/project/opengptx-elm/joot_code_oss/2022-06-24T10_17_10.587237/metaseq/models/transformer.py", line 651, in forward
   x, extra = self.extract_features(
 File "/p/scratch/opengptx-elm/john2/opengpt/OPT/slurm_snapshot_code_oss/2022-06-24T10_17_10.587237/metaseq/models/transformer.py", line 676, in extract_features
   return self.extract_features_scriptable(
 File "/p/scratch/opengptx-elm/john2/opengpt/OPT/slurm_snapshot_code_oss/2022-06-24T10_17_10.587237/metaseq/models/transformer.py", line 740, in extract_features_scriptable
   x, layer_attn, _, l_aux_i = layer(
 File "/p/project/opengptx-elm/john2/opengpt/OPT/env/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
   return forward_call(*input, **kwargs)
 File "/p/project/opengptx-elm/john2/opengpt/OPT/fairscale/fairscale/nn/data_parallel/fully_sharded_data_parallel.py", line 1152, in forward
   outputs = self.module(*args, **kwargs)
 File "/p/project/opengptx-elm/john2/opengpt/OPT/env/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impot_code_oss/2022-06-24T10_17_10.587237/metaseq/models/transformer.py", line 651, in forward
   x, extra = self.extract_features(
 File "/p/scratch/opengptx-elm/john2/opengpt/OPT/slurm_snapshot_code_oss/2022-06-24T10_17_10.587237/metaseq/models/transformer.py", line 676, in extract_features
   return self.extract_features_scriptable(
 File "/p/scratch/opengptx-elm/john2/opengpt/OPT/slurm_snapshot_code_oss/2022-06-24T10_17_10.587237/metaseq/models/transformer.py", line 740, in extract_features_scriptable
   x, layer_attn, _, l_aux_i = layer(
 File "/p/project/opengptx-elm/john2/opengpt/OPT/env/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
   return forward_call(*input, **kwargs)
 File "/p/project/opengptx-elm/john2/opengpt/OPT/fairscale/fairscale/nn/data_parallel/fully_sharded_data_parallel.py", line 1152, in forward
   outputs = self.module(*args, **kwargs)
 File "/p/project/opengptx-elm/john2/opengpt/OPT/env/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_imphn2/opengpt/OPT/env/lib/python3.8/site-packages/torch/cuda/random.py", line 31, in get_rng_state
   return default_generator.get_state()
RuntimeError: CUDA error: an illegal memory access was encountered
l
   return forward_call(*input, **kwargs)
 File "/p/project/opengptx-elm/john2/opengpt/OPT/fairscale/fairscale/nn/misc/flatten_params_wrapper.py", line 459, in forward
   return self.module(*inputs, **kwinputs)
 File "/p/project/opengptx-elm/john2/opengpt/OPT/env/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
   return forward_call(*input, **kwargs)
 File "/p/scratch/opengptx-elm/john2/opengpt/OPT/slurm_snapshot_code_oss/2022-06-24T10_17_10.587237/metaseq/modules/checkpoint_activation_wrapper/checkpoint_activations.py", line 207, in _checkpointed_forward
   output = CheckpointFunction.apply(
 File "/p/scratch/opengptx-elm/john2/opengpt/OPT/slurm_snapshot_code_oss/2022-06-24T10_17_10.587237/metaseq/modules/checkpoint_activation_wrapper/checkpoint_activations.py", line 307, in forward
   outputs = run_function(*unpacked_args, **unpacked_kwargs)
 File "/p/scratch/opengptx-elm/john2/opengpt/OPT/slurm_snapshot_code_oss/2022-06-24T10_17_10.587237/metaseq/modules/transformer_layer.py", line 541l
   return forward_call(*input, **kwargs)
 File "/p/project/opengptx-elm/john2/opengpt/OPT/fairscale/fairscale/nn/misc/flatten_params_wrapper.py", line 459, in forward
   return self.module(*inputs, **kwinputs)
 File "/p/project/opengptx-elm/john2/opengpt/OPT/env/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
   return forward_call(*input, **kwargs)
 File "/p/scratch/opengptx-elm/john2/opengpt/OPT/slurm_snapshot_code_oss/2022-06-24T10_17_10.587237/metaseq/modules/checkpoint_activation_wrapper/checkpoint_activations.py", line 207, in _checkpointed_forward
   output = CheckpointFunction.apply(
 File "/p/scratch/opengptx-elm/john2/opengpt/OPT/slurm_snapshot_code_oss/2022-06-24T10_17_10.587237/metaseq/modules/checkpoint_activation_wrapper/checkpoint_activations.py", line 307, in forward
   outputs = run_function(*unpacked_args, **unpacked_kwargs)
 File "/p/scratch/opengptx-elm/john2/opengpt/OPT/slurm_snapshot_code_oss/2022-06-24T10_17_10.587237/metaseq/modules/transformer_layer.py", line 541, in forward
   x, attn = self.forward_attention(
 File "/p/scratch/opengptx-elm/john2/opengpt/OPT/slurm_snapshot_code_oss/2022-06-24T10_17_10.587237/metaseq/model_parallel/modules/transformer_layer.py", line 169, in forward_attention
   (attn_output, attn_bias), attn_weights = self.self_attn(
 File "/p/project/opengptx-elm/john2/opengpt/OPT/env/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
   return forward_call(*input, **kwargs)
 File "/p/scratch/opengptx-elm/john2/opengpt/OPT/slurm_snapshot_code_oss/2022-06-24T10_17_10.587237/metaseq/model_parallel/modules/multihead_attention.py", line 380, in forward
   with get_cuda_rng_tracker().fork():
 File "/p/software/juwelsbooster/stages/2020/software/Python/3.8.5-GCCcore-10.3.0/lib/python3.8/contextlib.py", line 113, in __enter__
   return next(self.gen)
 File "/p/project/opengptx-elm/john2/opengpt/OPT/Megatron-LM/megatron/mpu/random.py", line 161, in fork
   orig_cuda_rng_state = torch.cuda.get_rng_state()
 File "/p/project/opengptx-elm/jo, in forward
   x, attn = self.forward_attention(
 File "/p/scratch/opengptx-elm/john2/opengpt/OPT/slurm_snapshot_code_oss/2022-06-24T10_17_10.587237/metaseq/model_parallel/modules/transformer_layer.py", line 169, in forward_attention
   (attn_output, attn_bias), attn_weights = self.self_attn(
 File "/p/project/opengptx-elm/john2/opengpt/OPT/env/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
   return forward_call(*input, **kwargs)
 File "/p/scratch/opengptx-elm/john2/opengpt/OPT/slurm_snapshot_code_oss/2022-06-24T10_17_10.587237/metaseq/model_parallel/modules/multihead_attention.py", line 380, in forward
   with get_cuda_rng_tracker().fork():
 File "/p/software/juwelsbooster/stages/2020/software/Python/3.8.5-GCCcore-10.3.0/lib/python3.8/contextlib.py", line 113, in __enter__
   return next(self.gen)
 File "/p/project/opengptx-elm/john2/opengpt/OPT/Megatron-LM/megatron/mpu/random.py", line 161, in fork
   orig_cuda_rng_state = torch.cuda.get_rng_state()
 File "/p/project/opengptx-elm/john2/opengpt/OPT/env/lib/python3.8/site-packages/torch/cuda/random.py", line 31, in get_rng_state
   return default_generator.get_state()
RuntimeError: CUDA error: an illegal memory access was encountered
hn2/opengpt/OPT/env/lib/python3.8/site-packages/torch/cuda/random.py", line 31, in get_rng_state
   return default_generator.get_state()
RuntimeError: CUDA error: an illegal memory access was encountered
terminate called without an active exception
terminate called after throwing an instance of 'c10::CUDAError'
 what():  CUDA error: an illegal memory access was encountered
Exception raised from create_event_internal at ../c10/cuda/CUDACachingAllocator.cpp:1230 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x42 (0x1466732937d2 in /p/project/opengptx-elm/john2/opengpt/OPT/env/lib/python3.8/site-packages/torch/lib/libc10.so)
frame #1: <unknown function> + 0x239de (0x1466abe789de in /p/project/opengptx-elm/john2/opengpt/OPT/env/lib/python3.8/site-packages/torch/lib/libc10_cuda.so)
frame #2: c10::cuda::CUDACachingAllocator::raw_delete(void*) + 0x22d (0x1466abe7a57d in /p/project/opengptx-elm/john2/opengpt/OPT/env/lib/python3.8/site-packages/torch/lib/libc10_cuda.so)
frame #3: <unknown function> + 0x300568 (0x1467286ab568 in /p/project/opengptx-elm/john2/opengpt/OPT/env/lib/python3.8/site-packages/torch/lib/libtorch_python.so)
frame #4: c10::TensorImpl::release_resources() + 0x175 (0x14667327c005 in /p/project/opengptx-elm/john2/opengpt/OPT/env/lib/python3.8/site-packages/torch/lib/libc10.so)
frame #5: <unknown function> + 0x1ee569 (0x146728599569 in /p/project/opengptx-elm/john2/opengpt/OPT/env/lib/python3.8/site-packages/torch/lib/libtorch_python.so)
frame #6: <unknown function> + 0x4d9c78 (0x146728884c78 in /p/project/opengptx-elm/john2/opengpt/OPT/env/lib/python3.8/site-packages/torch/lib/libtorch_python.so)
frame #7: THPVariable_subclass_dealloc(_object*) + 0x292 (0x146728884f72 in /p/project/opengptx-elm/john2/opengpt/OPT/env/lib/python3.8/site-packages/torch/lib/libtorch_python.so)
<omitting python frames>
frame #18: __libc_start_main + 0xf3 (0x14672fffe493 in /usr/lib64/libc.so.6)
frame #19: _start + 0x2e (0x4006de in /p/project/opengptx-elm/john2/opengpt/OPT/env/bin/python)
Assignee
Assign to
Time tracking