import torch
from model_packs import get_model_and_processors
import os
from torch.utils.data import DataLoader
import argparse
import yaml

from custom_datasets.eval.objectnet import ObjectNet
from custom_datasets.eval.imagenet1k import ImageNet1k
from custom_datasets.eval.imagenetr import ImageNetR
from custom_datasets.eval.imagenets import ImageNetS
from custom_datasets.eval.food101 import Food101
from custom_datasets.eval.cifar100 import CIFAR100

os.environ["TOKENIZERS_PARALLELISM"] = "true"


def get_collate_fn(processor, max_length):
    def collate_fn(batch):
        images = [item[0].convert("RGB") for item in batch]
        labels = torch.tensor([item[1] for item in batch])

        inputs = processor(
            images=images, return_tensors="pt"
        )
        return inputs, labels
    return collate_fn


@torch.no_grad()
def main(dataset_name, vlm):
    device = f"cuda:{torch.cuda.device_count() - 1}"

    # model and processor
    model_pack = get_model_and_processors(vlm)
    model = model_pack["model"].to(device)
    processor = model_pack["processor"]
    
    data_func = {
        "food101": Food101,
        "cifar100": CIFAR100,
        "imagenet1k": ImageNet1k,
        "imagenetr": ImageNetR,
        "imagenets": ImageNetS,
        "objectnet": ObjectNet,
    }

    if dataset_name in data_func:
        dataset = data_func[dataset_name]()
    else:
        raise NotImplementedError

    dataloader = DataLoader(
        dataset,
        batch_size=256,
        shuffle=False,
        num_workers=4,
        prefetch_factor=2,
        collate_fn=get_collate_fn(processor, model_pack["max_length"]))

    img_embs = []
    targets = []
    model.eval()
    for idx, batch in enumerate(dataloader):
        print(f'Processing batch {idx+1}', flush=True)
        img_inputs, target = batch
        img_inputs = {
            k: v.to(device) for k, v in img_inputs.items()
        }

        with torch.no_grad():
            emb = model.get_image_features(
                pixel_values=img_inputs['pixel_values']
            ).pooler_output
            emb = emb / emb.norm(p=2, dim=-1, keepdim=True)

        img_embs.append(emb.cpu())
        targets.append(target)
    img_embs = torch.cat(img_embs, dim=0)
    targets = torch.cat(targets, dim=0)

    # construct prompt for each class
    # 1. Get the mapping from the dataset
    class_mapping = dataset.class_to_idx
    # 2. Create a list of the correct size
    class_names = [""] * len(class_mapping)
    # 3. Populate the list so that class_names[i] is the name for label i
    for name, idx in class_mapping.items():
        # print(name, idx)
        class_names[idx] = name
    print(len(class_names))

    # 4. Now, create prompts from this correctly ordered list
    prompts = [f"a photo of a {c.replace('_', ' ')}" for c in class_names]

    print([prompts[0], prompts[1], prompts[2]])

    prompt_input = processor(
        text=prompts,
        padding="max_length",
        truncation=True,
        max_length=model_pack["max_length"],
        return_tensors="pt"
    )
    prompt_input = {
        k: v.to(device) for k, v in prompt_input.items()
    }

    prompt_emb = model.get_text_features(
        input_ids=prompt_input['input_ids'],
        attention_mask=prompt_input['attention_mask']).pooler_output.cpu()
    prompt_emb = prompt_emb / prompt_emb.norm(p=2, dim=-1, keepdim=True)

    save_path = f'embeddings/{dataset_name}/'
    os.makedirs(save_path, exist_ok=True)
    torch.save(img_embs, f'{save_path}/image.pth')
    torch.save(targets, f'{save_path}/target.pth')
    torch.save(prompt_emb, f'{save_path}/prompt.pth')


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('dataset', type=str, default='cifar100')
    args = parser.parse_args()

    dataset = args.dataset
    vlm = yaml.safe_load(open('configs.yaml'))['base_model']

    main(dataset, vlm)
