import torch
import torch.nn as nn
from torch.nn import functional as F

import numpy as np

__all__ = [
    'CrossEntropyCustom',
    'BCECustom',
    'MSECustom',
    'CosineLoss',
    'LogLoss',
    'LogPosLoss',
    'SupConLoss',
]


class CrossEntropyCustom(nn.Module):
    def __init__(self, param_dict=None, **kwargs):
        super(CrossEntropyCustom, self).__init__()

    def forward(self, output, targets):
        return F.cross_entropy(output, targets)


class BCECustom(nn.Module):
    def __init__(self, param_dict=None, **kwargs):
        super().__init__()
        self.num_classes = param_dict['num_classes']

    def forward(self, output, targets):
        tgt = F.one_hot(targets, num_classes=self.num_classes).float()
        return F.binary_cross_entropy(output, tgt)


class MSECustom(nn.Module):
    def __init__(self, param_dict=None, **kwargs):
        super().__init__()

    def forward(self, features, weights):
        return F.mse_loss(features, weights)


class CosineLoss(nn.Module):
    def __init__(self, param_dict=None, **kwargs):
        super().__init__()
        self.num_classes = param_dict['num_classes']

    def forward(self, output, targets):
        tgt = F.one_hot(targets, num_classes=self.num_classes).float()
        return torch.mean(torch.sum(tgt * (1. - output), dim=1))


class LogLoss(nn.Module):
    def __init__(self, param_dict=None, **kwargs):
        super(LogLoss, self).__init__()
        self.num_classes = param_dict['num_classes']
        self.eps = 1e-18

    def forward(self, output, targets):
        tgt = F.one_hot(targets, num_classes=self.num_classes).float()
        x = torch.log(output + self.eps)
        return -torch.mean(torch.sum(tgt * x, dim=1))


class LogPosLoss(nn.Module):
    def __init__(self, param_dict=None, **kwargs):
        super(LogPosLoss, self).__init__()
        self.num_classes = param_dict['num_classes']
        self.eps = 1e-18

    def forward(self, output, targets):
        tgt = F.one_hot(targets, num_classes=self.num_classes).float()
        output = (output + 1) / 2.
        x = torch.log(output + self.eps)
        return -torch.mean(torch.sum(tgt * x, dim=1))


# Supervised Contrastive Learning: 
# (paper) https://arxiv.org/pdf/2004.11362.pdf
# (official code) https://github.com/HobbitLong/SupContrast/
class SupConLoss(nn.Module):
    def __init__(
        self,
        param_dict=None,
        temperature=0.07,
        contrast_mode='all',
        base_temperature=0.07,
        **kwargs,
    ):
        super(SupConLoss, self).__init__()
        self.temperature = temperature
        self.contrast_mode = contrast_mode
        self.base_temperature = base_temperature

    def forward(self, features, labels=None, mask=None, device=None):
        device = torch.device('cpu') if device is None else device

        if len(features.shape) < 3:
            raise ValueError("`features` needs to be [bsz, n_views, ...],"
                             "ate least 3 dimensions are required")
        if len(features.shape) > 3:
            features = features.reshape(features.shape[0], features.shape[1], -1)

        batch_size = features.shape[0]
        if (labels is not None) and (mask is not None):
            raise ValueError("Cannot define both `labels` and `mask`")
        elif (labels is None) and (mask is None):
            mask = torch.eye(batch_size, dtype=torch.float32).to(device)
        elif labels is not None:
            labels = labels.contiguous().reshape(-1, 1)
            if labels.shape[0] != batch_size:
                raise ValueError("# labels does not match # features")
            mask = torch.eq(labels, labels.T).float().to(device)
        else:
            mask = mask.float().to(device)

        contrast_cnt = features.shape[1]
        contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0)
        if self.contrast_mode == 'one':
            anchor_feature = features[:, 0]
            anchor_cnt = 1
        elif self.contrast_mode == 'all':
            anchor_feature = contrast_feature
            anchor_cnt = contrast_cnt
        else:
            raise ValueError("Unknown mode: {}".format(self.contrast_mode))

        # compute logits
        anchor_dot_contrast = torch.div(
            torch.matmul(anchor_feature, contrast_feature.T),
            self.temperature
        )
        # for numerical stability
        logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
        logits = anchor_dot_contrast - logits_max.detach()

        # tile mask
        mask = mask.repeat(anchor_cnt, contrast_cnt)
        # mask-out self-contrast cases
        logits_mask = torch.scatter(
            torch.ones_like(mask),
            1,
            torch.arange(batch_size * anchor_cnt).reshape(-1, 1).to(device),
            0
        )
        mask = mask * logits_mask

        # compute log_prob
        exp_logits = torch.exp(logits) * logits_mask
        log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))

        # compute mean of log-likelihood over positive
        mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1)

        # loss
        loss = -(self.temperature / self.base_temperature) * mean_log_prob_pos
        loss = loss.reshape(anchor_cnt, batch_size).mean()

        return loss

