import torch
import torch.nn as nn
import torch.nn.functional as F
import wandb


def create_projection_head(intermediate_dim, dim):
    if dim is None:
        return nn.Identity()
    return nn.Sequential(
        nn.Linear(
            intermediate_dim,
            intermediate_dim,
        ),
        nn.ReLU(),
        nn.Linear(intermediate_dim, dim),
    )


class SupConLoss(nn.Module):
    """Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf.
    It also supports the unsupervised contrastive loss in SimCLR
    Code taken from https://github.com/HobbitLong/SupContrast"""

    def __init__(
        self,
        *args,
        temperature=0.07,
        contrast_mode="all",
        base_temperature=0.07,
        **kwargs,
    ):
        super(SupConLoss, self).__init__()
        self.temperature = temperature
        self.contrast_mode = contrast_mode
        self.base_temperature = base_temperature

    def forward(
        self,
        features,
        labels=None,
        mask=None,
        distance="cos",
        positive_free_denominator=False,
    ):
        """Compute loss for model. If both `labels` and `mask` are None,
        it degenerates to SimCLR unsupervised loss:
        https://arxiv.org/pdf/2002.05709.pdf

        Args:
            features: hidden vector of shape [bsz, n_views, ...].
            labels: ground truth of shape [bsz].
            mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j
                has the same class as sample i. Can be asymmetric.
        Returns:
            A loss scalar.
        """
        device = torch.device("cuda") if features.is_cuda else torch.device("cpu")

        if len(features.shape) < 3:
            raise ValueError(
                "`features` needs to be [bsz, n_views, ...],"
                "at least 3 dimensions are required"
            )
        if len(features.shape) > 3:
            features = features.view(features.shape[0], features.shape[1], -1)

        if distance == "cos":
            features = F.normalize(features, dim=-1)

        batch_size = features.shape[0]
        if labels is not None and mask is not None:
            raise ValueError("Cannot define both `labels` and `mask`")
        elif labels is None and mask is None:
            mask = torch.eye(batch_size, dtype=torch.float32).to(device)
        elif labels is not None:
            labels = labels.contiguous().view(-1, 1)
            if labels.shape[0] != batch_size:
                raise ValueError("Num of labels does not match num of features")
            mask = torch.eq(labels, labels.T).float().to(device)
        else:
            mask = mask.float().to(device)

        contrast_count = features.shape[1]
        contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0)
        if self.contrast_mode == "one":
            anchor_feature = features[:, 0]
            anchor_count = 1
        elif self.contrast_mode == "all":
            anchor_feature = contrast_feature
            anchor_count = contrast_count
        else:
            raise ValueError("Unknown mode: {}".format(self.contrast_mode))

        # compute logits
        if distance == "cos":
            anchor_dot_contrast = torch.div(
                torch.matmul(anchor_feature, contrast_feature.T), self.temperature
            )
        elif distance == "mse":
            anchor_dot_contrast = -torch.div(
                torch.cdist(anchor_feature, contrast_feature) + 1.0, self.temperature
            )
        else:
            raise NotImplementedError(f"Distance metric {distance} not implemented")

        # for numerical stability
        logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
        logits = anchor_dot_contrast - logits_max.detach()

        # tile mask
        mask = mask.repeat(anchor_count, contrast_count)
        # mask-out self-contrast cases
        logits_mask = torch.scatter(
            torch.ones_like(mask),
            1,
            torch.arange(batch_size * anchor_count).view(-1, 1).to(device),
            0,
        )
        mask = mask * logits_mask

        # compute log_prob
        exp_logits = torch.exp(logits) * logits_mask
        if not positive_free_denominator:
            log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))
        else:
            negative_exp_logits = exp_logits * (1 - mask)
            log_prob = logits - torch.log(
                negative_exp_logits.sum(1, keepdim=True) + exp_logits
            )

        # compute mean of log-likelihood over positive
        # modified to handle edge cases when there is no positive pair
        # for an anchor point.
        # Edge case e.g.:-
        # features of shape: [4,1,...]
        # labels:            [0,1,1,2]
        # loss before mean:  [nan, ..., ..., nan]
        mask_pos_pairs = mask.sum(1)
        mask_pos_pairs = torch.where(mask_pos_pairs < 1e-6, 1, mask_pos_pairs)
        mean_log_prob_pos = (mask * log_prob).sum(1) / mask_pos_pairs

        # loss
        loss = -(self.temperature / self.base_temperature) * mean_log_prob_pos
        loss = loss.view(anchor_count, batch_size).mean()

        return loss


