# -*- coding: utf-8 -*-
"""
features_extraction.py

Combined language and vision feature extraction for GEMMA2, OLMO2, LLaMA3 (text embeddings) and DINOv2 (vision features).
"""
import os
import json
import pickle
import torch
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModel
import numpy as np
from PIL import Image
import torchvision.transforms as transforms
import argparse


def load_json(file_path):
    with open(file_path, 'r') as f:
        return json.load(f)


def prepare_context(captions_dict, max_entries=10, index_range=None, use_negative=False):
    grouped = {}
    for idx, data in captions_dict.items():
        fn = data.get("filename")
        if not fn:
            continue
        if index_range:
            num = int(fn.split('.')[0])
            if num not in index_range:
                continue
        # Select caption or negative_caption
        text = data.get("negative_caption") if use_negative else data.get("caption")
        grouped.setdefault(fn, []).append(text)
    return {fn: " ".join(sentences[:max_entries]) for fn, sentences in grouped.items()}


def get_device():
    if torch.backends.mps.is_available():
        return torch.device("mps")
    elif torch.cuda.is_available():
        return torch.device("cuda")
    else:
        return torch.device("cpu")

def extract_text_embeddings(model_name,
                            suffixes,
                            prefix,
                            add_neg_suffix,
                            hf_token,
                            base_json_dir,
                            base_output_dir,
                            pad_token_as_eos=False,
                            allow_tf32=False,
                            offload_folder=None,
                            use_negative=False):
    tokenizer = AutoTokenizer.from_pretrained(model_name, token=hf_token)
    if pad_token_as_eos:
        tokenizer.pad_token = tokenizer.eos_token
    model_args = {"pretrained_model_name_or_path": model_name, "token": hf_token, "device_map": "auto"}
    if offload_folder:
        model_args["offload_folder"] = offload_folder
    model = AutoModel.from_pretrained(**model_args)
    model.eval()

    if allow_tf32:
        torch.backends.cuda.matmul.allow_tf32 = True

    device = get_device()
    print(f"Using device for {prefix}: {device}")

    for suf in suffixes:
        json_path = os.path.join(base_json_dir, f"{suf}.json")
        out_name = f"{prefix}_{suf}{'_neg' if add_neg_suffix else ''}.pkl"
        pkl_path = os.path.join(base_output_dir, out_name)
        print(f"\nProcessing {os.path.basename(json_path)} -> {os.path.basename(pkl_path)}")

        captions = load_json(json_path)
        contexts = prepare_context(captions, max_entries=2, use_negative=use_negative)

        embeddings = {}
        for fn, ctx in tqdm(contexts.items(), desc=f"→ {os.path.basename(pkl_path)}"):
            enc = tokenizer(ctx, return_tensors="pt", padding=True, truncation=True).to(device)
            with torch.no_grad():
                out = model(
                    input_ids=enc["input_ids"],
                    attention_mask=enc["attention_mask"],
                    output_hidden_states=True
                )
            embeddings[fn] = out

        with open(pkl_path, 'wb') as f:
            pickle.dump(embeddings, f)
        print(f"Saved {len(embeddings)} entries to {pkl_path}")


# Vision Feature Extraction 
hidden_preprocess = transforms.Compose([
    transforms.Resize((518, 518)),
    transforms.CenterCrop(518),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
])

backbone_preprocess = transforms.Compose([
    transforms.Resize(224),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
])


def load_and_preprocess_image_hidden(path):
    img = Image.open(path).convert('RGB')
    return hidden_preprocess(img).unsqueeze(0)


def load_and_preprocess_image_backbone(path):
    img = Image.open(path).convert('RGB')
    return backbone_preprocess(img).unsqueeze(0)

class DinoVisionTransformerWrapper(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, x):
        x = self.model.prepare_tokens_with_masks(x)
        for blk in self.model.blocks:
            x = blk(x)
        return self.model.norm(x)


def extract_dino_hidden(cls_model, image_dir, save_dir, image_files):
    os.makedirs(save_dir, exist_ok=True)
    for img_file in image_files:
        inp = os.path.join(image_dir, img_file)
        outp = os.path.join(save_dir, os.path.splitext(img_file)[0] + '.npy')
        if os.path.exists(outp):
            print(f"Skipping {img_file}, already exists.")
            continue
        img_t = load_and_preprocess_image_hidden(inp)
        with torch.no_grad():
            layers = list(range(12))
            outputs = cls_model.model.get_intermediate_layers(img_t, n=layers)
            tokens = [o[:,0,:].squeeze().cpu().numpy() for o in outputs]
            stack = np.stack(tokens, axis=0)
        np.save(outp, stack)
    print("Hidden-layer CLS extraction complete.")


