import torch.nn as nn
import torch
import torch.nn.functional as F

from lightly.loss import NegativeCosineSimilarity, VICRegLoss

import numpy as np
import wandb

def load_loss(loss, config):

    if loss == 'simclr':
        return MultiViewContrastive(config.agent.n_views, 
                                    config.agent.temperature, 
                                    config.device)
    elif loss == 'simsiam':
        return NegativeCosineLoss()
    elif loss == 'vicreg':
        return VICRLoss()
    elif loss == 'cassle':
        return CaSSLeLoss(config)
    elif loss == 'scale':
        return SCALELoss(config)
    elif loss == 'super':
        return CrossEntLoss()
    

class CrossEntLoss(nn.Module):

    def __init__(self):
        super(CrossEntLoss, self).__init__()
        self.criterion = torch.nn.CrossEntropyLoss()

    def forward(self, outs):
        return self._forward(outs['y_pred'], outs['y'])
    
    def _forward(self, y_pred, y_true):
        return self.criterion(y_pred, y_true)

class NegativeCosineLoss(nn.Module):

    def __init__(self):
        super(NegativeCosineLoss, self).__init__()
        self.criterion = NegativeCosineSimilarity()

    def forward(self, outs):
        return 0.5 * (self._forward(outs['h0'].detach(), outs['p1']) + \
                      self._forward(outs['h1'].detach(), outs['p0']))

    def _forward(self, h, p):
        return self.criterion(h, p)
    
class VICRLoss(nn.Module):

    def __init__(self, lamda=25.0, mu=25.0, nu=1.0):
        super(VICRLoss, self).__init__()
        self.criterion = VICRegLoss()

    def forward(self, outs):
        return self._forward(outs['h0'], outs['h1'])
    
    def _forward(self, h0, h1):
        return self.criterion(h0, h1)
       
class CaSSLeLoss(nn.Module):

    def __init__(self, config):
        super(CaSSLeLoss, self).__init__()
        self.criterion = load_loss(config.agent.loss, config)
        self.n_views = config.agent.n_views

    def forward(self, outs):

        loss_cont = self.criterion(outs)

        loss_distill = torch.sum(torch.stack([self.criterion._forward(outs[f'z{i}_frozen'], outs[f'z{i}_predicted']) for i in range(self.n_views)]))

        return loss_cont + loss_distill

class MultiViewContrastive(nn.Module):

    def __init__(self, n_views, temperature, device):
        super(MultiViewContrastive, self).__init__()
        self.n_views = n_views
        self.temperature = temperature
        self.device = device

        self.criterion = torch.nn.CrossEntropyLoss()

    def forward(self, outs, n_views=None, return_ind=False):
        if n_views != None:
            self.n_views = n_views

        loss, individual = self._forward(*[outs[f'h{i}'] for i in range(self.n_views)])

        if return_ind:
            return loss, individual.detach()
        else:
            return loss

    def _forward(self, *feats):

        features = torch.cat(feats)

        labels = torch.cat([torch.arange(feats[0].shape[0]) for i in range(self.n_views)], dim=0)
        labels = (labels.unsqueeze(0) == labels.unsqueeze(1)).float()
        labels = labels.to(self.device)

        features = F.normalize(features, dim=1)

        similarity_matrix = torch.matmul(features, features.T)


        # discard the main diagonal from both: labels and similarities matrix
        mask = torch.eye(labels.shape[0], dtype=torch.bool).to(self.device)
        labels = labels[~mask].view(labels.shape[0], -1)
        similarity_matrix = similarity_matrix[~mask].view(similarity_matrix.shape[0], -1) / self.temperature

        # select and combine multiple positives
        positives = similarity_matrix[labels.bool()].view(labels.shape[0], -1)

        # select only the negatives the negatives
        negatives = similarity_matrix[~labels.bool()].view(similarity_matrix.shape[0], -1)

        logits = torch.cat([positives, negatives], dim=1)
        labels = torch.zeros(logits.shape[0], dtype=torch.long).to(self.device)

        ### Collect Individual Losses for logging 
        positives = torch.mean(torch.exp(positives), dim=1)
        negatives = torch.mean(torch.exp(negatives), dim=1)


        return self.criterion(logits, labels), negatives
  
