import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math


class RenyiSCLLoss(nn.Module):
    def __init__(self, alpha, beta=1.0, gamma=1.0, supt=1.0, temperature=1.0, base_temperature=None, K=128, num_classes=1000,
                 obj='cpc', gamma_renyi=2.0, gamma_schedule='constant', gamma_min=1.2, gamma_max=2.0,
                 gamma_many=1.2, gamma_medium=1.4, gamma_few=1.6, inv_pow=-0.25,
                 alpha_schedule='batch', alpha_renyi=0.001, alpha_max=0.001, alpha_pow=0):
        super(RenyiSCLLoss, self).__init__()
        self.temperature = temperature
        self.base_temperature = temperature if base_temperature is None else base_temperature
        self.K = K
        self.alpha = alpha
        self.beta = beta
        self.gamma = gamma
        self.supt = supt
        self.num_classes = num_classes
        self.obj = obj
        self.gamma_renyi = gamma_renyi
        self.gamma_schedule = gamma_schedule
        self.gamma_min = gamma_min
        self.gamma_max = gamma_max
        self.inv_pow = inv_pow

        self.gamma_many = gamma_many
        self.gamma_medium = gamma_medium
        self.gamma_few = gamma_few

        self.alpha_schedule = alpha_schedule
        self.alpha_renyi = alpha_renyi
        self.alpha_max = alpha_max
        self.alpha_pow = alpha_pow

    def cal_weight_for_classes(self, cls_num_list):
        cls_num_list = torch.Tensor(cls_num_list).view(1, self.num_classes)
        self.weight = cls_num_list / cls_num_list.sum()
        self.weight = self.weight.to(torch.device('cuda'))

    def cal_gamma_for_classes(self, cls_num_list):
        cls_num_list = torch.Tensor(cls_num_list).cuda()
        if self.gamma_schedule == 'constant':
            self.gamma_per_class = torch.ones_like(cls_num_list) * self.gamma_renyi
        elif self.gamma_schedule == 'inv':
            cls_num_list_pow = cls_num_list.pow(self.inv_pow)
            self.gamma_per_class = (cls_num_list_pow - cls_num_list.max().pow(self.inv_pow)) / \
                                   (cls_num_list.min().pow(self.inv_pow) - cls_num_list.max().pow(self.inv_pow)) \
                                   * (self.gamma_max - self.gamma_min) + self.gamma_min
        elif self.gamma_schedule == 'linear':
            self.gamma_per_class = cls_num_list.argsort(descending=True).argsort() / (len(cls_num_list) - 1) * (self.gamma_max - self.gamma_min) + self.gamma_min
        elif self.gamma_schedule == 'region':
            self.gamma_per_class = (cls_num_list > 100) * self.gamma_many + \
                                   ((cls_num_list >= 20) & (cls_num_list <= 100)) * self.gamma_medium + \
                                   (cls_num_list < 20) * self.gamma_few
        elif self.gamma_schedule == 'cos':
            cos = torch.cos(-math.pi * (cls_num_list - cls_num_list.min()) / (cls_num_list.max() - cls_num_list.min()))
            self.gamma_per_class = (cos - cos.min()) / (cos.max() - cos.min()) * (self.gamma_max - self.gamma_min) + self.gamma_min

    def cal_alpha_for_classes(self, cls_num_list):
        cls_num_list = torch.Tensor(cls_num_list).cuda()
        if self.alpha_schedule == 'batch':
            return
        elif self.alpha_schedule == 'constant':
            self.alpha_per_class = torch.ones_like(cls_num_list) * self.alpha_renyi

    def forward(self, features, labels=None, sup_logits=None, mask=None, epoch=None):
        device = (torch.device('cuda')
                  if features.is_cuda
                  else torch.device('cpu'))

        ss = features.shape[0]
        batch_size = ( features.shape[0] - self.K ) // 2

        labels = labels.contiguous().view(-1, 1)
        mask = torch.eq(labels[:batch_size], labels.T).float().to(device)

        # compute logits
        anchor_dot_contrast = torch.div(
            torch.matmul(features[:batch_size], features.T),
            self.temperature)

        # add supervised logits
        anchor_dot_contrast = torch.cat(( (sup_logits + torch.log(self.weight + 1e-9) ) / self.supt, anchor_dot_contrast), dim=1)

        # for numerical stability
        logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
        logits = anchor_dot_contrast - logits_max.detach()

        # mask-out self-contrast cases
        logits_mask = torch.scatter(
            torch.ones_like(mask),
            1,
            torch.arange(batch_size).view(-1, 1).to(device),
            0
        )
        mask = mask * logits_mask

        if self.obj == 'naive_cpc':
            # add ground truth
            one_hot_label = torch.nn.functional.one_hot(labels[:batch_size,].view(-1,), num_classes=self.num_classes).to(torch.float32)
            mask = torch.cat((one_hot_label * self.beta, mask * self.alpha), dim=1)

            # compute log_prob
            logits_mask = torch.cat((torch.ones(batch_size, self.num_classes).to(device), self.gamma * logits_mask), dim=1)
            exp_logits = torch.exp(logits) * logits_mask
            log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True) + 1e-12)

            # compute mean of log-likelihood over positive
            mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1)

            # loss
            loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos
            loss = loss.mean()

        elif self.obj == 'cpc':
            one_hot_label = torch.nn.functional.one_hot(labels[:batch_size,].view(-1,), num_classes=self.num_classes).to(torch.float32)
            mask_pos = torch.cat((one_hot_label, mask), dim=1)
            weights_mask = torch.cat((one_hot_label * self.beta, mask * self.alpha), dim=1)
            logits_mask = torch.cat((torch.ones(batch_size, self.num_classes).to(device), self.gamma * logits_mask), dim=1)
            mask_neg = logits_mask - mask_pos

            pos = logits * mask_pos
            neg = logits * mask_neg
            weighted_pos = logits * weights_mask

            loss_1 = -1 * (weighted_pos.sum(dim=1, keepdim=True) / weights_mask.sum(dim=1, keepdim=True)).mean()  # v2; equivalent to original supcon loss
            e_pos = (pos.exp() * mask_pos).sum(dim=1, keepdim=True) / mask_pos.sum(dim=1, keepdim=True)
            e_neg = (neg.exp() * mask_neg).sum(dim=1, keepdim=True) / mask_neg.sum(dim=1, keepdim=True)
            alpha = mask_pos.sum(dim=1, keepdim=True) / logits_mask.sum(dim=1, keepdim=True)
            loss_2 = torch.log(alpha * e_pos + (1 - alpha) * e_neg).mean()
            loss = loss_1 + loss_2

        elif self.obj == 'rcpc':
            gamma = self.gamma_per_class[labels[:batch_size]]

            one_hot_label = torch.nn.functional.one_hot(labels[:batch_size,].view(-1,), num_classes=self.num_classes).to(torch.float32)
            mask_pos = torch.cat((one_hot_label, mask), dim=1)
            weights_mask = torch.cat((one_hot_label * self.beta, mask * self.alpha), dim=1)
            logits_mask = torch.cat((torch.ones(batch_size, self.num_classes).to(device), self.gamma * logits_mask), dim=1)
            mask_neg = logits_mask - mask_pos

            pos = logits * mask_pos
            neg = logits * mask_neg
            # weighted_pos = logits * weights_mask

            if (gamma == 1).any():
                weighted_pos = logits * weights_mask
                cpc = weighted_pos.sum(dim=1, keepdim=True) / weights_mask.sum(dim=1, keepdim=True)
                rcpc = (((gamma - 1) * pos).exp() * weights_mask).sum(dim=1, keepdim=True) / weights_mask.sum(dim=1,  keepdim=True)
                loss_1 = - 1 * ((torch.log(rcpc[gamma != 1]) / (gamma - 1)[gamma != 1]).sum() + cpc[gamma == 1].sum()) / len(gamma)
            else:
                e_pos_1 = (((gamma - 1) * pos).exp() * weights_mask).sum(dim=1, keepdim=True) / weights_mask.sum(dim=1,  keepdim=True)
                loss_1 = - 1 * (torch.log(e_pos_1) / (gamma - 1)).mean()
            e_pos = ((gamma * pos).exp() * mask_pos).sum(dim=1, keepdim=True) / mask_pos.sum(dim=1, keepdim=True)
            e_neg = ((gamma * neg).exp() * mask_neg).sum(dim=1, keepdim=True) / mask_neg.sum(dim=1, keepdim=True)
            if self.alpha_schedule == 'batch':
                alpha = mask_pos.sum(dim=1, keepdim=True) / logits_mask.sum(dim=1, keepdim=True)
            else:
                alpha = self.alpha_per_class[labels[:batch_size]]
            loss_2 = (torch.log(alpha * e_pos + (1 - alpha) * e_neg) / gamma).mean()
            loss = loss_1 + loss_2

        return loss
