"""
Based on bolts
"""

from contextlib import contextmanager
from typing import Any, Dict, Optional, Sequence, Tuple, Union

import math
import torch
from pytorch_lightning import Callback, LightningModule, Trainer
from pytorch_lightning.utilities import rank_zero_warn
from torch import Tensor, nn
from torch.nn import functional as F
from torch.optim import Optimizer
import torchmetrics
from models.lie_ssl.shapes_model import SimCLRLieBottleneckModule, SimCLRLieModule

from pl_bolts.models.self_supervised.evaluator import SSLEvaluator


class SSLOnlineEvaluator(Callback):  # pragma: no cover
    """Attaches a MLP for fine-tuning using the standard self-supervised protocol.

    Example::

        # your datamodule must have 2 attributes
        dm = DataModule()
        dm.num_classes = ... # the num of classes in the datamodule
        dm.name = ... # name of the datamodule (e.g. ImageNet, STL10, CIFAR10)

        # your model must have 1 attribute
        model = Model()
        model.z_dim = ... # the representation dim

        online_eval = SSLOnlineEvaluator(
            z_dim=model.z_dim
        )
    """

    def __init__(
        self,
        z_dim: Optional[int] = None,
        num_classes: Optional[int] = None,
    ):
        """
        Args:
            z_dim: representation dimension
            num_classes: number of classes
        """
        super().__init__()

        self.z_dim = z_dim

        self.optimizer: Optional[Optimizer] = None
        self.online_evaluator: Optional[SSLEvaluator] = None
        self.num_classes: Optional[int] = num_classes
        self.val_loader_names = None

        self._recovered_callback_state: Optional[Dict[str, Any]] = None

    def setup(
        self, trainer: Trainer, pl_module: LightningModule, stage: Optional[str] = None
    ) -> None:
        if self.num_classes is None:
            self.num_classes = trainer.datamodule.num_classes

        if self.z_dim is None:
            self.z_dim = pl_module.z_dim

        self.train_loader_names = pl_module.datamodule.train_loader_names
        for train_loader_name in self.train_loader_names:
            setattr(
                pl_module,
                f"online_{train_loader_name}_top_1_accuracy",
                torchmetrics.Accuracy(),
            )

        self.val_loader_names = pl_module.datamodule.val_loader_names
        for val_loader_name in self.val_loader_names:
            setattr(
                pl_module,
                f"online_{val_loader_name}_top_1_accuracy",
                torchmetrics.Accuracy(),
            )

    def on_pretrain_routine_start(
        self, trainer: Trainer, pl_module: LightningModule
    ) -> None:
        # must move to device after setup, as during setup, pl_module is still on cpu
        self.online_evaluator = torch.nn.Linear(self.z_dim, self.num_classes).to(
            pl_module.device
        )

        # distributed stuff copied from Pl.Bolts
        accel = (
            trainer.accelerator_connector
            if hasattr(trainer, "accelerator_connector")
            else trainer._accelerator_connector
        )
        if accel.is_distributed:
            from torch.nn.parallel import DistributedDataParallel as DDP

            self.online_evaluator = DDP(
                self.online_evaluator, device_ids=[pl_module.device]
            )
        else:
            rank_zero_warn(
                "Does not support this type of distributed accelerator. The online evaluator will not sync."
            )

        self.optimizer = torch.optim.Adam(self.online_evaluator.parameters(), lr=1e-4)

        if self._recovered_callback_state is not None:
            self.online_evaluator.load_state_dict(
                self._recovered_callback_state["state_dict"]
            )
            self.optimizer.load_state_dict(
                self._recovered_callback_state["optimizer_state"]
            )

    def to_device(
        self, batch: Sequence, device: Union[str, torch.device]
    ) -> Tuple[Tensor, Tensor, Tensor]:
        x1, x2, y, fov1, fov2, delta = batch
        _, _, x1_online = x1
        _, _, x2_online = x2

        x1_online = x1_online.to(device)
        x2_online = x2_online.to(device)
        y = y.to(device)

        return x1_online, x2_online, y

    def shared_step(
        self,
        pl_module: LightningModule,
        batch: Sequence,
        stage: str = "train",
    ):
        with torch.no_grad():
            with set_training(pl_module, False):
                x1_online, x2_online, y = self.to_device(batch, pl_module.device)
                reps = pl_module.online_probe_forward(
                    x1_online, x2_online, pl_module.g_matrix[stage], stage
                )
        mlp_logits = self.online_evaluator(reps)
        if self._is_lie_module(pl_module):
            # mlp_logits is [batch_size, 2 or 4, num_classes]
            # Log softmax over class probabilies for each neighbor
            # out is [batch_size, 2 or 4, num_classes]
            out = F.log_softmax(mlp_logits, dim=-1)
            # log proba = log mean over 4 probas
            # out is [batch_size, num_classes]
            out = torch.logsumexp(out, dim=1) + math.log(1 / float(reps.size(1)))
            mlp_loss = F.nll_loss(out, y)
        else:
            # mlp_logits is [batch_size, 2, num_classes]
            # mlp_logits turned to [2*batch_size, num_classes]
            batch_size = y.shape[0]
            mlp_logits = mlp_logits.view(2 * batch_size, -1)
            # y repeated and turned to [2*batch_size]
            y = torch.cat([y.unsqueeze(1), y.unsqueeze(1)], dim=1).view(2 * batch_size)
            mlp_loss = F.cross_entropy(mlp_logits, y)
            out = F.softmax(mlp_logits, -1)

        accuracy_metric_name = f"online_{stage}_top_1_accuracy"
        accuracy_metric = getattr(pl_module, accuracy_metric_name)
        accuracy_metric(out, y)
        # TODO: No batchsize included? It is inferred?
        pl_module.log(
            accuracy_metric_name,
            accuracy_metric,
            on_step=False,
            on_epoch=True,
            # loader names are used instead
            add_dataloader_idx=False,
        )
        pl_module.log(
            f"online_{stage}_loss",
            mlp_loss,
            # loader names are used instead
            add_dataloader_idx=False,
        )

        return mlp_loss

    def _is_lie_module(self, pl_module: LightningModule) -> bool:
        if "lie" in pl_module.__class__.__name__.lower():
            return True
        return False

    def on_train_batch_end(
        self,
        trainer: Trainer,
        pl_module: LightningModule,
        outputs: Sequence,
        batch: Sequence,
        batch_idx: int,
        dataloader_idx: int,
    ) -> None:
        for loader_name in batch:
            mlp_loss = self.shared_step(
                pl_module, batch[loader_name], stage=loader_name
            )
            # update finetune weights
            mlp_loss.backward()
            self.optimizer.step()
            self.optimizer.zero_grad()

    def on_validation_batch_end(
        self,
        trainer: Trainer,
        pl_module: LightningModule,
        outputs: Sequence,
        batch: Sequence,
        batch_idx: int,
        dataloader_idx: int,
    ) -> None:
        stage = self.val_loader_names[dataloader_idx]
        self.shared_step(pl_module, batch, stage=stage)

    def on_save_checkpoint(
        self, trainer: Trainer, pl_module: LightningModule, checkpoint: Dict[str, Any]
    ) -> dict:
        return {
            "state_dict": self.online_evaluator.state_dict(),
            "optimizer_state": self.optimizer.state_dict(),
        }

    def on_load_checkpoint(
        self,
        trainer: Trainer,
        pl_module: LightningModule,
        callback_state: Dict[str, Any],
    ) -> None:
        self._recovered_callback_state = callback_state


