GRIFFIN: Effective Token Alignment for Faster Speculative Decoding

GRIFFIN

GRIFFIN is a novel framework designed to accelerate inference in large language models (LLMs) by generating multiple draft tokens simultaneously. It addresses token misalignment between training and decoding phases through a token-alignable training strategy and a token-alignable draft model. As presented in the paper GRIFFIN: Effective Token Alignment for Faster Speculative Decoding, GRIFFIN achieves significant speedup ratios exceeding 7% and an average acceptance length improvement of over 8% compared to current speculative decoding state-of-the-art methods.

This repository provides the implementation of GRIFFIN, including its token-alignable training strategy and token-alignable draft model.

Code: https://github.com/hsj576/GRIFFIN

Acceleration Demo

Acceleration demo of GRIFFIN for llama3-8B in a 4090GPU

Acceleration demo of GRIFFIN for llama3-8B in a 4090GPU.

Usage

You can use the provided eagenerate function for accelerated generation, similar to using generate from Hugging Face's Transformers library.

import torch
from model.ea_model_griffin import EaModel
from fastchat.model import get_conversation_template

# Ensure base_model_path and EAGLE_model_path point to your model directories
# Example:
# base_model_path = "meta-llama/Llama-3-8B-Instruct"
# EAGLE_model_path = "husj576/GRIFFIN-llama3-instruct-8B"

model = EaModel.from_pretrained(
    base_model_path="lmsys/vicuna-7b-v1.5", # Replace with your base model path
    ea_model_path="husj576/GRIFFIN-Vicuna-7B-v1.5", # Replace with your GRIFFIN model path
    torch_dtype=torch.float16,
    low_cpu_mem_usage=True,
    device_map="auto",
    total_token=-1 # -1 for auto-configuration by EAGLE-2
)
model.eval()

your_message="Hello"
conv = get_conversation_template("vicuna") # Use appropriate conversation template for your base model
conv.append_message(conv.roles[0], your_message)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
input_ids=model.tokenizer([prompt]).input_ids
input_ids = torch.as_tensor(input_ids).cuda()

output_ids=model.eagenerate(input_ids,temperature=0.5,max_new_tokens=512)
output=model.tokenizer.decode(output_ids[0])
print(output)

Note: When using chat models like Vicuna, LLaMA2-Chat, or LLaMA3-Instruct, you must use the correct chat template to avoid abnormal model output and ensure optimal GRIFFIN performance.

Citation

@misc{hu2025griffineffectivetokenalignment,
      title={GRIFFIN: Effective Token Alignment for Faster Speculative Decoding}, 
      author={Shijing Hu and Jingyang Li and Xingyu Xie and Zhihui Lu and Kim-Chuan Toh and Pan Zhou},
      year={2025},
      eprint={2502.11018},
      archivePrefix={arXiv},
      primaryClass={cs.CL},
      url={https://arxiv.org/abs/2502.11018}, 
}
Downloads last month
3
Safetensors
Model size
0.5B params
Tensor type
F32
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Paper for husj576/GRIFFIN-Vicuna-7B-v1.5