class ContextContrastingLoss(nn.Module):

    def __init__(
        self,
        projection_dim=128,
        backbone_out_dim=512,
        similarity_metric="cos",
        content_loss_mode="none",
        content_loss_weight=1.0,
        content_annealing_start=0.0,
        content_annealing_epochs=None,
        content_annealing_schedule="linear_denominator",
        align_contexts=False,
        n_context_augs=4,
        positive_free_denominator=False,
        compact_clusters=False,
        *args,
        **kwargs,
    ):
        super(ContextContrastingLoss, self).__init__()
        if content_loss_mode not in [
            "none",
            "separate",
            "shifting",
            "linear_projection",
            "content_only",
            "simclr",
        ]:
            raise ValueError(f"Invalid content loss mode: {content_loss_mode}")
        self.content_loss_mode = content_loss_mode
        self.context_projection_head = create_projection_head(
            backbone_out_dim, projection_dim
        )
        if self.content_loss_mode in ["separate", "content_only", "simclr", "none"]:
            self.content_projection_head = create_projection_head(
                backbone_out_dim, projection_dim
            )
        elif self.content_loss_mode == "linear_projection":
            self.content_projection_head = torch.nn.ModuleList(
                [
                    nn.Linear(projection_dim, projection_dim)
                    for _ in range(n_context_augs)
                ]
            )
        elif self.content_loss_mode == "shifting":
            self.content_projection_head = torch.nn.ModuleList(
                [
                    nn.BatchNorm1d(projection_dim, affine=False)
                    for _ in range(n_context_augs)
                ]
            )
        self.align_contexts = align_contexts
        if self.align_contexts:
            self.alignment_layers = torch.nn.ModuleList(
                [
                    nn.Sequential(
                        # nn.LayerNorm(backbone_out_dim),
                        nn.BatchNorm1d(backbone_out_dim),
                    )
                    for _ in range(n_context_augs)
                ]
            )
        self.compact_clusters = compact_clusters
        if self.compact_clusters:
            self.register_buffer(
                "context_centers", torch.zeros(n_context_augs, backbone_out_dim)
            )
        self.criterion = SupConLoss(
            *args,
            **kwargs,
        )
        self.n_context_augs = n_context_augs
        self.positive_free_denominator = positive_free_denominator
        self.content_loss_weight = content_loss_weight
        self.content_annealing_start = content_annealing_start
        self.content_annealing_epochs = content_annealing_epochs
        self.content_annealing_schedule = content_annealing_schedule
        self.distance = similarity_metric

    def get_projection(self, z, context_labels, **kwargs):
        context_proj = self.context_projection_head(z)
        content_proj = self.get_content_projection(z, context_proj, context_labels)
        return {"content": content_proj, "context": context_proj}

    def get_content_projection(self, z, z_context, context_labels):
        if self.content_loss_mode in ["separate", "content_only", "simclr", "none"]:
            z_content = self.content_projection_head(z)
        elif self.content_loss_mode in ["shifting", "linear_projection"]:
            # Get unique context labels
            unique_labels = torch.unique(context_labels).long()
            # Apply projection head to corresponding context
            z_content = torch.zeros_like(z_context)
            for i in unique_labels:
                curr_z_context = z_context[context_labels == i]
                if len(curr_z_context.shape) == 3:
                    n, views, dim = curr_z_context.shape
                    curr_z_context = curr_z_context.reshape(-1, dim)
                    z_content[context_labels == i] = self.content_projection_head[i](
                        curr_z_context
                    ).reshape(n, views, dim)
                else:
                    z_content[context_labels == i] = self.content_projection_head[i](
                        curr_z_context
                    )
        return z_content

    def get_alignment_loss(self, z, context_labels):
        if not self.align_contexts:
            return 0.0
        # Get unique context labels
        unique_labels = torch.unique(context_labels).long()
        # Apply projection head to corresponding context
        z_aligned = torch.zeros_like(z)
        for i in unique_labels:
            curr_z = F.normalize(z[context_labels == i], dim=-1).type_as(z_aligned)
            if len(curr_z.shape) == 3:
                n, views, dim = curr_z.shape
                curr_z = curr_z.reshape(-1, dim)
                z_aligned[context_labels == i] = (
                    self.alignment_layers[i](curr_z)
                    .reshape(n, views, dim)
                    .type_as(z_aligned)
                )
            else:
                z_aligned[context_labels == i] = self.alignment_layers[i](
                    curr_z
                ).type_as(z_aligned)
        z_aligned = torch.cat(z_aligned.chunk(2, dim=0), dim=1)
        loss = (z_aligned[:, None] - z_aligned[:, :, None]).pow(2).mean()
        return loss

    def compactness_loss(self, z):
        if self.training:
            with torch.no_grad():
                if torch.all(
                    self.context_centers
                    == torch.zeros(z.shape[-2], z.shape[-1]).type_as(z)
                ):
                    self.context_centers = z.mean(dim=0)
                else:
                    # Update running mean
                    self.context_centers = 0.9 * self.context_centers + 0.1 * z.mean(
                        dim=0
                    )
        z_norm = F.normalize(z, dim=-1)
        center_norm = F.normalize(self.context_centers, dim=-1)
        intra_cluster_dist = (1 - (z_norm * center_norm[None]).sum(dim=-1)).mean()
        inter_cluster_dist = 1 + (z_norm[..., None, :] * center_norm[None, None]).sum(
            dim=-1
        )
        inter_cluster_dist = inter_cluster_dist * (
            1 - torch.eye(inter_cluster_dist.shape[-1]).type_as(inter_cluster_dist)
        )
        inter_cluster_dist = inter_cluster_dist.mean()
        loss = intra_cluster_dist + 1.0 / (1e-6 + inter_cluster_dist)
        return loss

    def forward(self, z, context_labels, logging_prefix="train/", epoch=None):
        context_z = self.context_projection_head(z)
        context_loss = self.criterion(
            context_z,
            labels=context_labels,
            distance=self.distance,
            positive_free_denominator=self.positive_free_denominator,
        )
        loss = context_loss
        if self.compact_clusters:
            compactness_loss = self.compactness_loss(z)
            loss += compactness_loss
        if self.content_loss_mode != "none":
            content_z = self.get_content_projection(z, context_z, context_labels)
            # Reassemble sample independent of context
            content_z = torch.cat(content_z.chunk(2, dim=0), dim=1)
            content_loss = self.criterion(
                content_z,
                distance="cos",
                positive_free_denominator=self.content_loss_mode == "simclr",
            )
            annealing_coef = 1.0
            if self.content_annealing_epochs is not None and epoch is not None:
                if self.content_annealing_schedule == "linear_denominator":
                    annealing_coef = 1.0 / max(
                        1.0, self.content_annealing_epochs - epoch
                    )
                elif self.content_annealing_schedule == "linear":
                    annealing_coef = 1.0
                    if epoch < self.content_annealing_epochs:
                        annealing_coef = torch.linspace(
                            self.content_annealing_start,
                            1,
                            self.content_annealing_epochs,
                        )[epoch]
                elif self.content_annealing_schedule == "none":
                    annealing_coef = 1.0
                else:
                    raise ValueError(
                        f"Invalid content annealing schedule: {self.content_annealing_schedule}"
                    )
            # Get alignment loss
            alignment_loss = self.get_alignment_loss(z, context_labels)

            loss = (
                annealing_coef
                * (self.content_loss_weight * content_loss + alignment_loss)
                + loss
            )
            if self.align_contexts:
                loss = loss / 3.0
            else:
                loss = loss / 2.0
            if self.content_loss_mode == "content_only":
                loss = content_loss
        if wandb.run is not None and logging_prefix is not None:
            with torch.no_grad():
                log = {}
                if self.content_loss_mode != "content_only":
                    log[f"{logging_prefix}context_loss"] = context_loss
                if self.content_loss_mode != "none":
                    if self.content_loss_mode != "content_only":
                        log["loss_annealing_coef"] = annealing_coef
                    log[f"{logging_prefix}content_loss"] = content_loss
                    log[f"{logging_prefix}context_loss"] = context_loss
                    if self.compact_clusters:
                        log[f"{logging_prefix}compactness_loss"] = compactness_loss
                    if self.align_contexts:
                        log[f"{logging_prefix}alignment_loss"] = alignment_loss
                    log[f"{logging_prefix}avg_content_context_loss"] = 0.5 * (
                        content_loss + context_loss
                    )
                else:
                    log[f"{logging_prefix}loss"] = loss
                wandb.log(
                    log,
                    commit=False,
                )

        return loss
