Tuned Lens for Gemma 3 27B

This repository contains tuned lens translators for google/gemma-3-27b-it, trained following Belrose et al. (2023).

What is a Tuned Lens?

Unlike the simpler "logit lens" which directly applies the unembedding matrix to intermediate hidden states, the tuned lens learns layer-specific affine transformations to correct for representational drift.

Key Idea

For each layer β„“, we learn parameters (A_β„“, b_β„“) such that:

TunedLens_β„“(h_β„“) = LayerNorm[A_β„“ h_β„“ + b_β„“] W_U

where h_β„“ is the hidden state at layer β„“, A_β„“ is the learned change-of-basis matrix, b_β„“ is the learned bias vector, and W_U is the unembedding matrix (frozen, from base model).

Why Tuned Lens?

Logit lens is no doubt a great tool, but it: (a) assumes all layers use the same representation space, (b) systematically biases certain vocabulary items, and (c) can yield poor predictions for early and middle layers.

Tuned lens learns transformations to align layer β„“ β†’ final layer space in attempt to: (a) correct for non-zero residual stream expectations, (b) provide calibrated probability estimates, and (c) make early layers more interpretable.

Model Details

Architecture

  • Base Model: google/gemma-3-27b-it
  • Number of Layers: 62
  • Hidden Dimension: 4096
  • Translator Type: Affine transformation (A_β„“ ∈ ℝ^(4096Γ—4096), b_β„“ ∈ ℝ^4096)
  • Parameters per Layer: 16,781,312 (4096Β² + 4096)
  • Total Parameters: 1,040,441,344 (~1B parameters)

Training Details

We corrected several critical hyperparameters to match the recommendations from Belrose et al. (2023). Most notably, we use a learning rate of 1.0 (the paper default), weight decay of 0.001 (10⁻³), and gradient clipping to norm 1. We also implemented support for the Muon optimizer (Momentum Orthogonalized Update), which the paper strongly recommends for faster convergence and lower KL divergence. Our Muon implementation features Newton-Schulz orthogonalization, EMA-style momentum, decoupled weight decay, and learning rate adjustment based on matrix shape.

Dataset: C4 validation set (English) β€” Documents are tokenized to a maximum length of 2048 tokens, matching the paper's setup. We also use per-epoch shuffling with deterministic seeding for reproducibility when resuming from checkpoints.

Training Hyperparameters:

  • Training samples: ~362,000 (from C4 validation)
  • Epochs: 4
  • Batch size: 4
  • Gradient accumulation steps: 4
  • Effective batch size: 16
  • Learning rate: 1.0 (linear decay to 0)
  • Weight decay: 0.001 (10⁻³)
  • Gradient clipping: norm 1
  • Sequence length: 2048
  • Optimizer: Muon (fallback to SGD with Nesterov momentum 0.95)
  • Loss function: KL divergence D_KL(final || tuned_lens), per-token

Training Objective:

argmin E_x [D_KL(final_layer(x) || TunedLens_β„“(h_β„“(x)))]

Usage

Installation

pip install torch transformers huggingface_hub

Quick Start

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import hf_hub_download
import torch.nn.functional as F

# Load base model
model_name = "google/gemma-3-27b-it"
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Download and load translators
translator_path = hf_hub_download(
    repo_id="uzaymacar/gemma-3-27b-tuned-lens",
    filename="translators.pt"
)
translators_data = torch.load(translator_path, map_location=model.device, weights_only=False)

# Handle checkpoint format if needed
if 'translators' in translators_data:
    translators_data = translators_data['translators']

# Define translator module
class TunedLensTranslator(torch.nn.Module):
    def __init__(self, hidden_dim, device='cuda', dtype=torch.bfloat16):
        super().__init__()
        self.A = torch.nn.Parameter(torch.eye(hidden_dim, device=device, dtype=dtype))
        self.b = torch.nn.Parameter(torch.zeros(hidden_dim, device=device, dtype=dtype))

    def forward(self, h):
        return torch.matmul(h, self.A.t()) + self.b

# Initialize translators
translators = {}
for layer_idx, state_dict in translators_data.items():
    # Handle both int and string keys
    key = int(layer_idx) if isinstance(layer_idx, str) and layer_idx.isdigit() else layer_idx
    if not isinstance(key, int):
        continue  # Skip component translators
    translator = TunedLensTranslator(
        hidden_dim=4096,
        device=model.device,
        dtype=model.dtype
    )
    translator.load_state_dict(state_dict)
    translator.eval()
    translators[key] = translator

print(f"Loaded {len(translators)} translators")

Example: Predict from Layer 20

text = "The capital of France is"
inputs = tokenizer(text, return_tensors="pt").to(model.device)

with torch.no_grad():
    outputs = model(**inputs, output_hidden_states=True)

    # Get layer 20 hidden state
    layer_20_hidden = outputs.hidden_states[21]  # +1 because hidden_states[0] is embeddings

    # Apply tuned lens
    translated = translators[20](layer_20_hidden)
    normalized = model.model.model.norm(translated)  # Final layer norm
    logits = model.lm_head(normalized)

    # Get top prediction
    next_token_id = logits[0, -1].argmax()
    next_token = tokenizer.decode(next_token_id)

    print(f"Layer 20 predicts: {next_token}")
    # Output: Layer 20 predicts:  Paris

Example: Compare Predictions Across Layers

text = "The capital of France is"
inputs = tokenizer(text, return_tensors="pt").to(model.device)

with torch.no_grad():
    outputs = model(**inputs, output_hidden_states=True)

    print(f"Input: '{text}'")
    print("\nPredictions by layer:")
    print("-" * 60)

    for layer_idx in [0, 10, 20, 30, 40, 50, 61]:
        layer_hidden = outputs.hidden_states[layer_idx + 1]  # +1 for embedding offset
        translated = translators[layer_idx](layer_hidden)
        normalized = model.model.model.norm(translated)
        logits = model.lm_head(normalized)

        # Top-1 prediction
        next_token_id = logits[0, -1].argmax()
        next_token = tokenizer.decode(next_token_id)

        # Confidence
        probs = F.softmax(logits[0, -1], dim=-1)
        confidence = probs[next_token_id].item()

        print(f"Layer {layer_idx:2d}: '{next_token}' (confidence: {confidence:.2%})")

Files in This Repository

  • translators.pt: Translator state dictionaries (62 layers, ~1B parameters)
  • README.md: This file

Citation

If you use these tuned lens translators in your research, please cite:

@article{belrose2023eliciting,
  title={Eliciting Latent Predictions from Transformers with the Tuned Lens},
  author={Belrose, Nora and Ostrovsky, Igor and McKinney, Lev and Furman, Zach and Smith, Logan and Halawi, Danny and Biderman, Stella and Steinhardt, Jacob},
  journal={arXiv preprint arXiv:2303.08112},
  year={2023}
}

Acknowledgments

Contact

For questions or issues, please feel free to contact me at [email protected]


Model Card Author: Uzay Macar Date: November 2025 Version: 1.0

Downloads last month
5
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Model tree for uzaymacar/gemma-3-27b-tuned-lens

Finetuned
(385)
this model