| | import os |
| | import json |
| | import pickle |
| | import torch |
| | import numpy as np |
| | from tqdm import tqdm |
| | from transformers import AutoTokenizer, AutoModel |
| | import torch.multiprocessing as mp |
| |
|
| | |
| | INPUT_JSON = "Pretrain.json" |
| | mean_shift = True |
| | CKPT = "/root/autodl-tmp/model/siglip2" |
| | BATCH_SIZE = 512 |
| | LOAD_LIMIT = None |
| |
|
| | |
| | RAW_DIR = "raw_embeds" |
| | SHIFTED_DIR = "shifted_embeds" |
| |
|
| | |
| | os.makedirs(RAW_DIR, exist_ok=True) |
| | os.makedirs(SHIFTED_DIR, exist_ok=True) |
| |
|
| | |
| | with open(INPUT_JSON, "r", encoding="utf-8") as f: |
| | items = json.load(f) |
| | if LOAD_LIMIT is not None: |
| | items = items[:LOAD_LIMIT] |
| |
|
| | |
| | tokenizer = AutoTokenizer.from_pretrained(CKPT) |
| |
|
| | |
| | num_gpus = torch.cuda.device_count() |
| | chunks = np.array_split(items, num_gpus) |
| |
|
| | |
| | def compute_raw_embeddings(device, data_chunk, gpu_id): |
| | device = torch.device(device) |
| | model = AutoModel.from_pretrained(CKPT).to(device).eval() |
| | results = [] |
| |
|
| | for i in tqdm(range(0, len(data_chunk), BATCH_SIZE), desc=f"Device {gpu_id} Raw Batches"): |
| | batch = data_chunk[i:i + BATCH_SIZE] |
| | ids = [it['id'] for it in batch] |
| | captions = [it.get('caption', '') for it in batch] |
| |
|
| | inputs = tokenizer( |
| | captions, |
| | padding="max_length", |
| | truncation=True, |
| | max_length=64, |
| | return_tensors="pt" |
| | ).to(device) |
| |
|
| | with torch.no_grad(): |
| | embs = model.get_text_features(**inputs) |
| | embs_np = embs.cpu().numpy() |
| |
|
| | for idx, item_id in enumerate(ids): |
| | results.append({'id': item_id, 'embed': embs_np[idx]}) |
| |
|
| | |
| | raw_file = os.path.join(RAW_DIR, f"raw_embeds_device_{gpu_id}.pkl") |
| | with open(raw_file, 'wb') as f: |
| | pickle.dump(results, f, protocol=pickle.HIGHEST_PROTOCOL) |
| | print(f"Device {gpu_id} saved {len(results)} raw embeddings to {raw_file}") |
| |
|
| | |
| | def apply_mean_shift_and_save(global_mean, gpu_id): |
| | raw_file = os.path.join(RAW_DIR, f"raw_embeds_device_{gpu_id}.pkl") |
| | out_file = os.path.join(SHIFTED_DIR, f"embeds_device_{gpu_id}.pkl") |
| |
|
| | with open(raw_file, 'rb') as f: |
| | data = pickle.load(f) |
| |
|
| | |
| | for item in data: |
| | item['embed'] = item['embed'] - global_mean |
| |
|
| | |
| | with open(out_file, 'wb') as f: |
| | pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL) |
| | print(f"Device {gpu_id} saved {len(data)} shifted embeddings to {out_file}") |
| |
|
| | |
| | def main(): |
| | |
| | procs = [] |
| | for i in range(num_gpus): |
| | p = mp.Process(target=compute_raw_embeddings, args=(f"cuda:{i}", chunks[i], i)) |
| | p.start() |
| | procs.append(p) |
| | for p in procs: |
| | p.join() |
| |
|
| | if mean_shift: |
| | |
| | all_embeds = [] |
| | for i in range(num_gpus): |
| | raw_file = os.path.join(RAW_DIR, f"raw_embeds_device_{i}.pkl") |
| | with open(raw_file, 'rb') as f: |
| | data = pickle.load(f) |
| | all_embeds.extend([item['embed'] for item in data]) |
| |
|
| | all_embeds = np.stack(all_embeds, axis=0) |
| | global_mean = np.mean(all_embeds, axis=0) |
| | print("Computed global mean of shape", global_mean.shape) |
| |
|
| | |
| | for i in range(num_gpus): |
| | apply_mean_shift_and_save(global_mean, i) |
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|