3️⃣Flash Attention

About Flash Attention-2

GPU는 메모리 대역폭과 병렬 처리에 최적화되어 있기 때문에 CPU와 달리 머신 러닝을 위한 표준 하드웨어로 선택됩니다. 최신 모델의 더 큰 크기를 따라잡거나 기존 및 구형 하드웨어에서 이러한 대형 모델을 실행하기 위해 GPU 추론 속도를 높이는 데 사용할 수 있는 몇 가지 최적화 방법이 있습니다.

Flash Attention(현재 FlashAttention-2)는 메모리 효율이 높은 주의 메커니즘으로 BetterTransformer(파이토치 기본 빠른 경로 실행)와 Bitsandbytes를 사용하여 모델을 더 낮은 정밀도로 정량화합니다.

FlashAttention-2는 standard attention 메커니즘을 더 빠르고 효율적으로 구현한 것으로, 추론 속도를 크게 높일 수 있습니다:

  • 시퀀스 길이에 대한 주의 계산을 추가로 병렬화합니다.

  • GPU 스레드 간 작업을 분할하여 스레드 간 통신 및 공유 메모리 읽기/쓰기를 감소시킵니다.

논문에서 40GB GPU의 타일형 플래시어텐션 계산 패턴과 메모리 계층구조를 보여줍니다. 오른쪽 차트는 어텐션 메커니즘의 여러 구성 요소를 융합하고 재정렬하여 얻을 수 있는 상대적인 속도 향상을 보여줍니다.

Flash Attention

설치 방법은 아래와 같습니다. 설치를 하면 Transformer 라이브러리와 함께 사용할 수 있습니다.

%pip install flash-attn --no-build-isolation
import torch
import bitsandbytes as bnb
from transformers import AutoModelForCausalLM, BitsAndBytesConfig

Memory Test

Flash Attention-2를 사용했을 때 실제 GPU에서 메모리가 어떻게 최적화되고 소모하는지 결과를 비교하겠습니다. 아래는 테스트 조건 입니다:

  • Model: Mistral-7B-v0.3

  • Batchsize: 2

  • Sequence Lenght: 3072

테스트는 3가지 조건으로 실행할 예정입니다.

  1. AutoModelForCausalML

  2. AutoModelForCausalML + BitsAndBytes

  3. AutoModelForCausalML + BitsAndBytes + flash_attention_2

1. Standard Inference

model = AutoModelForCausalLM.from_pretrained(
    "mistralai/Mistral-7B-v0.3",
    device_map="auto",
    use_flash_attention_2=False,
)

unit_gib = 1024 ** 3

curr_mem = torch.cuda.memory_reserved() / unit_gib
print(f"Initial: {curr_mem:.4f} GiB")

batch_size = 2
seq_len = 3072
inn = torch.LongTensor([[0] * seq_len] * batch_size)

try:
    with torch.no_grad():
        out = model(inn)

    infer_mem = torch.cuda.memory_reserved() / unit_gib
    print(f"Inference: {infer_mem:.4f} GiB")
except:
    print(f"Inference: OOM")
Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Initial: 7.5684 GiB
Inference: 8.7402 GiB

2. Inference with Bitsandbytes

bnb_config = BitsAndBytesConfig(  
    load_in_4bit= True,
    bnb_4bit_quant_type= "nf4",
    bnb_4bit_compute_dtype= torch.bfloat16,
    bnb_4bit_use_double_quant= False,
)

model = AutoModelForCausalLM.from_pretrained(
    "mistralai/Mistral-7B-v0.3",
    device_map="auto",
    use_flash_attention_2=False,
)

unit_gib = 1024 ** 3

curr_mem = torch.cuda.memory_reserved() / unit_gib
print(f"Initial: {curr_mem:.4f} GiB")

batch_size = 2
seq_len = 3072
inn = torch.LongTensor([[0] * seq_len] * batch_size)

try:
    with torch.no_grad():
        out = model(inn)

    infer_mem = torch.cuda.memory_reserved() / unit_gib
    print(f"Inference: {infer_mem:.4f} GiB")
except:
    print(f"Inference: OOM")
Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Initial: 7.5527 GiB
Inference: 8.7246 GiB

3. Inference with Bitsandbytes, FlashAttention2

bnb_config = BitsAndBytesConfig(  
    load_in_4bit= True,
    bnb_4bit_quant_type= "nf4",
    bnb_4bit_compute_dtype= torch.bfloat16,
    bnb_4bit_use_double_quant= False,
)

model: ModelCls = AutoModelForCausalLM.from_pretrained(
    "mistralai/Mistral-7B-v0.3",
    device_map="auto",
    quantization_config=bnb_config,
    use_flash_attention_2=True,
)

unit_gib = 1024 ** 3

curr_mem = torch.cuda.memory_reserved() / unit_gib
print(f"Initial: {curr_mem:.4f} GiB")

batch_size = 2
seq_len = 3072
inn = torch.LongTensor([[0] * seq_len] * batch_size)

try:
    with torch.no_grad():
        out = model(inn)

    infer_mem = torch.cuda.memory_reserved() / unit_gib
    print(f"Inference: {infer_mem:.4f} GiB")
except:
    print(f"Inference: OOM")
Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Initial: 9.3535 GiB
Inference: OOM

Conclusion

첫번째 일반적인 AutoModelForCausalML로 모델을 불러와 Inference 했을 때 8.742G의 GPU 메모리를 사용했습니다. 두번째 의 조합을 사용했을 때 8.724G의 메모리를 사용했습니다. 중요한 사실은 Initial Memory로 초기에 GPU 메모리 확보 입니다. 두 조건 모두 차이가 없습니다.

이제 마지막으로 flash_attention_2=True 로 했을 때 입니다. Initial Memory는 9.3535G로 일반적인 조건보다 더 많이 확보했고, Inference 후에는 0으로 GPU 공간을 잡아먹지 않았습니다.

Train이나 Inference 시 Flash Attention을 사용하면 메모리 최적화를 이룰 수 있습니다.

Last updated