from torch import nn
import os
import time
import torch
from nn_core.common import PROJECT_ROOT
from pytorch_lightning import seed_everything
from datasets import load_dataset
from torch.utils.data import DataLoader
from transformers import AutoConfig, AutoModel, AutoImageProcessor
import functools
import pandas as pd
from typing import List
from layskip.pl_modules.train_NN import eval_classifier
from layskip.modules.module import HFwrapper, SkipModel
from layskip.utils.dictionaries import DATASET2IMAGE_COLUMN, DATASET2LABEL_COLUMN
from layskip.utils.utils import image_encode, extract_specific_layers
import fire

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def skip_layers(dataset_name: str, model_name: str, layers_to_skip: List):
    seed = 0
    seed_everything(seed)

    print(f"Dataset: {dataset_name}, model: {model_name}")

    TRAINED_CLASSIFIER = PROJECT_ROOT / "models" / model_name.split("/")[1] / (dataset_name + "_classifier.ckpt")
    dataset = load_dataset(dataset_name)

    train_dataset = dataset["train"]
    test_dataset = dataset["test"]

    image_name = DATASET2IMAGE_COLUMN[dataset_name]
    label_name = DATASET2LABEL_COLUMN[dataset_name]

    processor = AutoImageProcessor.from_pretrained(model_name)
    batch_size = 256

    train_dataloader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=0,
        pin_memory=True,
        collate_fn=functools.partial(image_encode, processor=processor, image_name=image_name, label_name=label_name),
    )

    test_dataloader = DataLoader(
        test_dataset,
        batch_size=256,
        shuffle=False,
        num_workers=0,
        pin_memory=True,
        collate_fn=functools.partial(image_encode, processor=processor, image_name=image_name, label_name=label_name),
    )

    ### Original performance
    trained_classifier = torch.load(TRAINED_CLASSIFIER)

    config = AutoConfig.from_pretrained(model_name, output_hidden_states=True, return_dict=True)
    encoder = AutoModel.from_pretrained(model_name, config=config)

    model = HFwrapper(encoder=encoder, classifier=trained_classifier)
    model.to(device)
    criterion = nn.CrossEntropyLoss()

    original_start_time = time.time()

    _, original_accuracy = eval_classifier(
        model=model,
        test_data_loader=test_dataloader,
        criterion=criterion,  # just for plotting loss
        embedding_column_name="images",
        label_column_name="labels",
    )

    original_end_time = time.time()
    original_elapsed_time = original_end_time - original_start_time

    original_minutes = int(original_elapsed_time // 60)
    original_seconds = int(original_elapsed_time % 60)
    original_execution = str(original_minutes) + "'" + str(original_seconds) + "''"

    #### Extract internal information
    max_samples = 3000
    layer_embeddings = extract_specific_layers(encoder, max_samples, train_dataloader, layers_to_skip)

    ### Skip Layers
    results_path = PROJECT_ROOT / "results" / "skipping_results.csv"
    if os.path.exists(results_path):
        results_df = pd.read_csv(results_path)
    else:
        results_df = pd.DataFrame(
            columns=[
                "seed",
                "mode",
                "dataset",
                "model",
                "reference_points",
                "translator",
                "skip_layers",
                "num_skip_layers",
                "criterion",
                "optimizer",
                "lr",
                "batch_size",
                "classifier",
                "original_accuracy",
                "original_execution",
                "skip_accuracy",
                "skip_execution",
                "accuracy_diff",
            ]
        )

    results_list = []

    mode = 2
    translator_name = "SGDMLPAligner"  # linear

    skip_encoder = SkipModel(
        encoder=encoder,
        skips=layers_to_skip,
        mode=mode,
        precomputed_embeddings=layer_embeddings,
        translator_name=translator_name,
    )

    skip_encoder.to(device)

    skip_model = HFwrapper(encoder=skip_encoder, classifier=trained_classifier)
    skip_model.to(device)

    start_time = time.time()

    _, new_accuracy = eval_classifier(
        model=skip_model,
        test_data_loader=test_dataloader,
        criterion=criterion,
        embedding_column_name="images",
        label_column_name="labels",
    )

    end_time = time.time()
    elapsed_time = end_time - start_time

    minutes = int(elapsed_time // 60)
    seconds = int(elapsed_time % 60)
    execution = str(minutes) + "'" + str(seconds) + "''"

    accuracy_difference = new_accuracy - original_accuracy
    new_accuracy, original_accuracy
    results = {
        "seed": seed,
        "mode": mode,
        "dataset": dataset_name,
        "model": model_name,
        "reference_points": max_samples,
        "translator": translator_name,
        "skip_layers": layers_to_skip,
        "num_skip_layers": sum([i[1] - i[0] for i in layers_to_skip]),
        "criterion": "CrossEntropyLoss",
        "optimizer": "Adam",
        "lr": 0.01,
        "batch_size": batch_size,
        "classifier": trained_classifier.__class__.__name__,
        "original_accuracy": original_accuracy,
        "skip_accuracy": new_accuracy,
        "original_execution": original_execution,
        "skip_execution": execution,
        "accuracy_diff": round(abs(accuracy_difference), 3),
    }

    results_list.append(results)
    new_results_df = pd.DataFrame(results_list)

    results_df = pd.concat([results_df, new_results_df])
    results_df.to_csv(results_path, index=False)


if __name__ == "__main__":
    fire.Fire(skip_layers)
