klemenk commited on
Commit
9f478ca
·
verified ·
1 Parent(s): ee0c1d2

Upload AuriStream base model code

Browse files
Files changed (1) hide show
  1. modeling_auristream.py +22 -9
modeling_auristream.py CHANGED
@@ -253,7 +253,8 @@ class AuriStreamModel(AuriStreamPreTrainedModel):
253
  AuriStream speech language model.
254
 
255
  A GPT-like transformer model for cochlear token prediction with optional
256
- multi-token prediction (MTP) heads for speculative decoding.
 
257
 
258
  Developed by Greta Tuckute and Klemen Kotar.
259
  """
@@ -302,6 +303,7 @@ class AuriStreamModel(AuriStreamPreTrainedModel):
302
  self,
303
  input_ids: Optional[torch.LongTensor] = None,
304
  labels: Optional[torch.LongTensor] = None,
 
305
  output_hidden_states: Optional[bool] = False,
306
  return_dict: Optional[bool] = True,
307
  # Legacy arguments for compatibility
@@ -314,13 +316,16 @@ class AuriStreamModel(AuriStreamPreTrainedModel):
314
  Args:
315
  input_ids: Input token IDs of shape (batch_size, seq_len)
316
  labels: Target token IDs for computing loss
 
317
  output_hidden_states: Whether to return all hidden states
318
  return_dict: Whether to return a dict or tuple
319
  seq: Legacy argument (alias for input_ids)
320
  tgt: Legacy argument (alias for labels)
321
 
322
  Returns:
323
- CausalLMOutput with logits and optional loss
 
 
324
  """
325
  # Handle legacy arguments
326
  if seq is not None:
@@ -347,6 +352,15 @@ class AuriStreamModel(AuriStreamPreTrainedModel):
347
  # Final layer norm and output head
348
  x = self.ln_f(x)
349
  logits = self.lm_head(x)
 
 
 
 
 
 
 
 
 
350
 
351
  # Compute loss if labels provided
352
  loss = None
@@ -358,21 +372,20 @@ class AuriStreamModel(AuriStreamPreTrainedModel):
358
 
359
  # Multi-token prediction loss
360
  if self.future_heads is not None:
361
- for i, head in enumerate(self.future_heads):
362
- future_logits = head(x[:, :-(i+1)])
363
  loss = loss + F.cross_entropy(
364
- future_logits.reshape(-1, self.config.vocab_size),
365
- labels[:, (i+1):].reshape(-1),
366
  )
367
 
368
  if not return_dict:
369
  if labels is not None:
370
- return logits, loss
371
- return logits, None
372
 
373
  return CausalLMOutput(
374
  loss=loss,
375
- logits=logits,
376
  hidden_states=all_hidden_states if output_hidden_states else None,
377
  )
378
 
 
253
  AuriStream speech language model.
254
 
255
  A GPT-like transformer model for cochlear token prediction with optional
256
+ multi-token prediction (MTP) heads for future suprisal prediction and speculative decoding.
257
+ (These heads also improve representation learning).
258
 
259
  Developed by Greta Tuckute and Klemen Kotar.
260
  """
 
303
  self,
304
  input_ids: Optional[torch.LongTensor] = None,
305
  labels: Optional[torch.LongTensor] = None,
306
+ output_logits: Optional[bool] = False,
307
  output_hidden_states: Optional[bool] = False,
308
  return_dict: Optional[bool] = True,
309
  # Legacy arguments for compatibility
 
316
  Args:
317
  input_ids: Input token IDs of shape (batch_size, seq_len)
318
  labels: Target token IDs for computing loss
319
+ output_logits: Whether to return logits from all prediction heads
320
  output_hidden_states: Whether to return all hidden states
321
  return_dict: Whether to return a dict or tuple
322
  seq: Legacy argument (alias for input_ids)
323
  tgt: Legacy argument (alias for labels)
324
 
325
  Returns:
326
+ CausalLMOutput with logits and optional loss. When
327
+ output_logits=True, logits includes the main LM head first followed
328
+ by each future prediction head.
329
  """
330
  # Handle legacy arguments
331
  if seq is not None:
 
352
  # Final layer norm and output head
353
  x = self.ln_f(x)
354
  logits = self.lm_head(x)
355
+ all_logits = [logits] if output_logits else None
356
+ future_logits = []
357
+
358
+ if self.future_heads is not None and (labels is not None or output_logits):
359
+ for i, head in enumerate(self.future_heads):
360
+ head_logits = head(x[:, :-(i + 1)])
361
+ future_logits.append(head_logits)
362
+ if output_logits:
363
+ all_logits.append(head_logits)
364
 
365
  # Compute loss if labels provided
366
  loss = None
 
372
 
373
  # Multi-token prediction loss
374
  if self.future_heads is not None:
375
+ for i, head_logits in enumerate(future_logits):
 
376
  loss = loss + F.cross_entropy(
377
+ head_logits.reshape(-1, self.config.vocab_size),
378
+ labels[:, (i + 1):].reshape(-1),
379
  )
380
 
381
  if not return_dict:
382
  if labels is not None:
383
+ return (all_logits if output_logits else logits), loss
384
+ return (all_logits if output_logits else logits), None
385
 
386
  return CausalLMOutput(
387
  loss=loss,
388
+ logits=all_logits if output_logits else logits,
389
  hidden_states=all_hidden_states if output_hidden_states else None,
390
  )
391