import torch
import os
import wandb
import lightning as L
import torch.nn.functional as F
from torch.utils.data import DataLoader
from pydantic import BaseModel, Extra
from torchmetrics import Accuracy
import torchvision.transforms as transforms
from ..configs import NesimConfig
from ..utils import make_folder_if_does_not_exist
from ..losses.nesim_loss import NesimLoss
from datasets import load_dataset
from ..bimt.loss import BIMTLoss, BIMTConfig


class Cifar100Dataset:
    """
    simple interface to convert dict to an x,y tuple
    """

    def __init__(self, slice_name="train", cache_dir="./huggingface_datasets"):
        self.dataset = load_dataset("cifar100", cache_dir=cache_dir)[slice_name]

        self.transforms = transforms.Compose(
            [
                # transforms.Resize((224,224)),
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.RandomRotation(20),
                transforms.RandomAffine(degrees=5, translate=(0.1, 0.1)),
                transforms.ToTensor(),
                # transforms.Normalize(
                #     mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
                # ),
                transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
            ]
        )

    def __getitem__(self, idx):
        item = self.dataset[idx]

        # return self.transforms(item['image']), item['label']
        ## for cifar10 ?
        # return self.transforms(item["img"]), item["coarse_label"]
        ## for cifar100 ?
        return self.transforms(item["img"]), item["coarse_label"]

    def __len__(self):
        return len(self.dataset)


class Cifar100HyperParams(BaseModel, extra=Extra.forbid):
    lr: float
    batch_size: int
    weight_decay: float
    save_checkpoint_every_n_steps: int
    apply_nesim_every_n_steps: int


class Cifar100LightningModule(L.LightningModule):
    def __init__(
        self,
        model,
        hyperparams: Cifar100HyperParams,
        nesim_config: NesimConfig,
        checkpoint_dir: str,
        train_dataset: Cifar100Dataset,
        validation_dataset: Cifar100Dataset,
        wandb_log: bool,
        nesim_device: str = "cuda:0",
        bimt_config: BIMTConfig = None,
    ) -> None:
        super().__init__()
        assert isinstance(hyperparams, Cifar100HyperParams)
        self.wandb_log = wandb_log
        self.model = model
        self.nesim_config = nesim_config
        self.hyperparams = hyperparams
        self.checkpoint_dir = checkpoint_dir

        ## if bimt config is not None then apply bimt loss
        if bimt_config is not None:
            self.bimt = BIMTLoss.from_config(
                config=bimt_config,
            )
            self.model = self.bimt.init_modules_for_training(model=self.model)
        else:
            self.bimt = None

        assert os.path.exists(
            self.checkpoint_dir
        ), f"Expected checkpoint_dir to exist: {checkpoint_dir}"
        assert isinstance(train_dataset, Cifar100Dataset)
        assert isinstance(validation_dataset, Cifar100Dataset)

        self.train_dataloader = DataLoader(
            train_dataset, batch_size=self.hyperparams.batch_size, shuffle=True
        )
        self.validation_dataloader = DataLoader(
            validation_dataset, batch_size=self.hyperparams.batch_size, shuffle=False
        )

        """
        Change this section to switch to a different dataset
        """
        self.val_accuracy = Accuracy(task="multiclass", num_classes=100, top_k=1)
        self.test_accuracy = Accuracy(task="multiclass", num_classes=100, top_k=1)
        self.train_step_idx = 0

        """
        checkpoint_dir:
            best/
                model.pth
            all_checkpoints/
                train_step_0.pth
                train_step_10.pth
                train_step_20.pth
                ...
        """
        self.all_checkpoints_folder = os.path.join(self.checkpoint_dir, "all")
        self.best_checkpoint_folder = os.path.join(self.checkpoint_dir, "best")
        make_folder_if_does_not_exist(self.all_checkpoints_folder)
        make_folder_if_does_not_exist(self.best_checkpoint_folder)

        ## store losses
        self.validation_step_losses_single_epoch = []
        self.validation_step_acc_single_epoch = []

        if nesim_config is not None:
            self.nesim_loss = NesimLoss(
                model=self.model, config=nesim_config, device=nesim_device
            )

    def save_checkpoint(self, filename):
        torch.save(self.state_dict(), filename)
        print(f"[saved lightning checkpoint] {filename}")

    def training_step(self, batch, batch_idx):
        self.train_step_idx += 1
        if self.train_step_idx % self.hyperparams.save_checkpoint_every_n_steps == 0:
            self.save_checkpoint(
                filename=os.path.join(
                    self.all_checkpoints_folder,
                    f"train_step_idx_{self.train_step_idx}.pth",
                )
            )

        x, y = batch
        logits = self.model(x)
        loss = F.cross_entropy(logits, y)

        if self.wandb_log:
            wandb.log({"training_loss": loss.item()})

        if self.train_step_idx % self.hyperparams.apply_nesim_every_n_steps == 0:
            nesim_loss_item = self.nesim_loss.compute(reduce_mean=True)

            if self.wandb_log:
                self.nesim_loss.wandb_log()

            if nesim_loss_item is not None:
                loss = loss + nesim_loss_item.to(loss.device)

        if self.bimt is not None:
            if self.bimt.scale is not None:
                bimt_loss = self.bimt.forward(model=self.model)
                loss = loss + bimt_loss

            if self.wandb_log:
                self.bimt.wandb_log(model=self.model)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self.model(x)
        loss = F.cross_entropy(logits, y)
        preds = torch.argmax(logits, dim=1)
        ## saves the best model based on this self.log
        self.log("val_loss", loss)
        val_acc = self.val_accuracy(preds, y)
        self.log("val_acc", val_acc)

        self.validation_step_losses_single_epoch.append(loss)
        self.validation_step_acc_single_epoch.append(val_acc)

    def on_validation_epoch_end(self):
        validation_loss_epoch_average = (
            torch.stack(self.validation_step_losses_single_epoch).mean().item()
        )
        validation_acc_epoch_average = (
            torch.stack(self.validation_step_acc_single_epoch).mean().item()
        )

        self.validation_step_acc_single_epoch.clear()  # free memory
        self.validation_step_losses_single_epoch.clear()  # free memory
        if self.wandb_log:
            wandb.log(
                {
                    "validation_loss": validation_loss_epoch_average,
                    "validation_acc": validation_acc_epoch_average,
                }
            )

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(
            self.parameters(),
            lr=self.hyperparams.lr,
            weight_decay=self.hyperparams.weight_decay,
        )
        return optimizer

    ## required for trainer.validate(lightning_module) to work
    def val_dataloader(self):
        return self.validation_dataloader
