AirRep-Flan-Small / modeling_airrep.py
sunweiwei's picture
Upload folder using huggingface_hub
411a334 verified
"""AirRep model implementation."""
from typing import Optional
import torch
import torch.nn as nn
from transformers import BertModel, BertConfig, PreTrainedModel
from transformers.modeling_outputs import BaseModelOutput
def mean_pooling(last_hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
"""Apply mean pooling to hidden states."""
last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
class AirRepConfig(BertConfig):
"""Configuration class for AirRep model."""
model_type = "airrep"
def __init__(
self,
**kwargs
):
super().__init__(**kwargs)
class AirRepModel(PreTrainedModel):
"""
AirRep model with BERT encoder and projection layer.
This is a standalone model, not a wrapper.
"""
config_class = AirRepConfig
base_model_prefix = "airrep"
def __init__(self, config: AirRepConfig):
super().__init__(config)
self.config = config
# BERT encoder
self.bert = BertModel(config, add_pooling_layer=False)
# Projection layer
self.projector = nn.Linear(
config.hidden_size,
config.hidden_size,
dtype=torch.bfloat16
)
# Initialize weights
self.post_init()
def forward(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
**kwargs
) -> torch.Tensor:
"""
Forward pass.
Args:
input_ids: Input token IDs
attention_mask: Attention mask
token_type_ids: Token type IDs
Returns:
Pooled and projected embeddings (batch_size, hidden_size)
"""
# Get BERT outputs
outputs = self.bert(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
output_hidden_states=True,
return_dict=True,
)
# Mean pooling
last_hidden_state = outputs.last_hidden_state
if attention_mask is None:
attention_mask = torch.ones_like(input_ids)
pooled = mean_pooling(last_hidden_state, attention_mask)
# Project
projected = self.projector(pooled)
return projected
def save_pretrained(self, save_directory: str, **kwargs):
"""Save model and config."""
super().save_pretrained(save_directory, **kwargs)