import argparse
import json
import os
import random

import h5py
import numpy as np
import torch
from PIL import Image
from tqdm import tqdm
from transformers import CLIPModel, CLIPProcessor

from sae import ModelLoader, SAEVisualizer
from utils import (
    get_cls_acc,
    get_high_low_group_acc,
    get_image_text_cos,
    get_text_embedding,
    load_class_names,
    load_data,
    load_pred_results,
    masking,
)

if __name__ == "__main__":

    parser = argparse.ArgumentParser()
    parser.add_argument("--sae_name", type=str, default="clip-vit-large-patch14")
    parser.add_argument("--dataset_name", type=str, default="imagenet")
    parser.add_argument("--split", type=str, default="val")
    parser.add_argument("--num_classes", type=int, default=1000)
    parser.add_argument("--label_name", type=str, default="label")
    parser.add_argument("--root", type=str, default=".")
    parser.add_argument("--device", type=str, default="cuda")
    parser.add_argument("--save_path", type=str, default="./out/sae_latent_names/")
    parser.add_argument("--top_latents", type=int, default=20)
    parser.add_argument("--top_k", type=int, default=10)
    parser.add_argument(
        "--sae_path",
        type=str,
        default="./out/sae_weights/clip-vit-large-patch14/clip-vit-large-patch14_-2_resid_32768.pt",
    )
    args = parser.parse_args()

    sae, vit, cfg = ModelLoader.get_sae_and_vit(args.sae_path, args.device, "openai/clip-vit-large-patch14")
    visualizer = SAEVisualizer(sae, vit, cfg, args.root, "clip-vit-large-patch14", args.device)

    model_name = "openai/clip-vit-large-patch14"
    clip_model = CLIPModel.from_pretrained(model_name).to(args.device)
    clip_processor = CLIPProcessor.from_pretrained(model_name)

    with open(f"./out/sae_latent_names/clip-vit-large-patch14_gpt.json", "r") as f:
        latent_names = json.load(f)

    save_dir = os.path.join("./out", "dataset_bias", args.sae_name, "imagenet")
    with open(f"{save_dir}/major_concepts.json", "r") as f:
        major_concept_dict = json.load(f)
    major_concept_indices = np.array(list(major_concept_dict.keys())).astype(int)

    train_test_data = load_data(args.dataset_name, args.sae_name, args.split, args.label_name)
    class_names = load_class_names(args.root, args.dataset_name)
    pred_results = load_pred_results(args.root, args.dataset_name)
    training_stats = load_data(args.root, args.dataset_name)

    if args.dataset_name == "imagenet":
        top_n = 20
    else:
        top_n = 10

    error_slices = []

    for selected_class in tqdm(range(args.num_classes)):

        train_class_indices = np.where(train_test_data["train_labels"] == selected_class)[0]
        if args.dataset_name == "imagenet":
            with h5py.File(f"{root}/data/{args.dataset_name}_analysis/train_sae_latents.h5", "r") as hf:
                train_class_activation = hf[f"activations_{int(selected_class)}"][:]
        else:
            train_class_activation = train_test_data["train_activation"][train_class_indices, :]

        label_name = class_names[selected_class]
        with torch.no_grad():
            label_embedding = get_text_embedding(clip_model, clip_processor, label_name, args.device)
        if isinstance(label_embedding, str):
            label_embedding = label_embedding.reshape(1, -1)

        test_class_indices = np.where(train_test_data["test_labels"] == selected_class)[0]
        test_class_activation = train_test_data["test_activations"][test_class_indices, :]

        rand_images = {"image": []}
        for key in ["image", "jpg", "webp"]:
            if key in train_test_data["train_dataset"].features:
                break
        images = train_test_data["train_dataset"][train_class_indices][key]
        num_samples = min(100, len(images))
        selected_images = random.sample(images, num_samples)
        for image in selected_images:
            if isinstance(image, str):
                image = Image.open(image)
            rand_images["image"].append(image.resize((256, 256)))

        with torch.no_grad():
            inputs = ImageProcessor.process_model_inputs(rand_images, visualizer.vit, visualizer.device)
            sae_act = ImageProcessor.get_sae_activations(
                visualizer.sae,
                visualizer.vit,
                inputs,
                visualizer.cfg.block_layer,
                visualizer.cfg.module_name,
                visualizer.cfg.class_token,
                get_mean=False,
            )

        major_concept_indices = np.array(list(latent_names.keys())).astype(int)
        sorted_indices = np.argsort(
            training_stats["train_mean_var_activation"][0][selected_class][major_concept_indices]
        )[::-1]
        sorted_indices = major_concept_indices[sorted_indices]
        temp = sorted_indices[: args.top_latents].copy()
        sae_act = sae_act[:, :, temp]
        resize_size = 256

        with torch.no_grad():
            image_label_sim = get_image_text_cos(
                clip_model, clip_processor, rand_images["image"], label_embedding, args.device
            )

        for i, latent_idx in enumerate(tqdm(sorted_indices[: args.top_latents])):

            selected_act = sae_act[:, :, i]
            feature_size = int(np.sqrt(selected_act.shape[1] - 1))
            masks = torch.Tensor(selected_act[:, 1:].reshape(sae_act.shape[0], 1, feature_size, feature_size))
            masks = (
                torch.nn.functional.interpolate(masks, (resize_size, resize_size), mode="bilinear")
                .squeeze(1)
                .cpu()
                .numpy()
            )

            with torch.no_grad():
                concept_only_image, concept_only_mask = masking(
                    rand_images, masks, resize_size, blend_rate=0, gamma=0.001
                )
                concept_exclude_image, concept_exclude_mask = masking(
                    rand_images, masks, resize_size, blend_rate=0, gamma=0.001, reverse=True
                )

                concept_only_label_sim = get_image_text_cos(
                    clip_model, clip_processor, concept_only_image, label_embedding, args.device
                )
                concept_exclude_image_sim = get_image_text_cos(
                    clip_model, clip_processor, concept_exclude_image, label_embedding, args.device
                )

            sufficienct = (concept_only_label_sim / image_label_sim).mean()
            necessity = (image_label_sim / concept_exclude_image_sim).mean()
            alignment_score = (sufficienct + necessity) / 2

            latent_name = latent_names[str(latent_idx)]["imagenet"] if str(latent_idx) in latent_names else "unknown"

            group_acc = get_high_low_group_acc(pred_results, test_class_indices, test_class_activation, latent_idx)

            cls_acc = get_cls_acc(pred_results, test_class_indices, test_class_activation)

            high_rate = training_stats["high_activating_ratio"][selected_class][latent_idx]
            cohen = training_stats["d_scores"][selected_class][latent_idx]
            mean = training_stats["train_mean_var_activation"][0][selected_class][latent_idx]

            slice_dict = {
                "selected_class": selected_class,
                "latent_idx": int(latent_idx),
                "latent_name": latent_name,
                "org_image_label_sim": float(image_label_sim.mean()),
                "retain_sim": float(concept_only_label_sim.mean()),
                "masked_out_sim": float(concept_exclude_image_sim.mean()),
                "sufficienct": float(sufficienct),
                "necessity": float(necessity),
                "alignment_score": float(alignment_score),
            }
            slice_dict.update(cls_acc)
            slice_dict.update(group_acc)
            error_slices.append(slice_dict)

        with open(f"./out/{args.dataset_name}/aligment_scores_per_class.json", "w") as f:
            json.dump(error_slices, f, indent=4)
