multimodal-reasoning-lab/Zebra-CoT
Viewer β’ Updated β’ 160k β’ 4.62k β’ 66
Compression: 144x (576 β 144 tokens, 1152 β 32 channels)
Feature Auto-Encoder (FAE) that compresses Qwen3-VL-8B vision features using CNN spatial pooling (2x per dimension) combined with channel compression. Trained on the Chess subset of Zebra-CoT.
| Input | Latent | Compression | |
|---|---|---|---|
| Tokens | 576 (24Γ24) | 144 (12Γ12) | 4x spatial |
| Channels | 1152 | 32 | 36x channel |
| Total values | 663,552 | 4,608 | 144x |
| Metric | Value |
|---|---|
| Eval CosSim (feature reconstruction) | 0.9776 |
| VLM Chess MCQ Accuracy | 3/20 (15%) |
| VLM Agreement with uncompressed baseline | 17/20 (85%) |
The s2 model exceeds the channel-only d32 baseline (CosSim 0.9670) while being 4x more compressed.
fae_encoder.pt β Spatial FAE encoder weightsfeature_decoder.pt β Spatial FAE decoder weights training_state.pt β Training metadata + feature normalization statsfae_spatial.py β Model architecture source codeimport torch
from fae_spatial import FAESpatialEncoder, FAESpatialDecoder
# Load checkpoint
state = torch.load("training_state.pt", map_location="cpu")
feat_mean = state["feat_mean"].cuda()
feat_std = state["feat_std"].cuda()
encoder = FAESpatialEncoder(embed_dim=1152, latent_dim=32, num_heads=16, pool_factor=2, use_vae=True)
encoder.load_state_dict(torch.load("fae_encoder.pt", map_location="cpu"))
encoder = encoder.cuda().eval()
decoder = FAESpatialDecoder(latent_dim=32, output_dim=1152, num_layers=6, num_heads=16, ffn_mult=2.7, pool_factor=2)
decoder.load_state_dict(torch.load("feature_decoder.pt", map_location="cpu"))
decoder = decoder.cuda().eval()
# Compress ViT features [B, 576, 1152]
vit_features_norm = (vit_features - feat_mean) / feat_std
z, mu, logvar = encoder(vit_features_norm) # [B, 144, 32]
reconstructed = decoder(z) # [B, 576, 1152]
reconstructed = reconstructed * feat_std + feat_mean