def extract_dino_backbone(bt_model, image_dir, save_dir, image_files):
    os.makedirs(save_dir, exist_ok=True)
    for img_file in image_files:
        inp = os.path.join(image_dir, img_file)
        outp = os.path.join(save_dir, os.path.splitext(img_file)[0] + '.npy')
        if os.path.exists(outp):
            print(f"Skipping {img_file}, already exists.")
            continue
        img_t = load_and_preprocess_image_backbone(inp)
        with torch.no_grad():
            out = bt_model(img_t)
        np.save(outp, out.cpu().numpy())
    print("Backbone output extraction complete.")

def merge_npy_to_pkl(npy_dir, pkl_filename):
    import pickle
    representations = {}
    for npy_file in os.listdir(npy_dir):
        if npy_file.endswith('.npy'):
            file_path = os.path.join(npy_dir, npy_file)
            file_key = npy_file.replace('.npy', '.png')  # Use .png as the key extension
            representations[file_key] = np.load(file_path)
    pkl_save_path = os.path.join(npy_dir, pkl_filename)
    with open(pkl_save_path, 'wb') as f:
        pickle.dump(representations, f)
    print(f"Dictionary saved to {pkl_save_path}")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Extract language and vision features using specified model scale.")
    parser.add_argument("--scale", choices=["base","large"], default="base",
                        help="Model scale: 'base' uses smaller models, 'large' uses larger ones.")
    parser.add_argument("--hf_token", required=True,
                        help="HuggingFace token for model access.")
    parser.add_argument("--json_dir", required=True,
                        help="Directory containing JSON caption files and where text outputs will be saved.")
    parser.add_argument("--img_dir", required=True,
                        help="Directory containing image files for vision feature extraction and where vision outputs will be saved.")
    args = parser.parse_args()

    # select model names based on scale
    if args.scale == "base":
        text_configs = [
            ("gemma2", "google/gemma-2-2b", ['swap_att','swap_obj','add_att','add_obj','replace_att','replace_obj','replace_rel'], True, None),
            ("olmo2", "allenai/OLMo-2-1124-7B", ['swap_att','swap_obj','add_att','add_obj','replace_att','replace_obj','replace_rel'], False, {'allow_tf32':True}),
            ("llama3","meta-llama/Llama-3.2-1B", ['swap_att','swap_obj','add_att','add_obj','replace_att','replace_obj','replace_rel'], True, {'pad_token_as_eos':True})
        ]
        vision_model_name = 'dinov2_vitb14'
    else:
        text_configs = [
            ("gemma2","google/gemma-2-9b", ['replace_obj','replace_rel'], True, None),
            ("olmo2","allenai/OLMo-2-1124-13B", ['swap_att','swap_obj','add_att','add_obj','replace_att','replace_obj','replace_rel'], False, {'allow_tf32':True}),
            ("llama3","meta-llama/Llama-3.2-3B", ['swap_att','swap_obj','add_att','add_obj','replace_att','replace_obj','replace_rel'], True, {'pad_token_as_eos':True})
        ]
        vision_model_name = 'dinov2_vitl14'

    # Paths
    JSON_DIR = args.json_dir
    OUT_DIR = JSON_DIR
    IMG_DIR = args.img_dir
    VISION_OUT_DIR = IMG_DIR

    # Text extraction
    for prefix, model_name, suffixes, add_neg, extras in text_configs:
        for use_negative in [False, True]:
            opts = { 'hf_token': args.hf_token,
                     'base_json_dir': JSON_DIR, 'base_output_dir': OUT_DIR,
                     'model_name': model_name, 'suffixes': suffixes,
                     'prefix': prefix, 'add_neg_suffix': use_negative,
                     'use_negative': use_negative }
            # merge extras into function call args
            func_kwargs = {**opts, **(extras or {})}
            extract_text_embeddings(**func_kwargs)

    # Vision extraction
    image_files = [f for f in os.listdir(IMG_DIR) if os.path.isfile(os.path.join(IMG_DIR, f))
                   and f.lower().endswith(('.jpg','.jpeg','.png','.bmp','.tiff'))]
    # load and wrap DINO model
    base = torch.hub.load('facebookresearch/dinov2', vision_model_name)
    dino_model = DinoVisionTransformerWrapper(base)
    dino_model.eval()
    for p in dino_model.parameters():
        p.requires_grad = False

    # hidden CLS
    HIDDEN_DIR = os.path.join(VISION_OUT_DIR, f'{vision_model_name}_hidden')
    extract_dino_hidden(dino_model, IMG_DIR, HIDDEN_DIR, image_files)
    # backbone output
    BACKBONE_DIR = os.path.join(VISION_OUT_DIR, f'{vision_model_name}_backbone_output')
    extract_dino_backbone(dino_model, IMG_DIR, BACKBONE_DIR, image_files)

    # Merge .npy files into .pkl for both hidden and backbone outputs
    merge_npy_to_pkl(HIDDEN_DIR, f'{vision_model_name}_hidden.pkl')
    merge_npy_to_pkl(BACKBONE_DIR, f'{vision_model_name}_backbone_output.pkl')
