import torch
import os
import pandas as pd
from nn_core.common import PROJECT_ROOT
from pytorch_lightning import seed_everything
from datasets import load_dataset
from torch.utils.data import DataLoader
from transformers import (
    AutoConfig,
    AutoModel,
    AutoImageProcessor,
)
import functools
from torch import nn, optim

from tqdm import tqdm
from itertools import product

from layskip.pl_modules.train_NN import train_classifier
from layskip.modules.module import HFwrapper
from layskip.utils.dictionaries import (
    DATASET2IMAGE_COLUMN,
    DATASET2LABEL_COLUMN,
    DATASET2NUM_CLASSES,
)
from layskip.utils.utils import image_encode

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device


DATASET_NAMES = [
    "cifar100-fine"
]  # ["cifar100-coarse", "cifar100-fine","mnist", "cifar10", "cifar100", "ILSVRC/imagenet-1k"]

MODEL_NAMES = [
    # "WinKawaks/vit-small-patch16-224",
    "facebook/dinov2-small",
    # "facebook/deit-small-patch16-224",
    # "facebook/dino-vits8",
    # "facebook/dino-vitb16",
    # "facebook/dinov2-base",
    # "microsoft/swinv2-tiny-patch4-window8-256",
    # "microsoft/beit-base-patch16-224",
    # "snap-research/efficientformer-l1-300",
]

SEEDS = [0]

results_path = PROJECT_ROOT / "results" / "train_classifier_results.csv"
if os.path.exists(results_path):
    results_df = pd.read_csv(results_path)
else:
    results_df = pd.DataFrame(
        columns=[
            "seed",
            "dataset",
            "model",
            "criterion",
            "optimizer",
            "classifier",
            "num_epochs",
            "train_accuracy",
            "test_accuracy",
        ]
    )

for dataset_name, model_name, seed_value in tqdm(
    product(DATASET_NAMES, MODEL_NAMES, SEEDS), desc="Training classifiers"
):

    seed_everything(seed_value)

    # dataset stuff
    if dataset_name in ["cifar100-fine", "cifar100-coarse"]:
        dataset = load_dataset("cifar100", trust_remote_code=True)
    else:
        dataset = load_dataset(dataset_name, trust_remote_code=True)

    train_dataset = dataset["train"]
    if dataset_name == "ILSVRC/imagenet-1k":
        test_dataset = dataset["validation"]
    else:
        test_dataset = dataset["test"]

    image_name = DATASET2IMAGE_COLUMN[dataset_name]
    label_name = DATASET2LABEL_COLUMN[dataset_name]
    num_classes = DATASET2NUM_CLASSES[dataset_name]

    print(f"Dataset: {dataset_name}, Encoder: {model_name}")

    # HF config stuff
    config = AutoConfig.from_pretrained(model_name, output_hidden_states=True, return_dict=True)
    encoder = AutoModel.from_pretrained(model_name, config=config)

    if model_name == "WinKawaks/vit-small-patch16-224":
        processor = AutoImageProcessor.from_pretrained(model_name, use_fast=True)
    else:
        processor = AutoImageProcessor.from_pretrained(model_name)

    # model stuff
    # classifier = nn.Linear(encoder.config.hidden_size, num_classes)
    classifier = nn.Sequential(
        nn.Linear(
            encoder.config.hidden_size,
            encoder.config.hidden_size,
        ),
        nn.Dropout(0.5),
        nn.LayerNorm(encoder.config.hidden_size),
        nn.SiLU(),
        nn.Linear(encoder.config.hidden_size, DATASET2NUM_CLASSES[dataset_name]),
    )

    model = HFwrapper(encoder=encoder, classifier=classifier)

    criterion = nn.CrossEntropyLoss()
    lr = 0.001
    # optimizer = optim.Adam(model.parameters(), lr=lr)
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-5)
    num_epochs = 50

    train_dataloader = DataLoader(
        train_dataset,
        batch_size=256,
        shuffle=True,
        num_workers=8,
        pin_memory=True,
        collate_fn=functools.partial(image_encode, processor=processor, image_name=image_name, label_name=label_name),
    )

    test_dataloader = DataLoader(
        test_dataset,
        batch_size=256,
        shuffle=False,
        num_workers=8,
        pin_memory=True,
        collate_fn=functools.partial(image_encode, processor=processor, image_name=image_name, label_name=label_name),
    )

    print(f"Starting training")
    model.to(device)

    train_losses, eval_losses, train_accuracies, eval_accuracies, eval_indexes = train_classifier(
        num_epochs=num_epochs,
        model=model,
        train_data_loader=train_dataloader,
        optimizer=optimizer,
        criterion=criterion,
        test_data_loader=test_dataloader,
        embedding_column_name="images",
        label_column_name="labels",
    )

    print(f"Saving results...")
    output_dataset_name = dataset_name.split("/")[1] if dataset_name == "ILSVRC/imagenet-1k" else dataset_name

    # csv
    results = {
        "seed": seed_value,
        "dataset": output_dataset_name,
        "model": model_name,
        "criterion": criterion.__class__.__name__,
        "optimizer": optimizer.__class__.__name__,
        "lr": lr,
        "num_epochs": num_epochs,
        "classifier": classifier.__class__.__name__,
        "train_accuracy": round(train_accuracies[-1], 3),
        "test_accuracy": round(eval_accuracies[-1], 3),
    }

    # Save the current result
    new_results_df = pd.DataFrame([results])
    results_df = pd.concat([results_df, new_results_df])
    results_df.to_csv(results_path, index=False)

    # saving the model
    model_dir_path = PROJECT_ROOT / "models" / model_name.split("/")[1]
    model_dir_path.mkdir(parents=True, exist_ok=True)
    path_to_save = model_dir_path / (output_dataset_name + "_classifier.ckpt")

    torch.save(model.classifier, path_to_save)
