import fire
import torch
from nn_core.common import PROJECT_ROOT
from pytorch_lightning import seed_everything
from datasets import load_dataset, DownloadConfig, VerificationMode
from torch.utils.data import DataLoader
from transformers import AutoConfig, AutoModel, AutoImageProcessor
import functools
from typing import List
from layskip.modules.module import HFwrapper, SkipModel
from layskip.utils.dictionaries import DATASET2IMAGE_COLUMN, DATASET2LABEL_COLUMN, DATASET_NAME2HF_NAME
from layskip.utils.utils import image_encode, extract_specific_layers

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def save_translators(dataset_name: str, model_name: str, layers_to_approximate: List, seed: int):

    seed_everything(seed)

    print(f"Dataset: {dataset_name}, model: {model_name}, layers: {layers_to_approximate}")

    TRAINED_CLASSIFIER = PROJECT_ROOT / "models" / model_name.split("/")[1] / (dataset_name + "_classifier.ckpt")

    if dataset_name == "imagenet-1k":
        val_data = load_dataset(
            dataset_name,
            split="validation",
            data_files={"val": "data/val_images.tar.gz"},
            revision="refs/pr/20",
            trust_remote_code=True,
            download_config=DownloadConfig(resume_download=True),
            verification_mode=VerificationMode.NO_CHECKS,
        )
        dataset = val_data.train_test_split(test_size=0.2)
    else:
        dataset = load_dataset(DATASET_NAME2HF_NAME[dataset_name])

    train_dataset = dataset["train"]

    image_name = DATASET2IMAGE_COLUMN[dataset_name]
    label_name = DATASET2LABEL_COLUMN[dataset_name]

    processor = AutoImageProcessor.from_pretrained(model_name)

    train_dataloader = DataLoader(
        train_dataset,
        batch_size=256,
        shuffle=True,
        num_workers=0,
        pin_memory=True,
        collate_fn=functools.partial(image_encode, processor=processor, image_name=image_name, label_name=label_name),
    )

    ### Original performance
    with torch.serialization.safe_globals([torch.nn.modules.linear.Linear]): #  updated for torch compatibility
        trained_classifier = torch.load(TRAINED_CLASSIFIER)
        
    config = AutoConfig.from_pretrained(model_name, output_hidden_states=True, return_dict=True)
    encoder = AutoModel.from_pretrained(model_name, config=config)

    model = HFwrapper(encoder=encoder, classifier=trained_classifier)
    model.to(device)

    #### Extract internal information
    max_samples = 3000
    layer_embeddings = extract_specific_layers(encoder, max_samples, train_dataloader, layers_to_approximate)

    to_save_translator_path = PROJECT_ROOT / "latent_translators" / model_name.split("/")[1] / dataset_name

    _ = SkipModel(
        encoder=encoder,
        skips=layers_to_approximate,
        mode=1,
        precomputed_embeddings=layer_embeddings,
        translator_factory_name="linear",
        translator_key=str(layers_to_approximate[0]),
        to_save_translator_path=to_save_translator_path,
    )


if __name__ == "__main__":
    fire.Fire(save_translators)
