import lightning as L
from lightning.pytorch import seed_everything
import torch.nn as nn
import torch
from lightning.pytorch.callbacks import ModelCheckpoint
from nesim.lightning.cifar100 import (
    Cifar100Dataset,
    Cifar100HyperParams,
    Cifar100LightningModule,
)
from typing import Union
import torchvision.models as models
from nesim.configs import NesimConfig
from ..bimt.loss import BIMTConfig
from nesim.utils.checkpoint import load_and_filter_state_dict_keys

from pydantic import BaseModel, Extra
import json


def get_untrained_model(weights=None):
    model = models.resnet18(weights=weights)
    model.fc = nn.Linear(512, 100)
    return model


class Cifar100TrainingConfig(BaseModel, extra=Extra.forbid):
    hyperparams: Cifar100HyperParams
    nesim_config: NesimConfig
    bimt_config: Union[None, BIMTConfig]
    wandb_log: bool
    weights: Union[None, str]
    checkpoint_dir: str = "./checkpoints/cifar100"
    max_epochs: int = 10

    def save_json(self, filename: str):
        with open(filename, "w") as file:
            json.dump(self.dict(), file, indent=4)

    @classmethod
    def from_json(cls, filename: str):
        with open(filename, "r") as file:
            json_data = json.load(file)
        return cls.parse_obj(json_data)


class Cifar100Training:
    def __init__(self, config: Cifar100TrainingConfig):
        """
        trains a simple resnet18 on the CIFAR10 dataset
        - saves the model with the best validation loss
        """
        assert isinstance(config, Cifar100TrainingConfig)
        self.config = config

    def run(self):
        # 0. seed everything to make training runs deterministic
        seed_everything(0)

        """
        STUFF GOES HERE BELOW
        """
        ## setup traub abd valid dataset
        train_dataset = Cifar100Dataset(slice_name="train")
        validation_dataset = Cifar100Dataset(slice_name="test")

        model = get_untrained_model(weights=None)

        lightning_module = Cifar100LightningModule(
            model=model,
            hyperparams=self.config.hyperparams,
            nesim_config=self.config.nesim_config,
            checkpoint_dir=self.config.checkpoint_dir,
            train_dataset=train_dataset,
            validation_dataset=validation_dataset,
            wandb_log=self.config.wandb_log,
            bimt_config=self.config.bimt_config,
            nesim_device="cuda:0",
        )
        if self.config.weights != "None":
            lightning_module.load_state_dict(torch.load(self.config.weights))
            print(f"\nLoading lightning checkpoint: {self.config.weights}\n")

        # 4. setup callback
        # NOTE: best model = model with the lowest validation loss
        checkpoint_callback = ModelCheckpoint(
            save_top_k=1,
            monitor="val_loss",
            mode="min",
            dirpath=lightning_module.best_checkpoint_folder,
            # filename="food101-{epoch:02d}-{val_loss:.3f}-{val_acc:.3f}",
            filename="best_model",
        )

        trainer = L.Trainer(
            accelerator="auto",
            devices=1,
            max_epochs=self.config.max_epochs,
            logger=None,
            default_root_dir=lightning_module.checkpoint_dir,
            # saves top-K checkpoints based on "val_loss" metric
            callbacks=[checkpoint_callback],
        )

        # 5. validate once before training, run only on beginning of multi phase training
        # if self.config.weights == "None":
        #     trainer.validate(lightning_module)

        # 6. train + validate after each epoch
        trainer.fit(
            lightning_module,
            lightning_module.train_dataloader,
            lightning_module.validation_dataloader,
        )

        print(f"Best model: {checkpoint_callback.best_model_path}")

        print("EXPERIMENT COMPLETE")
