import torch
import os
import wandb
import lightning as L
import torch.nn.functional as F
from torch.utils.data import DataLoader
from pydantic import BaseModel, Extra
from torchmetrics import Accuracy
import torchvision.transforms as transforms
from ..configs import NesimConfig
from ..utils import make_folder_if_does_not_exist
from ..losses.nesim_loss import NesimLoss
from datasets import load_dataset
from ..bimt.loss import BIMTLoss, BIMTConfig
from ..losses.cross_layer_correlation.loss import (
    CrossLayerCorrelationLoss,
    CrossLayerCorrelationLossConfig,
)
from torch.optim.lr_scheduler import StepLR
from ..utils.json_stuff import load_json_as_dict
import os
from PIL import Image
from typing import Union
from torchvision.transforms.functional import InterpolationMode


"""
Taken from the original pytorch imagenet training recipe:
XXXX
"""


def get_module(use_v2):
    # We need a protected import to avoid the V2 warning in case just V1 is used
    if use_v2:
        import torchvision.transforms.v2

        return torchvision.transforms.v2
    else:
        import torchvision.transforms

        return torchvision.transforms


class ClassificationPresetTrain:
    # Note: this transform assumes that the input to forward() are always PIL
    # images, regardless of the backend parameter. We may change that in the
    # future though, if we change the output type from the dataset.
    def __init__(
        self,
        *,
        crop_size,
        mean=(0.485, 0.456, 0.406),
        std=(0.229, 0.224, 0.225),
        interpolation=InterpolationMode.BILINEAR,
        hflip_prob=0.5,
        auto_augment_policy=None,
        ra_magnitude=9,
        augmix_severity=3,
        random_erase_prob=0.0,
        backend="pil",
        use_v2=False,
    ):
        T = get_module(use_v2)

        transforms = []
        backend = backend.lower()
        if backend == "tensor":
            transforms.append(T.PILToTensor())
        elif backend != "pil":
            raise ValueError(f"backend can be 'tensor' or 'pil', but got {backend}")

        transforms.append(
            T.RandomResizedCrop(crop_size, interpolation=interpolation, antialias=True)
        )
        if hflip_prob > 0:
            transforms.append(T.RandomHorizontalFlip(hflip_prob))
        if auto_augment_policy is not None:
            if auto_augment_policy == "ra":
                transforms.append(
                    T.RandAugment(interpolation=interpolation, magnitude=ra_magnitude)
                )
            elif auto_augment_policy == "ta_wide":
                transforms.append(T.TrivialAugmentWide(interpolation=interpolation))
            elif auto_augment_policy == "augmix":
                transforms.append(
                    T.AugMix(interpolation=interpolation, severity=augmix_severity)
                )
            else:
                aa_policy = T.AutoAugmentPolicy(auto_augment_policy)
                transforms.append(
                    T.AutoAugment(policy=aa_policy, interpolation=interpolation)
                )

        if backend == "pil":
            transforms.append(T.PILToTensor())

        transforms.extend(
            [
                T.ToDtype(torch.float, scale=True)
                if use_v2
                else T.ConvertImageDtype(torch.float),
                T.Normalize(mean=mean, std=std),
            ]
        )
        if random_erase_prob > 0:
            transforms.append(T.RandomErasing(p=random_erase_prob))

        if use_v2:
            transforms.append(T.ToPureTensor())

        self.transforms = T.Compose(transforms)

    def __call__(self, img):
        return self.transforms(img)


class ClassificationPresetEval:
    def __init__(
        self,
        *,
        crop_size,
        resize_size=256,
        mean=(0.485, 0.456, 0.406),
        std=(0.229, 0.224, 0.225),
        interpolation=InterpolationMode.BILINEAR,
        backend="pil",
        use_v2=False,
    ):
        T = get_module(use_v2)
        transforms = []
        backend = backend.lower()
        if backend == "tensor":
            transforms.append(T.PILToTensor())
        elif backend != "pil":
            raise ValueError(f"backend can be 'tensor' or 'pil', but got {backend}")

        transforms += [
            T.Resize(resize_size, interpolation=interpolation, antialias=True),
            T.CenterCrop(crop_size),
        ]

        if backend == "pil":
            transforms.append(T.PILToTensor())

        transforms += [
            T.ToDtype(torch.float, scale=True)
            if use_v2
            else T.ConvertImageDtype(torch.float),
            T.Normalize(mean=mean, std=std),
        ]

        if use_v2:
            transforms.append(T.ToPureTensor())

        self.transforms = T.Compose(transforms)

    def __call__(self, img):
        return self.transforms(img)


"""
And now my stuff below
"""


def get_transforms_for_slice(slice_name: str):
    if slice_name == "train":
        my_transforms = ClassificationPresetTrain(crop_size=224)
    else:
        ## no image augmentations on validation set
        my_transforms = ClassificationPresetEval(crop_size=224)

    return my_transforms


