import os
import json
import pickle
import torch
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModel
import numpy as np
from PIL import Image
import torchvision.transforms as transforms
import argparse

def load_json(file_path):
    with open(file_path, 'r') as f:
        return json.load(f)

def prepare_context_single_caption(captions_dict, max_entries=10, index_range=None):
    grouped = {}
    for idx, data in captions_dict.items():
        fn = data.get("filename")
        if not fn:
            continue
        if index_range:
            num = int(fn.split('.')[0])
            if num not in index_range:
                continue
        text = data.get("caption")
        grouped.setdefault(fn, []).append(text)
    return {fn: " ".join(sentences[:max_entries]) for fn, sentences in grouped.items()}

def get_device():
    if torch.backends.mps.is_available():
        return torch.device("mps")
    elif torch.cuda.is_available():
        return torch.device("cuda")
    else:
        return torch.device("cpu")

def extract_text_embeddings_single(model_name, suffixes, prefix, hf_token, base_json_dir, base_output_dir, pad_token_as_eos=False, allow_tf32=False, offload_folder=None):
    tokenizer = AutoTokenizer.from_pretrained(model_name, token=hf_token)
    if pad_token_as_eos:
        tokenizer.pad_token = tokenizer.eos_token
    model_args = {"pretrained_model_name_or_path": model_name, "token": hf_token, "device_map": "auto"}
    if offload_folder:
        model_args["offload_folder"] = offload_folder
    model = AutoModel.from_pretrained(**model_args)
    model.eval()
    if allow_tf32:
        torch.backends.cuda.matmul.allow_tf32 = True
    device = get_device()
    print(f"Using device for {prefix}: {device}")
    for suf in suffixes:
        json_path = os.path.join(base_json_dir, f"{suf}.json")
        out_name = f"{prefix}_{suf}.pkl"
        pkl_path = os.path.join(base_output_dir, out_name)
        print(f"\nProcessing {os.path.basename(json_path)} -> {os.path.basename(pkl_path)}")
        captions = load_json(json_path)
        contexts = prepare_context_single_caption(captions, max_entries=2)
        embeddings = {}
        for fn, ctx in tqdm(contexts.items(), desc=f"→ {os.path.basename(pkl_path)}"):
            enc = tokenizer(ctx, return_tensors="pt", padding=True, truncation=True).to(device)
            with torch.no_grad():
                out = model(
                    input_ids=enc["input_ids"],
                    attention_mask=enc["attention_mask"],
                    output_hidden_states=True
                )
            embeddings[fn] = out
        with open(pkl_path, 'wb') as f:
            pickle.dump(embeddings, f)
        print(f"Saved {len(embeddings)} entries to {pkl_path}")

# Vision Feature Extraction (same as original)
hidden_preprocess = transforms.Compose([
    transforms.Resize((518, 518)),
    transforms.CenterCrop(518),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
])
backbone_preprocess = transforms.Compose([
    transforms.Resize(224),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
])
def load_and_preprocess_image_hidden(path):
    img = Image.open(path).convert('RGB')
    return hidden_preprocess(img).unsqueeze(0)
def load_and_preprocess_image_backbone(path):
    img = Image.open(path).convert('RGB')
    return backbone_preprocess(img).unsqueeze(0)
