import os
import shutil
import torch
from pytorch_lightning import seed_everything
from nn_core.common import PROJECT_ROOT
from transformers import AutoModel, PreTrainedModel, AutoImageProcessor
from datasets import DatasetDict, load_dataset, load_from_disk
from itertools import product
from tqdm import tqdm
from torch.utils.data import DataLoader
import functools
from layskip.utils.dictionaries import DATASET2IMAGE_COLUMN, DATASET2LABEL_COLUMN, DATASET_NAME2HF_NAME
from layskip.utils.utils import image_encode, extract_specific_layers
from layskip.modules.module import SkipModel
import fire

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
seed_everything(0)


@torch.no_grad()
def encode_data(loader, encoder, translator_name, mode, skip, max_samples=3000):
    embeddings = []
    layer_embeddings = []

    layer_embeddings = extract_specific_layers(encoder, max_samples, loader, skip)

    skip_encoder = SkipModel(
        encoder=encoder,
        skips=skip,
        mode=mode,
        precomputed_embeddings=layer_embeddings,
        translator_name=translator_name,
    )

    skip_encoder = skip_encoder.to(device)
    for batch in loader:
        image = batch["images"].to(device)
        x = skip_encoder(image)

        embeddings.extend(x.cpu().tolist())

    return embeddings


def run_encoding():
    translator_name = "linear"
    mode = 1
    batch_size = 256
    split2encoding = {}

    datasets = [
        # "mnist",
        # "fashion-mnist"
        # "cifar10",
        # "cifar100-fine",
        "cifar100-coarse",
    ]
    encoders = [
        "google/vit-base-patch16-224"
        # "WinKawaks/vit-small-patch16-224",
        # "facebook/dinov2-small",
        # "facebook/deit-small-patch16-224",
    ]

    skips = [
        [(0, 1)],
        [(1, 2)],
        [(2, 3)],
        [(3, 4)],
        [(4, 5)],
        [(5, 6)],
        [(6, 7)],
        [(7, 8)],
        [(8, 9)],
        [(9, 10)],
        [(10, 11)],
        #
        # [(2, 5)],
        # [(3, 5)],
        # [(2, 4)],
        # [(1, 5)],
        # [(8, 10)],
        # [(9, 11)],
    ]

    for dataset_name, encoder_name in tqdm(product(datasets, encoders), desc="Generating embeddings"):

        print(f"Dataset: {dataset_name}, Encoder: {encoder_name}")

        DATASET_DIR = PROJECT_ROOT / "data" / "embeddings" / dataset_name / encoder_name.split("/")[1]

        if DATASET_DIR.exists():
            data: DatasetDict = load_from_disk(dataset_path=str(DATASET_DIR))
        else:
            DATASET_DIR.mkdir(parents=True, exist_ok=True)
            data: DatasetDict = DatasetDict(
                train=load_dataset(DATASET_NAME2HF_NAME[dataset_name], split="train"),
                test=load_dataset(DATASET_NAME2HF_NAME[dataset_name], split="test"),
            )

        encoder: PreTrainedModel = (
            AutoModel.from_pretrained(encoder_name, output_hidden_states=True, return_dict=True)
            .eval()
            .requires_grad_(False)
        )

        if encoder_name in ["WinKawaks/vit-small-patch16-224", "google/vit-base-patch16-224"]:
            processor = AutoImageProcessor.from_pretrained(encoder_name, use_fast=True)
        else:
            processor = AutoImageProcessor.from_pretrained(encoder_name)

        train_loader = DataLoader(
            data["train"],
            batch_size=batch_size,
            pin_memory=True,
            shuffle=False,
            num_workers=8,
            collate_fn=functools.partial(
                image_encode,
                processor=processor,
                image_name=DATASET2IMAGE_COLUMN[dataset_name],
                label_name=DATASET2LABEL_COLUMN[dataset_name],
            ),
        )

        test_loader = DataLoader(
            data["test"],
            batch_size=batch_size,
            pin_memory=True,
            shuffle=False,
            num_workers=8,
            collate_fn=functools.partial(
                image_encode,
                processor=processor,
                image_name=DATASET2IMAGE_COLUMN[dataset_name],
                label_name=DATASET2LABEL_COLUMN[dataset_name],
            ),
        )

        for skip in tqdm(skips, desc="Encoding"):
            print(f"Skip: {skip}")

            split2encoding = split2encoding | {
                "train": encode_data(
                    loader=train_loader,
                    encoder=encoder,
                    translator_name=translator_name,
                    mode=mode,
                    skip=skip,
                )
            }

            split2encoding = split2encoding | {
                "test": encode_data(
                    loader=test_loader,
                    encoder=encoder,
                    translator_name=translator_name,
                    mode=mode,
                    skip=skip,
                )
            }

            for split, encoding in split2encoding.items():
                if str(skip) not in data[split].column_names:
                    data[split] = data[split].add_column(str(skip), encoding)

            encoder.cpu()

            # trick for huggingface
            if os.path.exists((str(DATASET_DIR))):

                temp_dir = PROJECT_ROOT / "data" / "embeddings" / "temp"
                temp_dir.mkdir(parents=True, exist_ok=True)
                data.save_to_disk(str(temp_dir))

                shutil.rmtree(DATASET_DIR)
                shutil.move(temp_dir, DATASET_DIR)


if __name__ == "__main__":
    fire.Fire(run_encoding)
