"""AirRep model implementation.""" from typing import Optional import torch import torch.nn as nn from transformers import BertModel, BertConfig, PreTrainedModel from transformers.modeling_outputs import BaseModelOutput def mean_pooling(last_hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: """Apply mean pooling to hidden states.""" last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0) return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] class AirRepConfig(BertConfig): """Configuration class for AirRep model.""" model_type = "airrep" def __init__( self, **kwargs ): super().__init__(**kwargs) class AirRepModel(PreTrainedModel): """ AirRep model with BERT encoder and projection layer. This is a standalone model, not a wrapper. """ config_class = AirRepConfig base_model_prefix = "airrep" def __init__(self, config: AirRepConfig): super().__init__(config) self.config = config # BERT encoder self.bert = BertModel(config, add_pooling_layer=False) # Projection layer self.projector = nn.Linear( config.hidden_size, config.hidden_size, dtype=torch.bfloat16 ) # Initialize weights self.post_init() def forward( self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None, **kwargs ) -> torch.Tensor: """ Forward pass. Args: input_ids: Input token IDs attention_mask: Attention mask token_type_ids: Token type IDs Returns: Pooled and projected embeddings (batch_size, hidden_size) """ # Get BERT outputs outputs = self.bert( input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, output_hidden_states=True, return_dict=True, ) # Mean pooling last_hidden_state = outputs.last_hidden_state if attention_mask is None: attention_mask = torch.ones_like(input_ids) pooled = mean_pooling(last_hidden_state, attention_mask) # Project projected = self.projector(pooled) return projected def save_pretrained(self, save_directory: str, **kwargs): """Save model and config.""" super().save_pretrained(save_directory, **kwargs)