Skip to content
GitLab
Projects Groups Snippets
  • /
  • Help
    • Help
    • Support
    • Community forum
    • Submit feedback
    • Contribute to GitLab
  • Sign in / Register
  • M metaseq
  • Project information
    • Project information
    • Activity
    • Labels
    • Members
  • Repository
    • Repository
    • Files
    • Commits
    • Branches
    • Tags
    • Contributors
    • Graph
    • Compare
  • Issues 95
    • Issues 95
    • List
    • Boards
    • Service Desk
    • Milestones
  • Merge requests 41
    • Merge requests 41
  • CI/CD
    • CI/CD
    • Pipelines
    • Jobs
    • Schedules
  • Deployments
    • Deployments
    • Environments
    • Releases
  • Packages and registries
    • Packages and registries
    • Package Registry
    • Infrastructure Registry
  • Monitor
    • Monitor
    • Incidents
  • Analytics
    • Analytics
    • Value stream
    • CI/CD
    • Repository
  • Wiki
    • Wiki
  • Snippets
    • Snippets
  • Activity
  • Graph
  • Create a new issue
  • Jobs
  • Commits
  • Issue Boards
Collapse sidebar
  • Administrator
  • metaseq
  • Merge requests
  • !258

Flash Attention

  • Review changes

  • Download
  • Email patches
  • Plain diff
Closed Administrator requested to merge flashattn into main Jul 26, 2022
  • Overview 1
  • Commits 3
  • Pipelines 0
  • Changes 1

Created by: stephenroller

Update: buggy. These results are invalid.

Patch Description Switches us to flashattention. Did it quick and dirty to test.

Experimental Setup Tried launching the 2.7B OPT baseline on 64 gpus. Observed that flashattention didn't like this, because its dimensionality of the attention heads is 40, and flashattention only supports 32/64/128. Resolved this by swapping # heads and head dimension, so that we have 40 heads in R^32 instead of 32 heads in R^40.

To control for this, launched 3 versions:

  • (green) Our Megatron-based attention implementation, with the original OPT setup
  • (orange) A Megatron-based attention implementation with 40 heads in R^32
  • (purple) Flash attention based with 40 heads in R^32

Results We observe the flash attention significantly reduces memory usage:

image

We observe that flash attention significantly increases throughput: image

I was quite surprised by how extreme this speedup is. Based on other's reports and the known FLOPS ratio, it should've been closer to 1.1x, not the 1.6x we're observing. Perhaps the key is that we get to avoid the transpose now? Our FLOPS/GPU aren't particularly high in any of these cases: the baseline is 77 TFLOPS/GPU and flash attention is 128 TFLOPS/GPU, neither of which is particularly great. We did not spend very much time ever optimizing the 2.7B, but it seems like we must be doing something really bad with it.

But here's the downside. It seems to significantly hurt stability and decrease convergence image

Unfortunately, unless the stability is resolved, I can't recommend flash attention replace our current implementation.

Next steps A version based on the triton implementation is likely preferable.

Assignee
Assign to
Reviewers
Request review from
Time tracking
Source branch: flashattn