@contextmanager
def set_training(module: nn.Module, mode: bool):
    """Context manager to set training mode.

    When exit, recover the original training mode.
    Args:
        module: module to set training mode
        mode: whether to set training mode (True) or evaluation mode (False).
    """
    original_mode = module.training

    try:
        module.train(mode)
        yield module
    finally:
        module.train(original_mode)


class SSLOnlineEvaluatorBackpropLosses(SSLOnlineEvaluator):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def unpack_logits(self, logits: Tensor) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]:
        z1_logits = logits[:, 0, :]
        z2_logits = logits[:, 1, :]
        if logits.shape[1] > 2:
            g_z1_logits = logits[:, 2, :]
            g_z2_logits = logits[:, 3, :]
        else:
            g_z1_logits, g_z2_logits  = None, None
        return z1_logits, z2_logits, g_z1_logits, g_z2_logits

    def shared_step(
        self,
        pl_module: LightningModule,
        batch: Sequence,
        stage: str = "train",
    ):
        with torch.no_grad():
            with set_training(pl_module, False):
                x1_online, x2_online, y = self.to_device(batch, pl_module.device)
                y = y.long()
                reps = pl_module.online_probe_forward(
                    x1_online, x2_online, pl_module.g_matrix[stage], stage
                )
        
        # mlp_logits is [batch_size, 2 or 4, num_classes]
        # 2 if stage is canonical
        mlp_logits = self.online_evaluator(reps)
        # each is of shape [batch_size, num_classes]
        z1_logits, z2_logits, g_z1_logits, g_z2_logits = self.unpack_logits(mlp_logits)

        z1_loss = F.cross_entropy(z1_logits, y)
        z2_loss = F.cross_entropy(z2_logits, y)

        g_loss = 0.0
        if type(g_z1_logits) is Tensor:
            g_loss += F.cross_entropy(g_z1_logits, y)
            g_loss += F.cross_entropy(g_z2_logits, y)
        loss = z1_loss + z2_loss + g_loss

        accuracy_metric_name = f"online_{stage}_top_1_accuracy"
        accuracy_metric = getattr(pl_module, accuracy_metric_name)

        # use only z_1 and z_2 for acuracy
        z1_pred = F.softmax(z1_logits, dim=-1)
        z2_pred = F.softmax(z2_logits, dim=-1)
        accuracy_metric(z1_pred, y)
        accuracy_metric(z2_pred, y)

        pl_module.log(
            accuracy_metric_name,
            accuracy_metric,
            on_step=False,
            on_epoch=True,
            # loader names are used instead
            add_dataloader_idx=False,
        )
        pl_module.log(
            f"online_{stage}_loss",
            loss,
            # loader names are used instead
            add_dataloader_idx=False,
        )

        return loss

        