| import os |
| from omegaconf import OmegaConf |
| import torch |
| import tempfile |
| from safetensors.torch import load_file |
| import requests |
| import yaml |
|
|
| def get_ckpt(path, key="state_dict"): |
| is_url = path.startswith("http://") or path.startswith("https://") |
| suffix = os.path.splitext(path)[-1] |
|
|
| if is_url: |
| print(f"Loading checkpoint from URL: {path}") |
| with tempfile.NamedTemporaryFile(suffix=suffix) as tmp_file: |
| response = requests.get(path) |
| response.raise_for_status() |
| tmp_file.write(response.content) |
| tmp_file.flush() |
| ckpt_path = tmp_file.name |
|
|
| if suffix == ".safetensors": |
| checkpoint = load_file(ckpt_path) |
| else: |
| checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=False) |
| else: |
| print(f"Loading checkpoint from local path: {path}") |
| if suffix == ".safetensors": |
| checkpoint = load_file(path) |
| else: |
| checkpoint = torch.load(path, map_location="cpu", weights_only=False) |
|
|
| if key is not None and key in checkpoint: |
| checkpoint = checkpoint[key] |
|
|
| return checkpoint |
|
|
|
|
| def get_yaml_config(path): |
| if path.startswith("http://") or path.startswith("https://"): |
| response = requests.get(path) |
| response.raise_for_status() |
| config = OmegaConf.create(response.text) |
| else: |
| with open(path, 'r') as f: |
| config = OmegaConf.load(f) |
| return config |
|
|