import math
from typing import List, Tuple

import torch
import torch.distributed as dist
import pytorch_lightning as pl
from torch import Tensor
from torch.nn import Identity
from torchvision.models import resnet50, resnet18

from lightly.models.modules import SimCLRProjectionHead
from lightly.models.utils import get_weight_decay_parameters
from lightly.transforms import SimCLRTransform
from lightly.utils.benchmarking import OnlineLinearClassifier
from lightly.utils.lars import LARS
from lightly.utils.scheduler import CosineWarmupScheduler

import math
from typing import List, Optional
from typing import Union

import torch.nn.functional as F

from model_zoo.utils import mlp
from model_zoo.base import BaseModel


class SimCLRModel(BaseModel):
    def __init__(
        self,
        lr: Union[float, str] = "auto",
        momentum: float = 0.9,
        weight_decay: float = 1e-6,
        temperature: float = 0.1,
        network: str = "resnet50",
        low_res: bool = False,
        weight_decay_trick: bool = True,
        datamodule: Optional[pl.LightningDataModule] = None,
    ) -> None:
        super().__init__(
            lr=lr,
            momentum=momentum,
            weight_decay=weight_decay,
            network=network,
            low_res=low_res,
            weight_decay_trick=weight_decay_trick,
            datamodule=datamodule,
        )
        self.temperature = temperature
        self.criterion = NTXEntLoss(temperature=self.temperature)

    def load_modules(self):
        if self.network == "resnet18":
            resnet = resnet18()
            self.projection_head = SimCLRProjectionHead(512, 512, 128)
            self.online_classifier = OnlineLinearClassifier(
                feature_dim=512, num_classes=self.num_classes
            )
        elif self.network == "resnet50":
            resnet = resnet50()
            self.projection_head = SimCLRProjectionHead()
            self.online_classifier = OnlineLinearClassifier(
                feature_dim=2048, num_classes=self.num_classes
            )
        resnet.fc = Identity()

        if self.low_res:
            input_channels = (
                self.datamodule.input_channels
                if hasattr(self.datamodule, "input_channels")
                else 3
            )
            resnet.conv1 = torch.nn.Conv2d(
                input_channels,
                64,
                kernel_size=(3, 3),
                stride=(1, 1),
                padding=(1, 1),
                bias=False,
            )
            resnet.maxpool = Identity()

        self.backbone = resnet

    def forward(self, x: Tensor) -> Tensor:
        return self.backbone(x)

    def training_step(
        self, batch: Tuple[List[Tensor], Tensor, List[str]], batch_idx: int
    ) -> Tensor:
        views, targets = batch[0], batch[1]
        features = self.forward(torch.cat(views)).flatten(start_dim=1)
        z = self.projection_head(features)
        z0, z1 = z.chunk(len(views))
        loss = self.criterion(z0, z1)
        self.log(
            "train_loss", loss, prog_bar=True, sync_dist=True, batch_size=len(targets)
        )

        cls_loss, cls_log = self.online_classifier.training_step(
            (features.detach(), targets.repeat(len(views))), batch_idx
        )
        self.log_dict(cls_log, sync_dist=True, batch_size=len(targets))
        return loss + cls_loss

    def configure_optimizers(self):
        # Set learning rate.
        if self.lr == "auto":
            # Square root learning rate scaling improves performance for small
            # batch sizes (<=2048) and few training epochs (<=200).
            if (
                self.datamodule.batch_size * self.trainer.world_size <= 2048
                and self.trainer.max_epochs <= 200
            ):
                lr = 0.075 * math.sqrt(
                    self.datamodule.batch_size * self.trainer.world_size
                )
            # Alternatively, linear scaling can be used for larger batches and longer training.
            # See Appendix B.1. in the SimCLR paper https://arxiv.org/abs/2002.05709
            else:
                lr = 0.3 * self.datamodule.batch_size * self.trainer.world_size / 256
        else:
            lr = self.lr

        # Don't use weight decay for batch norm, bias parameters, and classification
        # head to improve performance (for imagenet).
        if self.weight_decay_trick:
            params, params_no_weight_decay = get_weight_decay_parameters(
                [self.backbone, self.projection_head]
            )
            optimizer = LARS(
                [
                    {"name": "simclr", "params": params},
                    {
                        "name": "simclr_no_weight_decay",
                        "params": params_no_weight_decay,
                        "weight_decay": 0.0,
                    },
                    {
                        "name": "online_classifier",
                        "params": self.online_classifier.parameters(),
                        "weight_decay": 0.0,
                    },
                ],
                lr=lr,
                momentum=self.momentum,
                weight_decay=self.weight_decay,
            )
        else:
            optimizer = LARS(
                self.parameters(),
                lr=lr,
                momentum=self.momentum,
                weight_decay=self.weight_decay,
            )

        scheduler = {
            "scheduler": CosineWarmupScheduler(
                optimizer=optimizer,
                warmup_epochs=int(
                    self.trainer.estimated_stepping_batches
                    / self.trainer.max_epochs
                    * 10
                ),
                max_epochs=int(self.trainer.estimated_stepping_batches),
            ),
            "interval": "step",
        }
        return [optimizer], [scheduler]



class NTXEntLoss(torch.nn.Module):
    """Normalized temperature-scaled cross entropy loss.

    Introduced in the SimCLR paper :cite:`chen2020simple`.
    Also used in MoCo :cite:`he2020momentum`.

    Parameters
    ----------
    temperature : float, optional
        The temperature scaling factor.
        Default is 0.5.
    """

    def __init__(self, temperature: float = 0.5):
        super().__init__()
        self.temperature = temperature

    def forward(self, z_i, z_j):
        """Compute the NT-Xent loss.

        Parameters
        ----------
        z_i : torch.Tensor
            Latent representation of the first augmented view of the batch.
        z_j : torch.Tensor
            Latent representation of the second augmented view of the batch.

        Returns
        -------
        float
            The computed contrastive loss.
        """
        z_i = all_gather(z_i)
        z_j = all_gather(z_j)

        z = torch.cat([z_i, z_j], 0)
        N = z.size(0)

        features = F.normalize(z, dim=1)
        sim = torch.matmul(features, features.T) / self.temperature

        sim_i_j = torch.diag(sim, N // 2)
        sim_j_i = torch.diag(sim, -N // 2)

        positive_samples = torch.cat((sim_i_j, sim_j_i), dim=0)

        mask = torch.eye(N, dtype=bool).to(z_i.device)
        negative_samples = sim[~mask].reshape(N, -1)

        attraction = -positive_samples.mean()
        repulsion = torch.logsumexp(negative_samples, dim=1).mean()

        return attraction + repulsion


class GatherLayer(torch.autograd.Function):
    """Module to gather tensors from all process. Supports backward propagation."""

    @staticmethod
    def forward(ctx, x):
        output = [torch.zeros_like(x) for _ in range(dist.get_world_size())]
        dist.all_gather(output, x)
        return tuple(output)

    @staticmethod
    def backward(ctx, *grads):
        all_gradients = torch.stack(grads)
        dist.all_reduce(all_gradients)
        return all_gradients[dist.get_rank()]


def all_gather(x: torch.Tensor):
    """All-gather tensors from all processes if DDP is initialized."""
    if not (dist.is_available() and dist.is_initialized()):
        return x
    else:
        return torch.cat(GatherLayer.apply(x), dim=0)
