Tuned Lens for Gemma 3 27B
This repository contains tuned lens translators for google/gemma-3-27b-it, trained following Belrose et al. (2023).
What is a Tuned Lens?
Unlike the simpler "logit lens" which directly applies the unembedding matrix to intermediate hidden states, the tuned lens learns layer-specific affine transformations to correct for representational drift.
Key Idea
For each layer β, we learn parameters (A_β, b_β) such that:
TunedLens_β(h_β) = LayerNorm[A_β h_β + b_β] W_U
where h_β is the hidden state at layer β, A_β is the learned change-of-basis matrix, b_β is the learned bias vector, and W_U is the unembedding matrix (frozen, from base model).
Why Tuned Lens?
Logit lens is no doubt a great tool, but it: (a) assumes all layers use the same representation space, (b) systematically biases certain vocabulary items, and (c) can yield poor predictions for early and middle layers.
Tuned lens learns transformations to align layer β β final layer space in attempt to: (a) correct for non-zero residual stream expectations, (b) provide calibrated probability estimates, and (c) make early layers more interpretable.
Model Details
Architecture
- Base Model: google/gemma-3-27b-it
- Number of Layers: 62
- Hidden Dimension: 4096
- Translator Type: Affine transformation (A_β β β^(4096Γ4096), b_β β β^4096)
- Parameters per Layer: 16,781,312 (4096Β² + 4096)
- Total Parameters: 1,040,441,344 (~1B parameters)
Training Details
We corrected several critical hyperparameters to match the recommendations from Belrose et al. (2023). Most notably, we use a learning rate of 1.0 (the paper default), weight decay of 0.001 (10β»Β³), and gradient clipping to norm 1. We also implemented support for the Muon optimizer (Momentum Orthogonalized Update), which the paper strongly recommends for faster convergence and lower KL divergence. Our Muon implementation features Newton-Schulz orthogonalization, EMA-style momentum, decoupled weight decay, and learning rate adjustment based on matrix shape.
Dataset: C4 validation set (English) β Documents are tokenized to a maximum length of 2048 tokens, matching the paper's setup. We also use per-epoch shuffling with deterministic seeding for reproducibility when resuming from checkpoints.
Training Hyperparameters:
- Training samples: ~362,000 (from C4 validation)
- Epochs: 4
- Batch size: 4
- Gradient accumulation steps: 4
- Effective batch size: 16
- Learning rate: 1.0 (linear decay to 0)
- Weight decay: 0.001 (10β»Β³)
- Gradient clipping: norm 1
- Sequence length: 2048
- Optimizer: Muon (fallback to SGD with Nesterov momentum 0.95)
- Loss function: KL divergence D_KL(final || tuned_lens), per-token
Training Objective:
argmin E_x [D_KL(final_layer(x) || TunedLens_β(h_β(x)))]
Usage
Installation
pip install torch transformers huggingface_hub
Quick Start
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import hf_hub_download
import torch.nn.functional as F
# Load base model
model_name = "google/gemma-3-27b-it"
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Download and load translators
translator_path = hf_hub_download(
repo_id="uzaymacar/gemma-3-27b-tuned-lens",
filename="translators.pt"
)
translators_data = torch.load(translator_path, map_location=model.device, weights_only=False)
# Handle checkpoint format if needed
if 'translators' in translators_data:
translators_data = translators_data['translators']
# Define translator module
class TunedLensTranslator(torch.nn.Module):
def __init__(self, hidden_dim, device='cuda', dtype=torch.bfloat16):
super().__init__()
self.A = torch.nn.Parameter(torch.eye(hidden_dim, device=device, dtype=dtype))
self.b = torch.nn.Parameter(torch.zeros(hidden_dim, device=device, dtype=dtype))
def forward(self, h):
return torch.matmul(h, self.A.t()) + self.b
# Initialize translators
translators = {}
for layer_idx, state_dict in translators_data.items():
# Handle both int and string keys
key = int(layer_idx) if isinstance(layer_idx, str) and layer_idx.isdigit() else layer_idx
if not isinstance(key, int):
continue # Skip component translators
translator = TunedLensTranslator(
hidden_dim=4096,
device=model.device,
dtype=model.dtype
)
translator.load_state_dict(state_dict)
translator.eval()
translators[key] = translator
print(f"Loaded {len(translators)} translators")
Example: Predict from Layer 20
text = "The capital of France is"
inputs = tokenizer(text, return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model(**inputs, output_hidden_states=True)
# Get layer 20 hidden state
layer_20_hidden = outputs.hidden_states[21] # +1 because hidden_states[0] is embeddings
# Apply tuned lens
translated = translators[20](layer_20_hidden)
normalized = model.model.model.norm(translated) # Final layer norm
logits = model.lm_head(normalized)
# Get top prediction
next_token_id = logits[0, -1].argmax()
next_token = tokenizer.decode(next_token_id)
print(f"Layer 20 predicts: {next_token}")
# Output: Layer 20 predicts: Paris
Example: Compare Predictions Across Layers
text = "The capital of France is"
inputs = tokenizer(text, return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model(**inputs, output_hidden_states=True)
print(f"Input: '{text}'")
print("\nPredictions by layer:")
print("-" * 60)
for layer_idx in [0, 10, 20, 30, 40, 50, 61]:
layer_hidden = outputs.hidden_states[layer_idx + 1] # +1 for embedding offset
translated = translators[layer_idx](layer_hidden)
normalized = model.model.model.norm(translated)
logits = model.lm_head(normalized)
# Top-1 prediction
next_token_id = logits[0, -1].argmax()
next_token = tokenizer.decode(next_token_id)
# Confidence
probs = F.softmax(logits[0, -1], dim=-1)
confidence = probs[next_token_id].item()
print(f"Layer {layer_idx:2d}: '{next_token}' (confidence: {confidence:.2%})")
Files in This Repository
translators.pt: Translator state dictionaries (62 layers, ~1B parameters)README.md: This file
Citation
If you use these tuned lens translators in your research, please cite:
@article{belrose2023eliciting,
title={Eliciting Latent Predictions from Transformers with the Tuned Lens},
author={Belrose, Nora and Ostrovsky, Igor and McKinney, Lev and Furman, Zach and Smith, Logan and Halawi, Danny and Biderman, Stella and Steinhardt, Jacob},
journal={arXiv preprint arXiv:2303.08112},
year={2023}
}
Acknowledgments
- Trained using the methodology from Belrose et al. (2023)
- Base model: google/gemma-3-27b-it
- Training data: C4 dataset
Contact
For questions or issues, please feel free to contact me at [email protected]
Model Card Author: Uzay Macar Date: November 2025 Version: 1.0
- Downloads last month
- 5