import lightning as L
from lightning.pytorch import seed_everything
import torch.nn as nn
from torchvision.datasets import MNIST
import torchvision.transforms as transforms
from lightning.pytorch.callbacks import ModelCheckpoint
from nesim.lightning.mnist import MNISTLightningModule, MNISTHyperParams
from ..bimt.loss import BIMTConfig
import wandb
from typing import Union
from nesim.configs import NesimConfig

from pydantic import BaseModel, Extra
import json
import torch
import os


def get_untrained_model(hidden_size):
    model = nn.Sequential(
        nn.Flatten(),
        nn.Linear(28 * 28, hidden_size),  ## 1
        nn.ReLU(),
        nn.Dropout(0.1),
        nn.Linear(hidden_size, hidden_size),  ## 4
        nn.ReLU(),
        nn.Dropout(0.1),
        nn.Linear(hidden_size, 10),  ## 7
    )
    return model


class MnistTrainingConfig(BaseModel, extra=Extra.forbid):
    """
    - `hidden_size`: number of neurons in hidden layer
    - `nesim_config`: specifies nesim stuff
    - `hyperparams`: training hyperparams
    - `checkpoint_dir`: folder within which we save best model and checkpoints after every n steps (specified in hyperparams)
    - `wandb_log`: set to True if you want to log stuff to wandb
    - `max_epochs`: number of epochs to train
    - `data_dir`: folder where the MNIST dataset is located
    """

    hidden_size: int
    nesim_config: NesimConfig
    bimt_config: Union[None, BIMTConfig]
    hyperparams: MNISTHyperParams
    wandb_log: bool
    max_epochs: int
    checkpoint_dir: str = "./checkpoints/mnist"
    data_dir: str = "./data"
    load_checkpoint: Union[str, None] = None

    def save_json(self, filename: str):
        with open(filename, "w") as file:
            json.dump(self.dict(), file, indent=4)

    @classmethod
    def from_json(cls, filename: str):
        with open(filename, "r") as file:
            json_data = json.load(file)
        return cls.parse_obj(json_data)


class MNISTTraining:
    def __init__(self, config: MnistTrainingConfig):
        """
        trains a simple 3 layered MLP on the MNIST dataset
        - saves the model with the best validation loss
        """
        assert isinstance(config, MnistTrainingConfig)
        self.config = config

    def run(self):
        # 0. seed everything to make training runs deterministic
        seed_everything(0)

        transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize((0.1307,), (0.3081,)),
            ]
        )

        model = nn.Sequential(
            nn.Flatten(),
            nn.Linear(28 * 28, self.config.hidden_size),  ## 1
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(self.config.hidden_size, self.config.hidden_size),  ## 4
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(self.config.hidden_size, 10),  ## 7
        )

        if self.config.load_checkpoint is not None:
            assert os.path.exists(
                self.config.load_checkpoint
            ), f"Invalid checkpoint path: {self.config.load_checkpoint}"
            print(f"Loading checkpoint: {self.config.load_checkpoint}")
            model.load_state_dict(
                torch.load(self.config.load_checkpoint, map_location="cpu")
            )

        # 2. setup dataset
        train_dataset = MNIST(
            self.config.data_dir, train=True, transform=transform, download=True
        )
        validation_dataset = MNIST(
            self.config.data_dir, train=False, transform=transform, download=True
        )

        # 3. init MNIST lightning module with hyperparams
        lightning_module = MNISTLightningModule(
            model=model,
            hyperparams=self.config.hyperparams,
            nesim_config=self.config.nesim_config,
            checkpoint_dir=self.config.checkpoint_dir,
            train_dataset=train_dataset,
            validation_dataset=validation_dataset,
            wandb_log=self.config.wandb_log,
            bimt_config=self.config.bimt_config,
        )

        # 4. setup callback
        # NOTE: best model = model with the lowest validation loss
        checkpoint_callback = ModelCheckpoint(
            save_top_k=1,
            monitor="val_loss",
            mode="min",
            dirpath=lightning_module.best_checkpoint_folder,
            filename="mnist-{epoch:02d}-{val_loss:.3f}-{val_acc:.3f}",
        )

        trainer = L.Trainer(
            accelerator="auto",
            devices=1,
            max_epochs=self.config.max_epochs,
            logger=None,
            default_root_dir=lightning_module.checkpoint_dir,
            # saves top-K checkpoints based on "val_loss" metric
            callbacks=[checkpoint_callback],
        )

        # 5. validate once before training
        trainer.validate(lightning_module)

        # 6. train + validate after each epoch
        trainer.fit(
            lightning_module,
            lightning_module.train_dataloader,
            lightning_module.validation_dataloader,
        )

        print(f"Best model: {checkpoint_callback.best_model_path}")

        print("EXPERIMENT COMPLETE")
