import torch
import os
import wandb
import lightning as L
import torch.nn.functional as F
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
from pydantic import BaseModel, Extra
from torchmetrics import Accuracy

from ..configs import NesimConfig
from ..utils import make_folder_if_does_not_exist
from ..losses.nesim_loss import NesimLoss
from ..bimt.loss import BIMTLoss, BIMTConfig


class MNISTHyperParams(BaseModel, extra=Extra.forbid):
    lr: float
    batch_size: int
    weight_decay: float
    save_checkpoint_every_n_steps: int
    apply_nesim_every_n_steps: int


class MNISTLightningModule(L.LightningModule):
    def __init__(
        self,
        model,
        hyperparams: MNISTHyperParams,
        nesim_config: NesimConfig,
        checkpoint_dir: str,
        train_dataset: MNIST,
        validation_dataset: MNIST,
        wandb_log: bool,
        nesim_device: str = "cuda:0",
        bimt_config: BIMTConfig = None,
    ) -> None:
        super().__init__()
        assert isinstance(hyperparams, MNISTHyperParams)
        self.wandb_log = wandb_log
        self.model = model
        self.nesim_config = nesim_config
        self.hyperparams = hyperparams
        self.checkpoint_dir = checkpoint_dir

        ## if bimt config is not None then apply bimt loss
        if bimt_config is not None:
            self.bimt = BIMTLoss.from_config(
                config=bimt_config,
            )
            self.model = self.bimt.init_modules_for_training(model=self.model)
        else:
            self.bimt = None

        assert os.path.exists(
            self.checkpoint_dir
        ), f"Expected checkpoint_dir to exist: {checkpoint_dir}"
        assert isinstance(train_dataset, MNIST)
        assert isinstance(validation_dataset, MNIST)

        self.train_dataloader = DataLoader(
            train_dataset, batch_size=self.hyperparams.batch_size
        )
        self.validation_dataloader = DataLoader(
            validation_dataset, batch_size=self.hyperparams.batch_size
        )
        """
        todos:
        0. dynamically decide where to save based on hyperparams + nesim config
        1. save model checkpoint every n steps DONE
        2. save best checkpoint in a seperate folder DONE
        3. implement nesim loss
        4. wandb logging
        """

        """
        Change this section to switch to a different dataset
        """

        self.val_accuracy = Accuracy(task="multiclass", num_classes=10, top_k=1)
        self.test_accuracy = Accuracy(task="multiclass", num_classes=10, top_k=1)
        self.train_step_idx = 0

        """
        checkpoint_dir:
            best/
                model.pth
            all_checkpoints/
                train_step_0.pth
                train_step_10.pth
                train_step_20.pth
                ...
        """
        self.all_checkpoints_folder = os.path.join(self.checkpoint_dir, "all")
        self.best_checkpoint_folder = os.path.join(self.checkpoint_dir, "best")
        make_folder_if_does_not_exist(self.all_checkpoints_folder)
        make_folder_if_does_not_exist(self.best_checkpoint_folder)

        ## store losses
        self.validation_step_losses_single_epoch = []
        self.validation_step_acc_single_epoch = []

        if nesim_config is not None:
            self.nesim_loss = NesimLoss(
                model=self.model, config=nesim_config, device=nesim_device
            )

    def save_checkpoint(self, filename):
        torch.save(self.model.state_dict(), filename)
        print(f"[saved checkpoint] {filename}")

    def training_step(self, batch, batch_idx):

        if self.train_step_idx % self.hyperparams.save_checkpoint_every_n_steps == 0:
            self.save_checkpoint(
                filename=os.path.join(
                    self.all_checkpoints_folder,
                    f"train_step_idx_{self.train_step_idx}.pth",
                )
            )

        x, y = batch
        logits = self.model(x)
        loss = F.cross_entropy(logits, y)

        if self.wandb_log:
            wandb.log({"training_loss": loss.item()})

        if self.train_step_idx % self.hyperparams.apply_nesim_every_n_steps == 0:
            nesim_loss_item = self.nesim_loss.compute(reduce_mean=True)

            if self.wandb_log:
                self.nesim_loss.wandb_log()

            if nesim_loss_item is not None:
                loss = loss + nesim_loss_item

        if self.bimt is not None:
            bimt_loss = self.bimt.forward(model=self.model)
            loss = loss + bimt_loss

            if self.wandb_log:
                ## divide by scale to get actual loss value
                wandb.log({"bimt_loss": bimt_loss.item() / self.bimt.scale})

        self.train_step_idx += 1
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self.model(x)
        loss = F.cross_entropy(logits, y)
        preds = torch.argmax(logits, dim=1)
        ## saves the best model based on this self.log
        self.log("val_loss", loss)
        val_acc = self.val_accuracy(preds, y)
        self.log("val_acc", val_acc)

        self.validation_step_losses_single_epoch.append(loss)
        self.validation_step_acc_single_epoch.append(val_acc)

    def on_validation_epoch_end(self):

        validation_loss_epoch_average = (
            torch.stack(self.validation_step_losses_single_epoch).mean().item()
        )
        validation_acc_epoch_average = (
            torch.stack(self.validation_step_acc_single_epoch).mean().item()
        )

        self.validation_step_acc_single_epoch.clear()  # free memory
        self.validation_step_losses_single_epoch.clear()  # free memory
        if self.wandb_log:
            wandb.log(
                {
                    "validation_loss": validation_loss_epoch_average,
                    "validation_acc": validation_acc_epoch_average,
                }
            )
        print(
            {
                    "validation_loss": validation_loss_epoch_average,
                    "validation_acc": validation_acc_epoch_average,
            }
        )

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(
            self.parameters(),
            lr=self.hyperparams.lr,
            weight_decay=self.hyperparams.weight_decay,
        )
        return optimizer

    ## required for trainer.validate(lightning_module) to work
    def val_dataloader(self):
        return self.validation_dataloader
