import lightning as L
from lightning.pytorch import seed_everything
import torch.nn as nn
from lightning.pytorch.callbacks import ModelCheckpoint
from nesim.lightning.imagenet import (
    ConvertedImagenetDataset,
    ImageNetHyperParams,
    ImageNetLightningModule,
)
from typing import Union, Literal
import torchvision.models as models
from nesim.configs import NesimConfig
from ..bimt.loss import BIMTConfig
from ..losses.cross_layer_correlation.loss import CrossLayerCorrelationLossConfig
from ..utils.checkpoint import load_and_filter_state_dict_keys
from ..weights_init.sorted_weights import SortedWeightsInit
from ..utils.json_stuff import load_json_as_dict
from fastapi.encoders import jsonable_encoder

from pydantic import BaseModel, Extra
import json
import os
import torch


class ImageNetTrainingState:
    def __init__(
        self,
        model,
        optimizer,
        scheduler,
        wandb_run_id: str = None,
        global_step: int = None,
        nesim_config_filename: str = None,
    ):
        self.model = model
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.wandb_run_id = wandb_run_id
        self.global_step = global_step
        self.nesim_config_filename = nesim_config_filename

    def save(self, filename):
        save_dict = {
            "model_state_dict": self.model.state_dict(),
            "optimizer_state_dict": self.optimizer.state_dict(),
            "scheduler_state_dict": self.scheduler.state_dict(),
            "wandb_run_id": self.wandb_run_id,
            "global_step": self.global_step,
            "nesim_config_filename": self.nesim_config_filename,
        }
        torch.save(save_dict, filename)
        print(f"Saved ImagNetTrainingState: {filename}")

    def load(
        self,
        filename,
    ):
        assert os.path.exists(filename), f"Invalid filename: {filename}"
        train_state = torch.load(filename)

        self.model.load_state_dict(train_state["model_state_dict"])
        self.optimizer.load_state_dict(train_state["optimizer_state_dict"])
        self.scheduler.load_state_dict(train_state["scheduler_state_dict"])

        return {
            "model": self.model,
            "optimizer": self.optimizer,
            "scheduler": self.scheduler,
            "wandb_run_id": train_state["wandb_run_id"],
            "global_step": train_state["global_step"],
            "nesim_config_filename": train_state["nesim_config_filename"],
        }


class ImageNetTrainingConfig(BaseModel, extra=Extra.forbid):
    hyperparams: ImageNetHyperParams
    nesim_config: NesimConfig
    bimt_config: Union[None, BIMTConfig]
    cross_layer_correlation_loss_config: Union[None, CrossLayerCorrelationLossConfig]
    wandb_log: bool
    weights: Union[None, str]
    checkpoint_dir: str = "./checkpoints/imagenet"
    cache_dir: str = "/mindhive/nklab3/users/XXXX-1/datasets/imagenet"
    max_epochs: int = 10
    skip_initial_validation_run: bool = False
    resume_from_checkpoint: Union[str, None] = None
    apply_sorted_weights_init_filename: Union[str, None] = None
    model_name: Literal["resnet50", "resnet18"] = "resnet50"
    wandb_project_name: Union[str, None] = None
    wandb_run_id: Union[str, None] = None

    def save_json(self, filename: str):
        with open(filename, "w") as file:
            json.dump(self.dict(), file, indent=4)

    @classmethod
    def from_json(cls, filename: str):
        with open(filename, "r") as file:
            json_data = json.load(file)
        return cls.parse_obj(json_data)


class ImageNetTraining:
    def __init__(self, config: ImageNetTrainingConfig):
        """
        - saves the model with the best validation loss
        """
        assert isinstance(config, ImageNetTrainingConfig)
        self.config = config

    def run(self, run_name: str):

        # 0. seed everything to make trainsing runs deterministic
        seed_everything(0)

        """
        STUFF GOES HERE BELOW
        """
        if self.config.model_name == "resnet50":
            model = models.resnet50(weights=self.config.weights)
        elif self.config.model_name == "resnet18":
            model = models.resnet18(weights=self.config.weights)
        else:
            raise NameError(f"Invalid config.model_name:{self.config.model_name}")

        if self.config.apply_sorted_weights_init_filename is not None:
            print(
                f"Applying SortedWeightsInit as deifned on file: {self.config.apply_sorted_weights_init_filename}"
            )
            layer_names = load_json_as_dict(
                self.config.apply_sorted_weights_init_filename
            )
            sorted_weights_init = SortedWeightsInit(layer_names=layer_names)
            model = sorted_weights_init.apply(model=model)

        ## setup train and valid dataset
        train_dataset = ConvertedImagenetDataset(
            slice_name="train", folder=os.path.join(self.config.cache_dir, "train")
        )
        validation_dataset = ConvertedImagenetDataset(
            slice_name="validation",
            folder=os.path.join(self.config.cache_dir, "validation"),
        )

        ## init config
        lightning_module = ImageNetLightningModule(
            model=model,
            hyperparams=self.config.hyperparams,
            nesim_config=self.config.nesim_config,
            checkpoint_dir=self.config.checkpoint_dir,
            train_dataset=train_dataset,
            validation_dataset=validation_dataset,
            wandb_log=self.config.wandb_log,
            bimt_config=self.config.bimt_config,
            cross_layer_correlation_loss_config=self.config.cross_layer_correlation_loss_config,
            nesim_device="cuda:0",
        )

        import wandb
        if self.config.wandb_run_id is not None:
            wandb.init(
                project=self.config.wandb_project_name,
                id=self.config.wandb_run_id,
                resume="must",
            )
            print(f"Resuming wandb run ID: {self.config.wandb_run_id}")

        else:
            wandb_run = wandb.init(
                project=self.config.wandb_project_name,
                name=run_name,
                config=jsonable_encoder(self.config),
            )
            print(f"Starting wandb run ID: {wandb_run.id}")

        # 4. setup callback
        # NOTE: best model = model with the lowest validation loss
        checkpoint_callback = ModelCheckpoint(
            save_top_k=1,
            monitor="val_loss",
            mode="min",
            dirpath=lightning_module.best_checkpoint_folder,
            filename="best_model",
        )

        trainer = L.Trainer(
            accelerator="auto",
            devices=1,
            max_epochs=self.config.max_epochs,
            logger=None,
            default_root_dir=lightning_module.checkpoint_dir,
            # saves top-K checkpoints based on "val_loss" metric
            callbacks=[checkpoint_callback],
        )

        # 5. validate once before training
        if self.config.skip_initial_validation_run is False:
            trainer.validate(lightning_module)
        else:
            print("Skipping initial validation run")

        # 6. train + validate after each epoch
        trainer.fit(
            lightning_module,
            lightning_module.train_dataloader,
            lightning_module.validation_dataloader,
            ckpt_path = self.config.resume_from_checkpoint
        )

        print(f"Best model: {checkpoint_callback.best_model_path}")

        print("EXPERIMENT COMPLETE")
