|
|
"""AirRep model implementation.""" |
|
|
|
|
|
from typing import Optional |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from transformers import BertModel, BertConfig, PreTrainedModel |
|
|
from transformers.modeling_outputs import BaseModelOutput |
|
|
|
|
|
|
|
|
def mean_pooling(last_hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: |
|
|
"""Apply mean pooling to hidden states.""" |
|
|
last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0) |
|
|
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] |
|
|
|
|
|
|
|
|
class AirRepConfig(BertConfig): |
|
|
"""Configuration class for AirRep model.""" |
|
|
|
|
|
model_type = "airrep" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
**kwargs |
|
|
): |
|
|
super().__init__(**kwargs) |
|
|
|
|
|
|
|
|
class AirRepModel(PreTrainedModel): |
|
|
""" |
|
|
AirRep model with BERT encoder and projection layer. |
|
|
|
|
|
This is a standalone model, not a wrapper. |
|
|
""" |
|
|
|
|
|
config_class = AirRepConfig |
|
|
base_model_prefix = "airrep" |
|
|
|
|
|
def __init__(self, config: AirRepConfig): |
|
|
super().__init__(config) |
|
|
self.config = config |
|
|
|
|
|
|
|
|
self.bert = BertModel(config, add_pooling_layer=False) |
|
|
|
|
|
|
|
|
self.projector = nn.Linear( |
|
|
config.hidden_size, |
|
|
config.hidden_size, |
|
|
dtype=torch.bfloat16 |
|
|
) |
|
|
|
|
|
|
|
|
self.post_init() |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids: torch.Tensor, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
token_type_ids: Optional[torch.Tensor] = None, |
|
|
**kwargs |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Forward pass. |
|
|
|
|
|
Args: |
|
|
input_ids: Input token IDs |
|
|
attention_mask: Attention mask |
|
|
token_type_ids: Token type IDs |
|
|
|
|
|
Returns: |
|
|
Pooled and projected embeddings (batch_size, hidden_size) |
|
|
""" |
|
|
|
|
|
outputs = self.bert( |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
token_type_ids=token_type_ids, |
|
|
output_hidden_states=True, |
|
|
return_dict=True, |
|
|
) |
|
|
|
|
|
|
|
|
last_hidden_state = outputs.last_hidden_state |
|
|
if attention_mask is None: |
|
|
attention_mask = torch.ones_like(input_ids) |
|
|
pooled = mean_pooling(last_hidden_state, attention_mask) |
|
|
|
|
|
|
|
|
projected = self.projector(pooled) |
|
|
|
|
|
return projected |
|
|
|
|
|
def save_pretrained(self, save_directory: str, **kwargs): |
|
|
"""Save model and config.""" |
|
|
super().save_pretrained(save_directory, **kwargs) |