import torch
from model_packs import get_model_and_processors
import os
import json
from torch.utils.data import DataLoader
import importlib
import argparse
import yaml
from custom_datasets.eval.cifar100 import CIFAR100
from custom_datasets.eval.objectnet import ObjectNet
from custom_datasets.eval.imagenet1k import ImageNet1k
from custom_datasets.eval.imagenets import ImageNetS
from custom_datasets.eval.imagenetr import ImageNetR
from custom_datasets.eval.food101 import Food101

from PIL import Image

os.environ["TOKENIZERS_PARALLELISM"] = "true"
DEVICE = f"cuda:{torch.cuda.device_count() - 1}"


def get_collate_fn(processor, max_length):
    def collate_fn(batch):
        images = [item[0] for item in batch]
        labels = torch.tensor([item[1] for item in batch])
        texts_placeholder = ["NA"]

        inputs = processor(
            images=images, text=texts_placeholder, return_tensors="pt",
            padding="max_length", truncation=True, max_length=max_length,
        )
        return inputs, labels
    return collate_fn


@torch.no_grad()
def main(eval_ds, vlm, seed):
    num_samples = 10

    # model and processor
    model_pack = get_model_and_processors(vlm)
    model = model_pack["model"].to(DEVICE)
    processor = model_pack["processor"]

    for k, v in model.named_modules():
        # check if has dropout attribute
        if hasattr(v, "dropout"):
            v.dropout = 0.2
            print(f"Set dropout for {k} to {v.dropout}")
    model.train()

    # dataset
    data_func = {
        "food101": Food101,
        "cifar100": CIFAR100,
        "imagenet1k": ImageNet1k,
        "imagenetr": ImageNetR,
        "imagenets": ImageNetS,
        "objectnet": ObjectNet,
    }

    if eval_ds in data_func:
        dataset = data_func[eval_ds]()
    else:
        raise NotImplementedError

    dataloader = DataLoader(
        dataset, batch_size=256, shuffle=False, num_workers=4, pin_memory=True,
        collate_fn=get_collate_fn(processor, model_pack["max_length"]))

    # compute uncertainty for image embeddings
    all_img_embs_batches = []
    # Iterate through the dataloader only ONCE
    for i, batch in enumerate(dataloader):
        print(f"Processing batch {i+1}", flush=True)
        inputs, _ = batch
        inputs = {
            k: v.to(DEVICE) for k, v in inputs.items()
        }

        # Get all Monte Carlo samples for the current batch
        batch_samples = []
        for n in range(num_samples):
            img_emb = model(**inputs).image_embeds
            img_emb /= img_emb.norm(dim=-1, keepdim=True)
            batch_samples.append(img_emb)

        # Stack samples for this batch [num_samples, batch_size, D]
        # and move to CPU to free GPU memory for the next batch.
        all_img_embs_batches.append(torch.stack(batch_samples, dim=0).cpu())

    # Concatenate results from all batches along the batch dimension
    all_img_embs = torch.cat(all_img_embs_batches, dim=1)  # [num_samples, N, D]
    uncer_img_embs = all_img_embs.std(dim=0).mean(dim=-1)  # [N]

    # compute uncertainty for the text embeddings
    # 1. Get the mapping from the dataset
    class_mapping = dataset.class_to_idx
    # 2. Create a list of the correct size
    class_names = [""] * len(class_mapping)
    # 3. Populate the list so that class_names[i] is the name for label i
    for name, idx in class_mapping.items():
        class_names[idx] = name
    # 4. Now, create prompts from this correctly ordered list
    prompts = [f"a photo of a {c.replace('_', ' ')}" for c in class_names]
    image_placeholder = [Image.new("RGB", (10, 10), (0, 0, 0))]

    prompt_input = processor(
        images=image_placeholder,
        text=prompts,
        padding="max_length",
        truncation=True,
        max_length=model_pack["max_length"],
        return_tensors="pt"
    )
    prompt_input = {
        k: v.to(DEVICE) for k, v in prompt_input.items()
    }

    # evaluate the uncertainty of the text embeddings with MC Dropout
    all_text_embs = []
    for n in range(num_samples):
        print(f"Sample {n+1}/{num_samples}", flush=True)
        text_emb = model(**prompt_input).text_embeds
        text_emb /= text_emb.norm(dim=-1, keepdim=True)
        all_text_embs.append(text_emb.cpu())
    all_text_embs = torch.stack(all_text_embs, dim=0)  # [num_samples, C, D]
    uncer_text_embs = all_text_embs.std(dim=0).mean(dim=-1)  # [C]

    img_embs = torch.load(f'embeddings/{eval_ds}/image.pth').to(DEVICE)
    text_embs = torch.load(f'embeddings/{eval_ds}/prompt.pth').to(DEVICE)
    targets = torch.load(f'embeddings/{eval_ds}/target.pth').to(DEVICE)

    from eval.unc_rmv import i2t_acc
    i2t_accs = i2t_acc(
        img_embs, uncer_img_embs, text_embs, uncer_text_embs,
        targets, args.uncer_levels)

    # save the results
    duplicates = [
        "cc",
        "datacomp",
        "laion",
    ]
    for duplicate in duplicates:
        result_dir = f"results/cls/{eval_ds}/{duplicate}/mcdo"
        os.makedirs(result_dir, exist_ok=True)
        with open(f"{result_dir}/{args.seed}.json", "w") as f:
            json.dump({
                "i2t": i2t_accs
            }, f)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--eval_ds', type=str, default='cifar100')
    parser.add_argument('--uncer_levels', type=int, default=10)
    parser.add_argument('--seed', type=int, default=0)
    args = parser.parse_args()

    eval_ds = args.eval_ds
    seed = args.seed
    vlm = yaml.safe_load(open('configs.yaml'))['base_model']
    main(eval_ds, vlm, seed)