class ConvertedImagenetDataset:
    """
    Faster alternative to the usual huggingface dataset.
    Loads stuff ~34x faster :)
    """

    def __init__(self, folder, slice_name="train"):
        assert slice_name in [
            "train",
            "validation",
        ], f"Invalid slice name for ImageNetDataset: {slice_name}"

        self.labels_filename = os.path.join(folder, "labels.json")
        assert os.path.exists(
            self.labels_filename
        ), f"Expected labels file to exist: {self.labels_filename}"

        self.labels = load_json_as_dict(self.labels_filename)
        self.image_folder = os.path.join(folder, "images/")
        assert os.path.exists(
            self.image_folder
        ), f"Expected images folder to exist: {self.image_folder}"
        self.transforms = get_transforms_for_slice(slice_name=slice_name)

    def __getitem__(self, idx):
        image_path = os.path.join(self.image_folder, f"{idx}.jpg")
        # assert os.path.exists(image_path), f'Expected Image to exist: {image_path}'
        label = self.labels[idx]

        image = Image.open(image_path).convert("RGB")

        return self.transforms(image), label

    def __len__(self):
        return len(self.labels)


class ImageNetDataset:
    """
    simple interface to convert dict to an x,y tuple
    """

    def __init__(
        self, slice_name="train", cache_dir="/research/datasets/imagenet_huggingface"
    ):
        assert slice_name in [
            "train",
            "validation",
        ], f"Invalid slice name for ImageNetDataset: {slice_name}"

        self.dataset = load_dataset("imagenet-1k", cache_dir=cache_dir)[slice_name]

        self.transforms = get_transforms_for_slice(slice_name=slice_name)

    def __getitem__(self, idx):
        item = self.dataset[idx]
        return self.transforms(item["image"].convert("RGB")), item["label"]

    def __len__(self):
        return len(self.dataset)


class ImageNetHyperParams(BaseModel, extra=Extra.forbid):
    lr: float
    batch_size: int
    weight_decay: float
    momentum: float
    save_checkpoint_every_n_steps: int
    apply_nesim_every_n_steps: int
    scheduler_step_size: int
    scheduler_gamma: float