class SCALELoss(nn.Module):

    def __init__(self, config):
        super(SCALELoss, self).__init__()
        self.device = config.device

        self.batch_size = config.scenario.batch_size + config.agent.mem_batch_size
        self.stream_bs = config.scenario.batch_size

        self.loss_distill_hist = []
        self.loss_contrast_hist = []

        self.distill_power_moving = 0
        self.distill_power = config.agent.distill_power

        self.temperature = config.agent.temperature
        self.base_temperature = config.agent.base_temp
        self.temp_tSNE = config.agent.temp_tSNE
        self.thres_ratio = config.agent.thresh_ratio
        self.curr_temp = config.agent.curr_temp
        self.past_temp = config.agent.past_temp

    def forward(self, outs, old_feat):

        feat_1, feat_2 = outs['h0'], outs['h1']

        feat_all = torch.cat((feat_1, feat_2))
        mask = self.similarity_mask_old(feat_all)

        loss_cont = self.loss_contrast(feat_1, feat_2, mask)
        self.loss_contrast_hist.append(loss_cont.item())

        loss_distill = self.loss_distill(feat_1, old_feat)
        self.loss_distill_hist.append(loss_distill.item())

        if self.distill_power_moving <= 0.0 and loss_distill > 0.0:
            self.distill_power_moving = np.mean(self.loss_contrast_hist) * self.distill_power / np.mean(self.loss_distill_hist)

        return loss_cont + self.distill_power_moving * loss_distill

    def similarity_mask_old(self, feat_all):
        """Calculate the pairwise similarity and the mask for contrastive learning
        Args:
            feat_all: all hidden features of shape [n_views * bsz, ...].
            bsz: int, batch size of input data (stacked streaming and memory samples)
            opt: arguments
            pos_pairs: averagemeter recording number of positive pairs
        Returns:
            contrast_mask: mask of shape [bsz, bsz]
        """

        feat_size = feat_all.size(0)
        n_views = int(feat_size / self.batch_size)
        assert (n_views * self.batch_size == feat_size), "Unmatch feature sizes and batch size!"

        # Compute the pairwise distance and similarity between each view
        # and add the similarity together for average
        simil_mat_avg = torch.zeros(self.batch_size, self.batch_size).to(self.device)
        mat_cnt = 0
        for i in range(n_views):
            for j in range(n_views):
                # compute euclidean distance pairs
                simil_mat = 2 - 2 * torch.matmul(feat_all[i*self.batch_size: (i+1)*self.batch_size],
                                                feat_all[j*self.batch_size: (j+1)*self.batch_size].T)
                #print('\teuc dist', simil_mat * 1e4)
                tmp_distance = - torch.div(simil_mat, self.temp_tSNE)
                tmp_distance = tmp_distance - 1000 * torch.eye(self.batch_size).to(self.device)
                #print('\ttemp dist', tmp_distance * 1e4)
                simil_mat = 0.5 * torch.softmax(tmp_distance, 1) + 0.5 * torch.softmax(tmp_distance, 0)


                # Add the new probability to the average probability
                simil_mat_avg = (mat_cnt * simil_mat_avg + simil_mat) / (mat_cnt + 1)
                mat_cnt += 1
        #print('simil_mat_avg', simil_mat_avg * 1e4)
        logits_mask = torch.scatter(
            torch.ones_like(simil_mat_avg),
            1,
            torch.arange(simil_mat_avg.size(0)).view(-1, 1).to(self.device),
            0
        )
        simil_max = simil_mat_avg[logits_mask.bool()].max()
        simil_mean = simil_mat_avg[logits_mask.bool()].mean()
        simil_min = simil_mat_avg[logits_mask.bool()].min()
 
        # Set diagonal of similarity matrix to ones
        masks = torch.eye(self.batch_size).to(self.device)
        simil_mat_avg = simil_mat_avg * (1 - masks) + masks

        # mask out memory elements
        stream_mask = torch.zeros_like(simil_mat_avg).float().to(self.device)
        stream_mask[:self.stream_bs, :self.stream_bs] = 1
        simil_mat_avg = simil_mat_avg * stream_mask

        contrast_mask = torch.zeros_like(simil_mat_avg).float().to(self.device)

        simil_thres = simil_mean + self.thres_ratio * (simil_max - simil_mean)

        contrast_mask[simil_mat_avg > simil_thres] = 1


        return contrast_mask

    def loss_distill(self, cur_features, past_features):
        cur_features = F.normalize(cur_features, dim=1)
        past_features = F.normalize(past_features, dim=1)

        cur_features_sim = torch.div(torch.matmul(cur_features, cur_features.T),
                                     self.curr_temp)
        logits_mask = torch.scatter(
            torch.ones_like(cur_features_sim),
            1,
            torch.arange(cur_features_sim.size(0)).view(-1, 1).to(self.device),
            0
        )
        cur_logits_max, _ = torch.max(cur_features_sim * logits_mask, dim=1, keepdim=True)
        cur_features_sim = cur_features_sim - cur_logits_max.detach()
        row_size =cur_features_sim.size(0)
        cur_logits = torch.exp(cur_features_sim[logits_mask.bool()].view(row_size, -1)) / torch.exp(
            cur_features_sim[logits_mask.bool()].view(row_size, -1)).sum(dim=1, keepdim=True)
        # print('cur_logits', cur_logits * 1e4)

        past_features_sim = torch.div(torch.matmul(past_features, past_features.T), self.past_temp)
        past_logits_max, _ = torch.max(past_features_sim * logits_mask, dim=1, keepdim=True)
        past_features_sim = past_features_sim - past_logits_max.detach()
        past_logits = torch.exp(past_features_sim[logits_mask.bool()].view(row_size, -1)) / torch.exp(
            past_features_sim[logits_mask.bool()].view(row_size, -1)).sum(dim=1, keepdim=True)

        return (- past_logits * torch.log(cur_logits)).sum(1).mean()
    

    def loss_contrast(self, z_stu, z_tch, mask):

        z_stu = F.normalize(z_stu, dim=1)
        z_tch = F.normalize(z_tch, dim=1)

        batch_size = z_stu.shape[0]

        all_features = torch.cat((z_stu, z_tch), dim=0)

        mask = mask.float().to(self.device)

        # compute logits
        anchor_dot_contrast = torch.div(
            torch.matmul(all_features, all_features.T),
            self.temperature)
        # for numerical stability
        logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
        logits = anchor_dot_contrast - logits_max.detach()

        # tile mask
        mask = mask.repeat(2, 2)
        # mask-out self-contrast cases
        logits_mask = torch.scatter(
            torch.ones_like(mask),
            1,
            torch.arange(batch_size * 2).view(-1, 1).to(self.device),
            0
        )
        mask = mask * logits_mask

        # compute log_prob
        exp_logits = torch.exp(logits) * logits_mask
        log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))

        # compute mean of log-likelihood over positive
        mean_log_prob_pos = (mask * log_prob).sum(1) / (mask.sum(1) + 1e-10)
        # print(mean_log_prob_pos.shape, mean_log_prob_pos.max().item(), mean_log_prob_pos.mean().item(), mean_log_prob_pos.min().item())

        # loss
        loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos
        loss = loss.view(2, batch_size)
        stream_mask = torch.zeros_like(loss).float().to(self.device)
        stream_mask[:, :self.stream_bs] = 1
        loss = (stream_mask * loss).sum() / stream_mask.sum()
        return loss




    