Upload AuriStream base model code
Browse files- modeling_auristream.py +22 -9
modeling_auristream.py
CHANGED
|
@@ -253,7 +253,8 @@ class AuriStreamModel(AuriStreamPreTrainedModel):
|
|
| 253 |
AuriStream speech language model.
|
| 254 |
|
| 255 |
A GPT-like transformer model for cochlear token prediction with optional
|
| 256 |
-
multi-token prediction (MTP) heads for speculative decoding.
|
|
|
|
| 257 |
|
| 258 |
Developed by Greta Tuckute and Klemen Kotar.
|
| 259 |
"""
|
|
@@ -302,6 +303,7 @@ class AuriStreamModel(AuriStreamPreTrainedModel):
|
|
| 302 |
self,
|
| 303 |
input_ids: Optional[torch.LongTensor] = None,
|
| 304 |
labels: Optional[torch.LongTensor] = None,
|
|
|
|
| 305 |
output_hidden_states: Optional[bool] = False,
|
| 306 |
return_dict: Optional[bool] = True,
|
| 307 |
# Legacy arguments for compatibility
|
|
@@ -314,13 +316,16 @@ class AuriStreamModel(AuriStreamPreTrainedModel):
|
|
| 314 |
Args:
|
| 315 |
input_ids: Input token IDs of shape (batch_size, seq_len)
|
| 316 |
labels: Target token IDs for computing loss
|
|
|
|
| 317 |
output_hidden_states: Whether to return all hidden states
|
| 318 |
return_dict: Whether to return a dict or tuple
|
| 319 |
seq: Legacy argument (alias for input_ids)
|
| 320 |
tgt: Legacy argument (alias for labels)
|
| 321 |
|
| 322 |
Returns:
|
| 323 |
-
CausalLMOutput with logits and optional loss
|
|
|
|
|
|
|
| 324 |
"""
|
| 325 |
# Handle legacy arguments
|
| 326 |
if seq is not None:
|
|
@@ -347,6 +352,15 @@ class AuriStreamModel(AuriStreamPreTrainedModel):
|
|
| 347 |
# Final layer norm and output head
|
| 348 |
x = self.ln_f(x)
|
| 349 |
logits = self.lm_head(x)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 350 |
|
| 351 |
# Compute loss if labels provided
|
| 352 |
loss = None
|
|
@@ -358,21 +372,20 @@ class AuriStreamModel(AuriStreamPreTrainedModel):
|
|
| 358 |
|
| 359 |
# Multi-token prediction loss
|
| 360 |
if self.future_heads is not None:
|
| 361 |
-
for i,
|
| 362 |
-
future_logits = head(x[:, :-(i+1)])
|
| 363 |
loss = loss + F.cross_entropy(
|
| 364 |
-
|
| 365 |
-
labels[:, (i+1):].reshape(-1),
|
| 366 |
)
|
| 367 |
|
| 368 |
if not return_dict:
|
| 369 |
if labels is not None:
|
| 370 |
-
return logits, loss
|
| 371 |
-
return logits, None
|
| 372 |
|
| 373 |
return CausalLMOutput(
|
| 374 |
loss=loss,
|
| 375 |
-
logits=logits,
|
| 376 |
hidden_states=all_hidden_states if output_hidden_states else None,
|
| 377 |
)
|
| 378 |
|
|
|
|
| 253 |
AuriStream speech language model.
|
| 254 |
|
| 255 |
A GPT-like transformer model for cochlear token prediction with optional
|
| 256 |
+
multi-token prediction (MTP) heads for future suprisal prediction and speculative decoding.
|
| 257 |
+
(These heads also improve representation learning).
|
| 258 |
|
| 259 |
Developed by Greta Tuckute and Klemen Kotar.
|
| 260 |
"""
|
|
|
|
| 303 |
self,
|
| 304 |
input_ids: Optional[torch.LongTensor] = None,
|
| 305 |
labels: Optional[torch.LongTensor] = None,
|
| 306 |
+
output_logits: Optional[bool] = False,
|
| 307 |
output_hidden_states: Optional[bool] = False,
|
| 308 |
return_dict: Optional[bool] = True,
|
| 309 |
# Legacy arguments for compatibility
|
|
|
|
| 316 |
Args:
|
| 317 |
input_ids: Input token IDs of shape (batch_size, seq_len)
|
| 318 |
labels: Target token IDs for computing loss
|
| 319 |
+
output_logits: Whether to return logits from all prediction heads
|
| 320 |
output_hidden_states: Whether to return all hidden states
|
| 321 |
return_dict: Whether to return a dict or tuple
|
| 322 |
seq: Legacy argument (alias for input_ids)
|
| 323 |
tgt: Legacy argument (alias for labels)
|
| 324 |
|
| 325 |
Returns:
|
| 326 |
+
CausalLMOutput with logits and optional loss. When
|
| 327 |
+
output_logits=True, logits includes the main LM head first followed
|
| 328 |
+
by each future prediction head.
|
| 329 |
"""
|
| 330 |
# Handle legacy arguments
|
| 331 |
if seq is not None:
|
|
|
|
| 352 |
# Final layer norm and output head
|
| 353 |
x = self.ln_f(x)
|
| 354 |
logits = self.lm_head(x)
|
| 355 |
+
all_logits = [logits] if output_logits else None
|
| 356 |
+
future_logits = []
|
| 357 |
+
|
| 358 |
+
if self.future_heads is not None and (labels is not None or output_logits):
|
| 359 |
+
for i, head in enumerate(self.future_heads):
|
| 360 |
+
head_logits = head(x[:, :-(i + 1)])
|
| 361 |
+
future_logits.append(head_logits)
|
| 362 |
+
if output_logits:
|
| 363 |
+
all_logits.append(head_logits)
|
| 364 |
|
| 365 |
# Compute loss if labels provided
|
| 366 |
loss = None
|
|
|
|
| 372 |
|
| 373 |
# Multi-token prediction loss
|
| 374 |
if self.future_heads is not None:
|
| 375 |
+
for i, head_logits in enumerate(future_logits):
|
|
|
|
| 376 |
loss = loss + F.cross_entropy(
|
| 377 |
+
head_logits.reshape(-1, self.config.vocab_size),
|
| 378 |
+
labels[:, (i + 1):].reshape(-1),
|
| 379 |
)
|
| 380 |
|
| 381 |
if not return_dict:
|
| 382 |
if labels is not None:
|
| 383 |
+
return (all_logits if output_logits else logits), loss
|
| 384 |
+
return (all_logits if output_logits else logits), None
|
| 385 |
|
| 386 |
return CausalLMOutput(
|
| 387 |
loss=loss,
|
| 388 |
+
logits=all_logits if output_logits else logits,
|
| 389 |
hidden_states=all_hidden_states if output_hidden_states else None,
|
| 390 |
)
|
| 391 |
|