from torch import nn, optim
import os
import torch
from nn_core.common import PROJECT_ROOT
from pytorch_lightning import seed_everything
from torch.utils.data import DataLoader
import pandas as pd
from datasets import DatasetDict
from layskip.pl_modules.train_NN import train_classifier
from layskip.modules.module import HFwrapper, NoEncoder
from layskip.utils.dictionaries import (
    DATASET2LABEL_COLUMN,
    DATASET2NUM_CLASSES,
)
from typing import List
from transformers import AutoConfig, AutoModel
import fire

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def skip_and_train_run(
    dataset_name: str, model_name: str, layers_to_approximate: List, seed: int, classifier_type: str
):

    print(
        f"Dataset: {dataset_name}, model: {model_name}, approximating:{layers_to_approximate}, seed: {seed}, classifier_type: {classifier_type}"
    )

    seed_everything(seed)

    results_path = PROJECT_ROOT / "results" / "train_approximated_blocks.csv"

    EMBEDDINGS_DIR = str(PROJECT_ROOT / "data" / "embeddings" / dataset_name / model_name.split("/")[1])
    embeddings = DatasetDict.load_from_disk(EMBEDDINGS_DIR)
    embeddings.set_format("torch")

    if os.path.exists(results_path):
        results_df = pd.read_csv(results_path)
    else:
        results_df = pd.DataFrame(
            columns=[
                "seed",
                "dataset",
                "model",
                "optimizer",
                "lr",
                "classifier",
                "batch_size",
                "num_epochs",
                "approx_layer",
                "num_layers",
                "original_accuracy",
                "accuracy",
                "delta_acc",
            ]
        )

    results_list = []
    results = {}
    original_accuracy = 0
    batch_size = 256

    config = AutoConfig.from_pretrained(model_name, output_hidden_states=True, return_dict=True)
    encoder = AutoModel.from_pretrained(model_name, config=config)

    hf_train_embeddings = (
        embeddings["train"]
        .select_columns([str(layers_to_approximate), DATASET2LABEL_COLUMN[dataset_name]])
        .rename_column(str(layers_to_approximate), "images")
        .rename_column(DATASET2LABEL_COLUMN[dataset_name], "labels")
    )

    hf_test_embeddings = (
        embeddings["test"]
        .select_columns([str(layers_to_approximate), DATASET2LABEL_COLUMN[dataset_name]])
        .rename_column(str(layers_to_approximate), "images")
        .rename_column(DATASET2LABEL_COLUMN[dataset_name], "labels")
    )

    train_dataloader = DataLoader(
        hf_train_embeddings,
        shuffle=True,
        batch_size=batch_size,
        num_workers=8,
        pin_memory=True,
    )

    test_dataloader = DataLoader(
        hf_test_embeddings,
        shuffle=False,
        batch_size=256,
        num_workers=8,
        pin_memory=True,
    )

    no_encoder = NoEncoder(embeddings=hf_train_embeddings)
    no_encoder.to(device)

    if classifier_type == "MLP":
        classifier = nn.Sequential(
            nn.Linear(
                encoder.config.hidden_size,
                encoder.config.hidden_size,
            ),
            nn.Dropout(0.5),
            nn.LayerNorm(encoder.config.hidden_size),
            nn.SiLU(),
            nn.Linear(encoder.config.hidden_size, DATASET2NUM_CLASSES[dataset_name]),
        )

    elif classifier_type == "linear":
        classifier = nn.Linear(encoder.config.hidden_size, DATASET2NUM_CLASSES[dataset_name])

    skip_model = HFwrapper(encoder=no_encoder, classifier=classifier)
    skip_model.to(device)
    skip_model.freeze_encoder()

    if classifier_type == "MLP":
        lr = 0.001
        num_epochs = 50
        optimizer = optim.Adam(skip_model.parameters(), lr=lr, weight_decay=1e-5)

    elif classifier_type == "linear":
        lr = 0.01
        num_epochs = 5
        optimizer = optim.Adam(skip_model.parameters(), lr=lr)

    criterion = nn.CrossEntropyLoss()

    _, _, _, eval_accuracies, _ = train_classifier(
        num_epochs=num_epochs,
        model=skip_model,
        train_data_loader=train_dataloader,
        optimizer=optimizer,
        criterion=criterion,
        test_data_loader=test_dataloader,
        embedding_column_name="images",
        label_column_name="labels",
    )

    accuracy = eval_accuracies[-1]

    original_accuracy = 0.0  # TODO: fix this
    if str(layers_to_approximate) == "[]":
        original_accuracy = accuracy
    else:
        orig_df = pd.read_csv(results_path)
        orig_df = orig_df[
            (orig_df["approx_layer"] == "[]")
            & (orig_df["dataset"] == dataset_name)
            & (orig_df["model"] == model_name)
            & (orig_df["classifier"] == classifier.__class__.__name__)
            & (orig_df["num_epochs"] == num_epochs)
            & (orig_df["seed"] == seed)
            & (orig_df["batch_size"] == batch_size)
            & (orig_df["lr"] == lr)
        ]
        original_accuracy = orig_df["accuracy"].item()

    diff = original_accuracy - accuracy
    delta_acc = diff

    results = {
        "seed": seed,
        "dataset": dataset_name,
        "model": model_name,
        "optimizer": optimizer.__class__.__name__,
        "lr": lr,
        "classifier": classifier.__class__.__name__,
        "batch_size": batch_size,
        "num_epochs": num_epochs,
        "approx_layer": layers_to_approximate,
        "original_accuracy": original_accuracy,
        "num_layers": sum([i[1] - i[0] for i in layers_to_approximate]),
        "accuracy": accuracy,
        "delta_acc": delta_acc,
    }

    results_list.append(results)

    new_results_df = pd.DataFrame(results_list)
    results_df = pd.concat([results_df, new_results_df])
    results_df.to_csv(results_path, index=False)


if __name__ == "__main__":
    fire.Fire(skip_and_train_run)
