| | import os
|
| | import torch
|
| | from safetensors.torch import save_file
|
| | import glob
|
| | import shutil
|
| |
|
| | def convert_model_to_safetensors(model_path, output_path):
|
| |
|
| | if os.path.exists(output_path):
|
| | os.remove(output_path)
|
| | print(f"Looking for PyTorch model files in {model_path}")
|
| |
|
| |
|
| | os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
| |
|
| |
|
| | model_files = glob.glob(os.path.join(model_path, "*.pt")) + \
|
| | glob.glob(os.path.join(model_path, "*.pth")) + \
|
| | glob.glob(os.path.join(model_path, "pytorch_model.bin"))
|
| |
|
| | if not model_files:
|
| | raise FileNotFoundError(f"No PyTorch model files found in {model_path}")
|
| |
|
| | print(f"Found model file(s): {model_files}")
|
| | model_file = model_files[0]
|
| |
|
| |
|
| | print(f"Loading model from {model_file}")
|
| | checkpoint = torch.load(model_file, map_location='cpu')
|
| |
|
| | print(f"Checkpoint type: {type(checkpoint)}")
|
| | print(f"Checkpoint keys: {checkpoint.keys() if isinstance(checkpoint, dict) else 'Not a dict'}")
|
| |
|
| |
|
| | model_state_dict = {}
|
| | if isinstance(checkpoint, dict):
|
| |
|
| | if 'model_state_dict' in checkpoint:
|
| | checkpoint = checkpoint['model_state_dict']
|
| |
|
| | elif 'state_dict' in checkpoint:
|
| | checkpoint = checkpoint['state_dict']
|
| | print(f"After getting state dict - Keys available: {checkpoint.keys() if isinstance(checkpoint, dict) else 'Not a dict'}")
|
| |
|
| |
|
| | for key, value in checkpoint.items():
|
| | if isinstance(value, torch.Tensor):
|
| | model_state_dict[key] = value
|
| | print(f"Added tensor for key: {key} with shape {value.shape}")
|
| |
|
| | print(f"Total number of tensors to save: {len(model_state_dict)}")
|
| | if len(model_state_dict) == 0:
|
| | raise ValueError("No tensors found in the checkpoint! Check the model structure.")
|
| |
|
| |
|
| | print(f"Converting to safetensors and saving to {output_path}")
|
| | save_file(model_state_dict, output_path)
|
| | print("Conversion completed successfully!")
|
| |
|
| | if __name__ == "__main__":
|
| |
|
| | model_path = "./checkpoints"
|
| | output_path = "./checkpoints/model.safetensors"
|
| |
|
| | convert_model_to_safetensors(model_path, output_path) |