| """Utils for training/fine-tuning scripts.""" |
|
|
| import torch |
| import os |
| from .constants import ACTION_DIM, ACTION_TOKEN_BEGIN_IDX, IGNORE_INDEX |
|
|
|
|
| def get_current_action_mask(token_ids): |
| |
| newline_positions = token_ids != IGNORE_INDEX |
|
|
| |
| cumsum = torch.cumsum(newline_positions, dim=1) |
|
|
| |
| mask = (1 <= cumsum) & (cumsum <= ACTION_DIM) |
|
|
| |
| action_tokens_only_mask = token_ids > ACTION_TOKEN_BEGIN_IDX |
| mask = action_tokens_only_mask * mask |
|
|
| return mask |
|
|
|
|
| def get_next_actions_mask(token_ids): |
| |
| newline_positions = token_ids != IGNORE_INDEX |
|
|
| |
| cumsum = torch.cumsum(newline_positions, dim=1) |
|
|
| |
| mask = cumsum > ACTION_DIM |
|
|
| |
| action_tokens_only_mask = token_ids > ACTION_TOKEN_BEGIN_IDX |
| mask = action_tokens_only_mask * mask |
|
|
| return mask |
|
|
|
|
| def compute_token_accuracy(predicted_token_ids, ground_truth_token_ids, mask): |
| correct_preds = (predicted_token_ids == ground_truth_token_ids) & mask |
| accuracy = correct_preds.sum().float() / mask.sum().float() |
| return accuracy |
|
|
|
|
| def compute_actions_l1_loss(action_tokenizer, predicted_token_ids, ground_truth_token_ids, mask): |
| pred_continuous_actions = torch.tensor( |
| action_tokenizer.decode_token_ids_to_actions(predicted_token_ids[mask].cpu().numpy()) |
| ) |
| true_continuous_actions = torch.tensor( |
| action_tokenizer.decode_token_ids_to_actions(ground_truth_token_ids[mask].cpu().numpy()) |
| ) |
| l1_loss = torch.nn.functional.l1_loss(pred_continuous_actions, true_continuous_actions) |
| return l1_loss |
|
|
| def find_checkpoint_file(pretrained_checkpoint, file_pattern) : |
| """ |
| Find a specific checkpoint file matching a pattern. |
| |
| Args: |
| pretrained_checkpoint: Path to the checkpoint directory |
| file_pattern: String pattern to match in filenames |
| |
| Returns: |
| str: Path to the matching checkpoint file |
| |
| Raises: |
| AssertionError: If no files or multiple files match the pattern |
| """ |
| assert os.path.isdir(pretrained_checkpoint), f"Checkpoint path must be a directory: {pretrained_checkpoint}" |
|
|
| checkpoint_files = [] |
| for filename in os.listdir(pretrained_checkpoint): |
| if file_pattern in filename and "checkpoint" in filename: |
| full_path = os.path.join(pretrained_checkpoint, filename) |
| checkpoint_files.append(full_path) |
|
|
| assert len(checkpoint_files) == 1, ( |
| f"Expected exactly 1 {file_pattern} checkpoint but found {len(checkpoint_files)} in directory: {pretrained_checkpoint}" |
| ) |
|
|
| return checkpoint_files[0] |
|
|
|
|
| def load_component_state_dict(checkpoint_path) : |
| """ |
| Load a component's state dict from checkpoint and handle DDP prefix if present. |
| |
| Args: |
| checkpoint_path: Path to the checkpoint file |
| |
| Returns: |
| Dict: The processed state dictionary for loading |
| """ |
| state_dict = torch.load(checkpoint_path, weights_only=True) |
|
|
| |
| new_state_dict = {} |
| for k, v in state_dict.items(): |
| if k.startswith("module."): |
| new_state_dict[k[7:]] = v |
| else: |
| new_state_dict[k] = v |
|
|
| return new_state_dict |