Mixtral 8x7B Instruct with AWQ & Flash Attention 2 🔥
All in ~24GB GPU VRAM!
With the latest release of AutoAWQ - you can now run Mixtral 8x7B MoE with Flash Attention 2 for blazingly fast inference.
All in < 10 lines of code.
The only real change except loading AWQ weights is to pass attn_implementation="flash_attention_2" over to the .from_pretrained call whilst loading the model.
Here's a full run through:
1. Install AutoAWQ and transformers
pip install autoawq git https://github. com/huggingface/transformers.git
2. Initialise the tokeniser and the model
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
model_id = "casperhansen/mixtral-instruct-awq"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
low_cpu_mem_usage=True,
device_map="cuda:0",
attn_implementation="flash_attention_2")
3. Initialise the TextStreamer
streamer = TextStreamer(tokenizer,
skip_prompt=True,
skip_special_tokens=True)
4. Tokenise the inputs
tokens = tokenizer(
text,
return_tensors='pt'
).input_ids.to("cuda:0")
5. Generate!
generation_output = model.generate(
tokens,
streamer=streamer,
max_new_tokens=512
)
That's it! 🤗