class ImageNetLightningModule(L.LightningModule):
    def __init__(
        self,
        model,
        hyperparams: ImageNetHyperParams,
        nesim_config: NesimConfig,
        checkpoint_dir: str,
        train_dataset: ImageNetDataset,
        validation_dataset: ImageNetDataset,
        wandb_log: bool,
        nesim_device: str = "cuda:0",
        bimt_config: BIMTConfig = None,
        cross_layer_correlation_loss_config: CrossLayerCorrelationLossConfig = None,
    ) -> None:
        super().__init__()
        assert isinstance(hyperparams, ImageNetHyperParams)
        self.wandb_log = wandb_log
        self.model = model
        self.nesim_config = nesim_config
        self.hyperparams = hyperparams
        self.checkpoint_dir = checkpoint_dir

        ## if bimt config is not None then apply bimt loss
        if bimt_config is not None:
            self.bimt = BIMTLoss.from_config(
                config=bimt_config,
            )
            self.model = self.bimt.init_modules_for_training(model=self.model)
        else:
            self.bimt = None

        if cross_layer_correlation_loss_config is not None:
            self.cross_layer_correlation_loss = CrossLayerCorrelationLoss.from_config(
                model=self.model, config=cross_layer_correlation_loss_config
            )
        else:
            self.cross_layer_correlation_loss = None

        assert os.path.exists(
            self.checkpoint_dir
        ), f"Expected checkpoint_dir to exist: {checkpoint_dir}"
        assert isinstance(
            train_dataset, Union[ImageNetDataset, ConvertedImagenetDataset]
        )
        assert isinstance(
            validation_dataset, Union[ImageNetDataset, ConvertedImagenetDataset]
        )

        self.train_dataloader = DataLoader(
            train_dataset,
            batch_size=self.hyperparams.batch_size,
            shuffle=True,
            num_workers=4,
        )
        self.validation_dataloader = DataLoader(
            validation_dataset,
            batch_size=self.hyperparams.batch_size,
            shuffle=False,
            num_workers=4,
        )

        """
        Change this section to switch to a different dataset
        """
        self.val_accuracy = Accuracy(task="multiclass", num_classes=1000)
        """
        note for top 5 acc:
        top_k (int) – Number of highest probability or logit score predictions considered to find the correct label. 
        Only works when preds contain probabilities/logits.

        source: XXXX
        """
        self.val_accuracy_top_5 = Accuracy(task="multiclass", num_classes=1000, top_k=5)
        self.test_accuracy = Accuracy(task="multiclass", num_classes=1000)
        self.train_step_idx = 0

        """
        checkpoint_dir:
            best/
                model.pth
            all_checkpoints/
                train_step_0.pth
                train_step_10.pth
                train_step_20.pth
                ...
        """
        self.all_checkpoints_folder = os.path.join(self.checkpoint_dir, "all")
        self.best_checkpoint_folder = os.path.join(self.checkpoint_dir, "best")
        make_folder_if_does_not_exist(self.all_checkpoints_folder)
        make_folder_if_does_not_exist(self.best_checkpoint_folder)

        ## store losses
        self.validation_step_losses_single_epoch = []
        self.validation_step_acc_single_epoch = []
        self.validation_step_acc_top_5_single_epoch = []

        if nesim_config is not None:
            self.nesim_loss = NesimLoss(
                model=self.model, config=nesim_config, device=nesim_device
            )

    def save_checkpoint(self, filename):
        self.trainer.save_checkpoint(filename)
        print(f"[saved lightning checkpoint] {filename}")

    def training_step(self, batch, batch_idx):

        if self.train_step_idx % self.hyperparams.save_checkpoint_every_n_steps == 0:
            self.save_checkpoint(
                filename=os.path.join(
                    self.all_checkpoints_folder,
                    f"train_step_idx_{self.train_step_idx}.pth",
                )
            )

        x, y = batch
        logits = self.model(x)
        loss = F.cross_entropy(logits, y)

        if self.wandb_log:
            wandb.log(
                {"training_loss": loss.item(), "train/global_step": self.train_step_idx}
            )

        if self.train_step_idx % self.hyperparams.apply_nesim_every_n_steps == 0:
            nesim_loss_item = self.nesim_loss.compute(reduce_mean=True)

            if self.wandb_log:
                self.nesim_loss.wandb_log()

            if nesim_loss_item is not None:
                loss = loss + nesim_loss_item.to(loss.device)

        if self.bimt is not None:
            if self.bimt.scale is not None:
                bimt_loss = self.bimt.forward(model=self.model)
                loss = loss + bimt_loss

            if self.wandb_log:
                self.bimt.wandb_log(model=self.model)

        if self.cross_layer_correlation_loss is not None:
            cross_layer_corr_loss = self.cross_layer_correlation_loss.forward()
            if cross_layer_corr_loss is not None:
                loss = loss + cross_layer_corr_loss

            if self.wandb_log:
                self.cross_layer_correlation_loss.wandb_log()

        self.train_step_idx += 1
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self.model(x)
        loss = F.cross_entropy(logits, y)
        preds = torch.argmax(logits, dim=1)
        ## saves the best model based on this self.log
        self.log("val_loss", loss)
        val_acc = self.val_accuracy(preds, y)
        self.log("val_acc", val_acc)

        """
        val_accuracy_top_5 only works when y is probabilities/logits.

        so we construct a y_probs tensor of shape batch_size, num_output_logits
        and set the values corresponding to y[batch_idx] to 1 and rest to 0
        """
        val_acc_top_5 = self.val_accuracy_top_5(logits, y)
        self.log("val_acc_top_5", val_acc_top_5)

        self.validation_step_losses_single_epoch.append(loss)
        self.validation_step_acc_single_epoch.append(val_acc)
        self.validation_step_acc_top_5_single_epoch.append(val_acc_top_5)

    def on_validation_epoch_end(self):

        validation_loss_epoch_average = (
            torch.stack(self.validation_step_losses_single_epoch).mean().item()
        )
        validation_acc_epoch_average = (
            torch.stack(self.validation_step_acc_single_epoch).mean().item()
        )
        validation_acc_top_5_epoch_average = (
            torch.stack(self.validation_step_acc_top_5_single_epoch).mean().item()
        )

        self.validation_step_acc_single_epoch.clear()  # free memory
        self.validation_step_losses_single_epoch.clear()  # free memory
        self.validation_step_acc_top_5_single_epoch.clear()  # free memory

        if self.wandb_log:
            wandb.log(
                {
                    "validation_loss": validation_loss_epoch_average,
                    "validation_acc": validation_acc_epoch_average,
                    "validation_acc_top_5": validation_acc_top_5_epoch_average,
                }
            )

    def configure_optimizers(self):

        optimizer = torch.optim.SGD(
            self.parameters(),
            lr=self.hyperparams.lr,
            weight_decay=self.hyperparams.weight_decay,
            momentum=self.hyperparams.momentum,
        )

        """
        Original pytorch recipe source:
        - XXXX
        - XXXX
        """
        lr_scheduler = {
            # 'scheduler': StepLR(
            #     optimizer,
            #     step_size=30,
            #     gamma=0.1
            # ),
            "scheduler": StepLR(
                optimizer,
                step_size=self.hyperparams.scheduler_step_size,
                gamma=self.hyperparams.scheduler_gamma,
            ),
            "monitor": "val_loss",  # Metric to monitor for reducing LR (use 'val_loss' or any other metric)
            "interval": "epoch",  # Apply the scheduler after each epoch
            "frequency": 1,  # apply scheduler
        }

        return [optimizer], [lr_scheduler]

    ## required for trainer.validate(lightning_module) to work
    def val_dataloader(self):
        return self.validation_dataloader
