openvla-sparse-rl / train_utils.py
YixangSun's picture
Upload folder using huggingface_hub
48168f2 verified
"""Utils for training/fine-tuning scripts."""
import torch
import os
from .constants import ACTION_DIM, ACTION_TOKEN_BEGIN_IDX, IGNORE_INDEX
def get_current_action_mask(token_ids):
# Create a tensor marking positions of IGNORE_INDEX
newline_positions = token_ids != IGNORE_INDEX
# Calculate cumulative sum to identify regions between newlines
cumsum = torch.cumsum(newline_positions, dim=1)
# Create the mask
mask = (1 <= cumsum) & (cumsum <= ACTION_DIM)
# Extract the action part only
action_tokens_only_mask = token_ids > ACTION_TOKEN_BEGIN_IDX
mask = action_tokens_only_mask * mask
return mask
def get_next_actions_mask(token_ids):
# Create a tensor marking positions of IGNORE_INDEX
newline_positions = token_ids != IGNORE_INDEX
# Calculate cumulative sum to identify regions between newlines
cumsum = torch.cumsum(newline_positions, dim=1)
# Create the mask
mask = cumsum > ACTION_DIM
# Extract the action part only
action_tokens_only_mask = token_ids > ACTION_TOKEN_BEGIN_IDX
mask = action_tokens_only_mask * mask
return mask
def compute_token_accuracy(predicted_token_ids, ground_truth_token_ids, mask):
correct_preds = (predicted_token_ids == ground_truth_token_ids) & mask
accuracy = correct_preds.sum().float() / mask.sum().float()
return accuracy
def compute_actions_l1_loss(action_tokenizer, predicted_token_ids, ground_truth_token_ids, mask):
pred_continuous_actions = torch.tensor(
action_tokenizer.decode_token_ids_to_actions(predicted_token_ids[mask].cpu().numpy())
)
true_continuous_actions = torch.tensor(
action_tokenizer.decode_token_ids_to_actions(ground_truth_token_ids[mask].cpu().numpy())
)
l1_loss = torch.nn.functional.l1_loss(pred_continuous_actions, true_continuous_actions)
return l1_loss
def find_checkpoint_file(pretrained_checkpoint, file_pattern) :
"""
Find a specific checkpoint file matching a pattern.
Args:
pretrained_checkpoint: Path to the checkpoint directory
file_pattern: String pattern to match in filenames
Returns:
str: Path to the matching checkpoint file
Raises:
AssertionError: If no files or multiple files match the pattern
"""
assert os.path.isdir(pretrained_checkpoint), f"Checkpoint path must be a directory: {pretrained_checkpoint}"
checkpoint_files = []
for filename in os.listdir(pretrained_checkpoint):
if file_pattern in filename and "checkpoint" in filename:
full_path = os.path.join(pretrained_checkpoint, filename)
checkpoint_files.append(full_path)
assert len(checkpoint_files) == 1, (
f"Expected exactly 1 {file_pattern} checkpoint but found {len(checkpoint_files)} in directory: {pretrained_checkpoint}"
)
return checkpoint_files[0]
def load_component_state_dict(checkpoint_path) :
"""
Load a component's state dict from checkpoint and handle DDP prefix if present.
Args:
checkpoint_path: Path to the checkpoint file
Returns:
Dict: The processed state dictionary for loading
"""
state_dict = torch.load(checkpoint_path, weights_only=True)
# If the component was trained with DDP, elements in the state dict have prefix "module." which we must remove
new_state_dict = {}
for k, v in state_dict.items():
if k.startswith("module."):
new_state_dict[k[7:]] = v
else:
new_state_dict[k] = v
return new_state_dict