"""Utils for training/fine-tuning scripts.""" import torch import os from .constants import ACTION_DIM, ACTION_TOKEN_BEGIN_IDX, IGNORE_INDEX def get_current_action_mask(token_ids): # Create a tensor marking positions of IGNORE_INDEX newline_positions = token_ids != IGNORE_INDEX # Calculate cumulative sum to identify regions between newlines cumsum = torch.cumsum(newline_positions, dim=1) # Create the mask mask = (1 <= cumsum) & (cumsum <= ACTION_DIM) # Extract the action part only action_tokens_only_mask = token_ids > ACTION_TOKEN_BEGIN_IDX mask = action_tokens_only_mask * mask return mask def get_next_actions_mask(token_ids): # Create a tensor marking positions of IGNORE_INDEX newline_positions = token_ids != IGNORE_INDEX # Calculate cumulative sum to identify regions between newlines cumsum = torch.cumsum(newline_positions, dim=1) # Create the mask mask = cumsum > ACTION_DIM # Extract the action part only action_tokens_only_mask = token_ids > ACTION_TOKEN_BEGIN_IDX mask = action_tokens_only_mask * mask return mask def compute_token_accuracy(predicted_token_ids, ground_truth_token_ids, mask): correct_preds = (predicted_token_ids == ground_truth_token_ids) & mask accuracy = correct_preds.sum().float() / mask.sum().float() return accuracy def compute_actions_l1_loss(action_tokenizer, predicted_token_ids, ground_truth_token_ids, mask): pred_continuous_actions = torch.tensor( action_tokenizer.decode_token_ids_to_actions(predicted_token_ids[mask].cpu().numpy()) ) true_continuous_actions = torch.tensor( action_tokenizer.decode_token_ids_to_actions(ground_truth_token_ids[mask].cpu().numpy()) ) l1_loss = torch.nn.functional.l1_loss(pred_continuous_actions, true_continuous_actions) return l1_loss def find_checkpoint_file(pretrained_checkpoint, file_pattern) : """ Find a specific checkpoint file matching a pattern. Args: pretrained_checkpoint: Path to the checkpoint directory file_pattern: String pattern to match in filenames Returns: str: Path to the matching checkpoint file Raises: AssertionError: If no files or multiple files match the pattern """ assert os.path.isdir(pretrained_checkpoint), f"Checkpoint path must be a directory: {pretrained_checkpoint}" checkpoint_files = [] for filename in os.listdir(pretrained_checkpoint): if file_pattern in filename and "checkpoint" in filename: full_path = os.path.join(pretrained_checkpoint, filename) checkpoint_files.append(full_path) assert len(checkpoint_files) == 1, ( f"Expected exactly 1 {file_pattern} checkpoint but found {len(checkpoint_files)} in directory: {pretrained_checkpoint}" ) return checkpoint_files[0] def load_component_state_dict(checkpoint_path) : """ Load a component's state dict from checkpoint and handle DDP prefix if present. Args: checkpoint_path: Path to the checkpoint file Returns: Dict: The processed state dictionary for loading """ state_dict = torch.load(checkpoint_path, weights_only=True) # If the component was trained with DDP, elements in the state dict have prefix "module." which we must remove new_state_dict = {} for k, v in state_dict.items(): if k.startswith("module."): new_state_dict[k[7:]] = v else: new_state_dict[k] = v return new_state_dict