class DinoVisionTransformerWrapper(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model
    def forward(self, x):
        x = self.model.prepare_tokens_with_masks(x)
        for blk in self.model.blocks:
            x = blk(x)
        return self.model.norm(x)
def extract_dino_hidden(cls_model, image_dir, save_dir, image_files):
    os.makedirs(save_dir, exist_ok=True)
    for img_file in image_files:
        inp = os.path.join(image_dir, img_file)
        outp = os.path.join(save_dir, os.path.splitext(img_file)[0] + '.npy')
        if os.path.exists(outp):
            print(f"Skipping {img_file}, already exists.")
            continue
        img_t = load_and_preprocess_image_hidden(inp)
        with torch.no_grad():
            layers = list(range(12))
            outputs = cls_model.model.get_intermediate_layers(img_t, n=layers)
            tokens = [o[:,0,:].squeeze().cpu().numpy() for o in outputs]
            stack = np.stack(tokens, axis=0)
        np.save(outp, stack)
    print("Hidden-layer CLS extraction complete.")
def extract_dino_backbone(bt_model, image_dir, save_dir, image_files):
    os.makedirs(save_dir, exist_ok=True)
    for img_file in image_files:
        inp = os.path.join(image_dir, img_file)
        outp = os.path.join(save_dir, os.path.splitext(img_file)[0] + '.npy')
        if os.path.exists(outp):
            print(f"Skipping {img_file}, already exists.")
            continue
        img_t = load_and_preprocess_image_backbone(inp)
        with torch.no_grad():
            out = bt_model(img_t)
        np.save(outp, out.cpu().numpy())
    print("Backbone output extraction complete.")
def merge_npy_to_pkl(npy_dir, pkl_filename):
    representations = {}
    for npy_file in os.listdir(npy_dir):
        if npy_file.endswith('.npy'):
            file_path = os.path.join(npy_dir, npy_file)
            file_key = npy_file.replace('.npy', '.png')
            representations[file_key] = np.load(file_path)
    pkl_save_path = os.path.join(npy_dir, pkl_filename)
    with open(pkl_save_path, 'wb') as f:
        pickle.dump(representations, f)
    print(f"Dictionary saved to {pkl_save_path}")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Extract features for random_data_easynonmatch format.")
    parser.add_argument("--hf_token", required=True, help="HuggingFace token for model access.")
    parser.add_argument("--json_dir", required=True, help="Directory containing JSON caption files and where text outputs will be saved.")
    parser.add_argument("--img_dir", required=True, help="Directory containing image files for vision feature extraction and where vision outputs will be saved.")
    args = parser.parse_args()
    # Use base scale configs
    text_configs = [
        ("gemma2", "google/gemma-2-2b", ['random_text_gatsby','random_text_wiki'], False, None),
        ("olmo2", "allenai/OLMo-2-1124-7B", ['random_text_gatsby','random_text_wiki'], False, {'allow_tf32':True}),
        ("llama3","meta-llama/Llama-3.2-1B", ['random_text_gatsby','random_text_wiki'], False, {'pad_token_as_eos':True})
    ]
    vision_model_name = 'dinov2_vitb14'
    JSON_DIR = args.json_dir
    OUT_DIR = JSON_DIR
    IMG_DIR = args.img_dir
    VISION_OUT_DIR = IMG_DIR
    # Text extraction (single caption)
    for prefix, model_name, suffixes, _, extras in text_configs:
        opts = { 'hf_token': args.hf_token,
                 'base_json_dir': JSON_DIR, 'base_output_dir': OUT_DIR,
                 'model_name': model_name, 'suffixes': suffixes,
                 'prefix': prefix }
        func_kwargs = {**opts, **(extras or {})}
        extract_text_embeddings_single(**func_kwargs)
    # Vision extraction
    image_files = [f for f in os.listdir(IMG_DIR) if os.path.isfile(os.path.join(IMG_DIR, f))
                   and f.lower().endswith(('.jpg','.jpeg','.png','.bmp','.tiff'))]
    base = torch.hub.load('facebookresearch/dinov2', vision_model_name)
    dino_model = DinoVisionTransformerWrapper(base)
    dino_model.eval()
    for p in dino_model.parameters():
        p.requires_grad = False
    HIDDEN_DIR = os.path.join(VISION_OUT_DIR, f'{vision_model_name}_hidden')
    extract_dino_hidden(dino_model, IMG_DIR, HIDDEN_DIR, image_files)
    BACKBONE_DIR = os.path.join(VISION_OUT_DIR, f'{vision_model_name}_backbone_output')
    extract_dino_backbone(dino_model, IMG_DIR, BACKBONE_DIR, image_files)
    merge_npy_to_pkl(HIDDEN_DIR, f'{vision_model_name}_hidden.pkl')
    merge_npy_to_pkl(BACKBONE_DIR, f'{vision_model_name}_backbone_output.pkl') 