GRIFFIN: Effective Token Alignment for Faster Speculative Decoding
GRIFFIN is a novel framework designed to accelerate inference in large language models (LLMs) by generating multiple draft tokens simultaneously. It addresses token misalignment between training and decoding phases through a token-alignable training strategy and a token-alignable draft model. As presented in the paper GRIFFIN: Effective Token Alignment for Faster Speculative Decoding, GRIFFIN achieves significant speedup ratios exceeding 7% and an average acceptance length improvement of over 8% compared to current speculative decoding state-of-the-art methods.
This repository provides the implementation of GRIFFIN, including its token-alignable training strategy and token-alignable draft model.
Code: https://github.com/hsj576/GRIFFIN
Acceleration Demo
Usage
You can use the provided eagenerate function for accelerated generation, similar to using generate from Hugging Face's Transformers library.
import torch
from model.ea_model_griffin import EaModel
from fastchat.model import get_conversation_template
# Ensure base_model_path and EAGLE_model_path point to your model directories
# Example:
# base_model_path = "meta-llama/Llama-3-8B-Instruct"
# EAGLE_model_path = "husj576/GRIFFIN-llama3-instruct-8B"
model = EaModel.from_pretrained(
base_model_path="lmsys/vicuna-7b-v1.5", # Replace with your base model path
ea_model_path="husj576/GRIFFIN-Vicuna-7B-v1.5", # Replace with your GRIFFIN model path
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
device_map="auto",
total_token=-1 # -1 for auto-configuration by EAGLE-2
)
model.eval()
your_message="Hello"
conv = get_conversation_template("vicuna") # Use appropriate conversation template for your base model
conv.append_message(conv.roles[0], your_message)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
input_ids=model.tokenizer([prompt]).input_ids
input_ids = torch.as_tensor(input_ids).cuda()
output_ids=model.eagenerate(input_ids,temperature=0.5,max_new_tokens=512)
output=model.tokenizer.decode(output_ids[0])
print(output)
Note: When using chat models like Vicuna, LLaMA2-Chat, or LLaMA3-Instruct, you must use the correct chat template to avoid abnormal model output and ensure optimal GRIFFIN performance.
Citation
@misc{hu2025griffineffectivetokenalignment,
title={GRIFFIN: Effective Token Alignment for Faster Speculative Decoding},
author={Shijing Hu and Jingyang Li and Xingyu Xie and Zhihui Lu and Kim-Chuan Toh and Pan Zhou},
year={2025},
eprint={2502.11018},
archivePrefix={arXiv},
primaryClass={cs.CL},
url={https://arxiv.org/abs/2502.11018},
}
- Downloads last month
- 3