from typing import Optional

import torch
import torch.nn as nn
from torch import Tensor
import pdb
import torch.nn.functional as F

class SupConLossReg(nn.Module):
    # Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf.
    # The proposed Weighted Supervised Contrastive Loss for the regression task
    def __init__(
        self,
        temperature: float = 0.07,
        base_temperature: float = 0.07,
        gamma1: int = 2,
        gamma2: int = 2,
        threshold: float = 0.8,
    ):
        """Weighted Supervised Contrastive Loss initialization.

        Args:
            temperature (float, optional): The hyperparameter of the weighted supervised
                contrastive loss. Defaults to 0.07.
            base_temperature (float, optional): The hyperparameter of the weighted supervised
                contrastive loss. Defaults to 0.07.
            gamma1 (int, optional): The hyperparameter of the weighted supervised contrastive
                loss. Defaults to 2.
            gamma2 (int, optional): The hyperparameter of the weighted supervised contrastive
                loss. Defaults to 2.
            threshold (float, optional): The hyperparameter of the weighted supervised
                contrastive loss. Defaults to 0.8.
        """
        super(SupConLossReg, self).__init__()
        self.temperature = temperature
        self.base_temperature = base_temperature
        self.gamma1 = gamma1
        self.gamma2 = gamma2
        self.threshold = threshold

    def forward(
            self, features: Tensor, gamma: int, soft_targets: Tensor, labels: Optional[Tensor] = None, mask: Optional[Tensor] = None
    ):
        """Compute the supervised contrastive loss for model.

        Args:
            features (Tensor): hidden vector of shape [bsz, n_views, ...].
            labels (Optional[Tensor], optional): ground truth of shape [bsz].
            mask (Optional[Tensor], optional): contrastive mask of
                shape [bsz, bsz], mask_{i,j}=1 if sample j has the same
                class as sample i. Can be asymmetric.
        Returns:
            A loss scalar.
        """
        device = torch.device("cuda") if features.is_cuda else torch.device("cpu")

        contrast_count = 1
        contrast_feature_smiles = features[:, 1, :]
        contrast_feature_graph = features[:, 0, :]

        ############################anchor graph###############################
        anchor_feature = contrast_feature_graph
        anchor_count = 1

        # anchor graph contrast SMILES
        batch_size = features.shape[0]
        labels = labels.contiguous().view(-1, 1)
        if labels.shape[0] != batch_size:
            raise ValueError("Num of labels does not match num of features")

        anchor_dot_contrast = torch.div(
            torch.matmul(anchor_feature, contrast_feature_smiles.T), self.temperature
        )

        # for numerical stability
        logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
        logits = anchor_dot_contrast - logits_max.detach()

        # calculate the distance between two samples.
        weight_mask = torch.sqrt(
            (torch.pow(labels.repeat(1, batch_size) - labels.repeat(1, batch_size).T, 2))
        )

        soft_targets_norm = F.normalize(soft_targets.squeeze(), p=2, dim=1)  # [B, E]
        routing_sim = torch.matmul(soft_targets_norm, soft_targets_norm.T)  # [B, B]
        
        weight = 1 - routing_sim

        mask = torch.le(weight_mask, 0).int()

        n_weight = 1
        d_weight = torch.pow(gamma, weight)

        logits_mask = torch.scatter(
            torch.ones_like(mask),
            1,
            torch.arange(batch_size * anchor_count).view(-1, 1).to(device),
            0,
        )

        mask = mask.fill_diagonal_(0)

        exp_logits = torch.exp(logits) * d_weight * logits_mask
        exp_logits = exp_logits.sum(1, keepdim=True)[logits_mask.sum(1) > 0]
        log_prob = torch.exp(logits) / exp_logits
        log_prob = torch.log(log_prob)

        numerator = (mask * n_weight * log_prob).sum(1)
        denominator = (mask).sum(1)
        numerator = numerator[denominator > 0]
        denominator = denominator[denominator > 0]
        mean_log_prob_pos = numerator / denominator

        if torch.isnan(mean_log_prob_pos).any():
            pdb.set_trace()

        loss = -(self.temperature / self.base_temperature) * mean_log_prob_pos
        loss_graph_smiles = loss.view(anchor_count, -1).mean()

        loss = loss_graph_smiles
        return loss
