
import json
import math
import torch

import torch.nn as nn
import torch.nn.functional as F

from transformers import RobertaModel, RobertaForMaskedLM


def byol_loss(x, y):
    x = F.normalize(x, dim=-1, p=2)
    y = F.normalize(y, dim=-1, p=2)
    return 2 - 2 * (x * y).sum(dim=-1).mean()


@torch.no_grad()
def concat_all_gather(tensor):
    """
    Performs all_gather operation on the provided tensors.
    *** Warning ***: torch.distributed.all_gather has no gradient.
    """
    tensors_gather = [torch.ones_like(tensor)
        for _ in range(torch.distributed.get_world_size())]
    torch.distributed.all_gather(tensors_gather, tensor, async_op=False)

    output = torch.cat(tensors_gather, dim=0)
    return output


class ModifiedSupConLoss(nn.Module):
    """Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf.
    It also supports the unsupervised contrastive loss in SimCLR"""
    def __init__(self, temperature=0.07, contrast_mode='all',
                 base_temperature=0.07):
        super(ModifiedSupConLoss, self).__init__()
        self.temperature = temperature
        self.contrast_mode = contrast_mode
        self.base_temperature = base_temperature

    def forward(self, features, protos=None, labels=None, proto_labels=None, mask=None):
        """Compute loss for model. If both `labels` and `mask` are None,
        it degenerates to SimCLR unsupervised loss:
        https://arxiv.org/pdf/2002.05709.pdf

        Args:
            features: hidden vector of shape [bsz, n_views, ...].
            labels: ground truth of shape [bsz].
            mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j
                has the same class as sample i. Can be asymmetric.
        Returns:
            A loss scalar.
        """

        feature_bz = features.shape[0]
        # protos_bz = protos.shape[0]

        # print('raw features: {}'.format(features.shape))
        if protos is not None:
            features = torch.cat([features, protos], dim=0)
            labels = torch.cat([labels, proto_labels], dim=0)
        # print('cat features: {}'.format(features.shape))

        device = (torch.device('cuda')
                  if features.is_cuda
                  else torch.device('cpu'))

        if len(features.shape) < 3:
            raise ValueError('`features` needs to be [bsz, n_views, ...],'
                             'at least 3 dimensions are required')
        if len(features.shape) > 3:
            features = features.view(features.shape[0], features.shape[1], -1)

        batch_size = features.shape[0]
        if labels is not None and mask is not None:
            raise ValueError('Cannot define both `labels` and `mask`')
        elif labels is None and mask is None:
            mask = torch.eye(batch_size, dtype=torch.float32).to(device)
        elif labels is not None:
            labels = labels.contiguous().view(-1, 1)
            if labels.shape[0] != batch_size:
                raise ValueError('Num of labels does not match num of features')
            mask = torch.eq(labels, labels.T).float().to(device)
        else:
            mask = mask.float().to(device)

        contrast_count = features.shape[1]
        contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0)
        if self.contrast_mode == 'one':
            anchor_feature = features[:, 0]
            anchor_count = 1
        elif self.contrast_mode == 'all':
            anchor_feature = contrast_feature
            anchor_count = contrast_count
        else:
            raise ValueError('Unknown mode: {}'.format(self.contrast_mode))

        # compute logits
        anchor_dot_contrast = torch.div(
            torch.matmul(anchor_feature, contrast_feature.T),
            self.temperature)
        # for numerical stability
        logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
        logits = anchor_dot_contrast - logits_max.detach()

        # tile mask
        mask = mask.repeat(anchor_count, contrast_count)
        # mask-out self-contrast cases
        logits_mask = torch.scatter(
            torch.ones_like(mask),
            1,
            torch.arange(batch_size * anchor_count).view(-1, 1).to(device),
            0
        )
        mask = mask * logits_mask

        # compute log_prob
        exp_logits = torch.exp(logits) * logits_mask
        log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))

        # print('mask: {}\n\t{}'.format(mask, mask.shape))
        # print('\tmask.sum(): {}'.format(mask.sum()))
        # print('log_prob: {}'.format(log_prob.shape))
        # compute mean of log-likelihood over positive
        mean_log_prob_pos = (mask * log_prob).sum(1) / (mask.sum(1) + 1e-9)

        # loss
        loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos
        loss = loss.view(anchor_count, batch_size)
        # print('raw loss: {}'.format(loss.shape))
        loss = loss[:, :feature_bz]

        if protos is not None:
            labels = labels[:feature_bz].view(1, -1)
            for label_idx in range(protos.shape[0]):
                label_loss = loss[labels == label_idx]
                print('emo {} avg: {}'.format(label_idx, label_loss.mean()))

        # n_50_pct = int(feature_bz * 0.5)
        # low_vals, _ = torch.topk(loss.detach(), k=n_50_pct, largest=False)
        # print('low_vals: {}'.format(low_vals))
        # low_val = low_vals.max()
        # print('low_val: {}'.format(low_val))
        # loss = loss[loss > 5.0]

        # print('loss: {}'.format(loss.shape))
        loss = loss.mean()

        return loss


class LMWithEntityAttention(nn.Module):
    def __init__(self, args, mode):
        super(LMWithEntityAttention, self).__init__()
        self.args = args
        self.mode = mode

        self.max_n_utterances = self.args.max_n_utterances
        self.do_sentiment = self.args.do_sentiment
        self.objective = self.args.objective
        self.general_dropout_prob = self.args.general_dropout_prob
        self.contrastive_dim = self.args.contrastive_dim
        self.use_meld = self.args.use_meld
        self.use_iemocap = self.args.use_iemocap
        self.use_emory_nlp = self.args.use_emory_nlp
        self.project_embds_for_contrast = self.args.project_embds_for_contrast
        self.supcon_loss_fn = self.args.supcon_loss_fn
        self.freeze_roberta = self.args.freeze_roberta
        self.add_cls_to_entity = self.args.add_cls_to_entity
        self.token_mask_pct = self.args.token_mask_pct
        self.entity_attn = self.args.entity_attn
        self.e2e_attn = self.args.e2e_attn
        self.use_queue = self.args.use_queue

        self.T = 1.0
        self.eps = 1e-8

        self.CLS_TOKEN_ID = 0
        self.PAD_TOKEN_ID = 1
        self.MASK_TOKEN_ID = 50264

        # self.roberta = RobertaModel.from_pretrained('roberta-base')
        # self.roberta = RobertaModel.from_pretrained('princeton-nlp/sup-simcse-roberta-large')
        self.roberta = RobertaModel.from_pretrained('roberta-large')
        # self.roberta.train()

        if self.use_iemocap and self.use_emory_nlp and self.use_meld:
            self.n_emotion_classes = 13
        elif self.use_iemocap and not self.use_emory_nlp and self.use_meld:
            self.n_emotion_classes = 11
        elif not self.use_iemocap and self.use_emory_nlp and self.use_meld:
            self.n_emotion_classes = 9
        elif self.use_emory_nlp and not self.use_iemocap and not self.use_meld:
            self.n_emotion_classes = 7
        else:
            self.n_emotion_classes = 7
        # self.n_emotion_classes = 11 if self.use_iemocap else 7
        self.n_sentiment_classes = 3
        self.dropout = nn.Dropout(self.general_dropout_prob)

        if self.entity_attn:
            roberta_embd_weights = self.roberta.embeddings.word_embeddings.weight
            roberta_cls_embd = roberta_embd_weights[0].clone().unsqueeze(0)
            print('roberta_embd_weights: {}'.format(roberta_embd_weights.shape))
            print('roberta_cls_embd: {}'.format(roberta_cls_embd.shape))
            self.entity_embd = nn.Embedding.from_pretrained(roberta_cls_embd, freeze=False)

            if self.e2e_attn:
                roberta_pos_embd_weights = self.roberta.embeddings.position_embeddings.weight
                print('raw roberta_pos_embd_weights: {}'.format(roberta_pos_embd_weights.shape))
                entity_pos_embds = self.roberta.embeddings.position_embeddings.weight[2].clone().unsqueeze(0).expand(self.max_n_utterances, -1)
                roberta_pos_embd_weights = torch.cat([roberta_pos_embd_weights, entity_pos_embds], dim=0)
                print('new roberta_pos_embd_weights: {}'.format(roberta_pos_embd_weights.shape))
                self.roberta.embeddings.position_embeddings = nn.Embedding.from_pretrained(roberta_pos_embd_weights,
                                                                                           freeze=False)

        if self.mode == 'pt' or self.mode == 'pt_test':
            print('Getting lm_head from pretrained model...')
            self.vocab_pred_head = RobertaForMaskedLM.from_pretrained('roberta-large').lm_head
            # self.vocab_pred_head.train()
            # self.vocab_pred_head = nn.Linear(1024, roberta_embd_weights.shape[0])
            # print('self.vocab_pred_head.dense.requires_grad: {}'.format(self.vocab_pred_head.dense.requires_grad))
            print('self.vocab_pred_head: {}'.format(self.vocab_pred_head))
        else:
            if self.objective == 'xent':
                self.emotion_pred_head = nn.Linear(1024, self.n_emotion_classes)
                if self.do_sentiment:
                    self.sent_pred_head = nn.Linear(1024, self.n_sentiment_classes)
            elif self.objective == 'xent_supcon':
                self.emotion_pred_head = nn.Linear(1024, self.n_emotion_classes)

                if self.project_embds_for_contrast:
                    self.emotion_projection = nn.Linear(1024, self.contrastive_dim)
            elif self.objective == 'xent_cosine_sim':
                self.emotion_pred_head = nn.Linear(1024, self.n_emotion_classes)

                if self.project_embds_for_contrast:
                    self.emotion_projection = nn.Linear(1024, self.contrastive_dim)
            else:
                if self.supcon_loss_fn == 'byol':
                    self.projector = nn.Linear(1024, self.contrastive_dim)
                    self.predictor = nn.Linear(self.contrastive_dim, self.contrastive_dim)
                else:
                    if self.project_embds_for_contrast:
                        self.emotion_projection = nn.Linear(1024, self.contrastive_dim)
                    # self.emotion_prototypes = nn.Embedding(self.n_emotion_classes, self.contrastive_dim)

                    if self.do_sentiment:
                        self.sentiment_projection = nn.Linear(1024, self.contrastive_dim)
                        self.sentiment_prototypes = nn.Embedding(self.n_sentiment_classes, self.contrastive_dim)

    def forward(self, input_ids, attn_mask, emo_labels=None, sent_labels=None, entity_presence=None,
                emo_negs=None, sent_negs=None, position_ids=None, emotion_queue=None):
        input_id_labels = input_ids.clone().detach()
        if self.freeze_roberta:
            with torch.no_grad():
                model_out = self.process_data(input_ids, attn_mask, emo_labels, sent_labels,
                                              entity_presence, emo_negs, sent_negs, position_ids)
        else:
            model_out = self.process_data(input_ids, attn_mask, emo_labels, sent_labels,
                                          entity_presence, emo_negs, sent_negs, position_ids)

        if self.mode == 'pt' or self.mode == 'pt_test':
            masked_embds, masked_idxs = model_out
            # print('masked_idxs: {}'.format(masked_idxs))
            masked_preds = self.vocab_pred_head(self.dropout(masked_embds))
            masked_labels = input_id_labels[masked_idxs]

            xent = nn.CrossEntropyLoss()
            loss = xent(masked_preds.view(-1, masked_preds.shape[-1]), masked_labels.view(-1))

            to_return = [loss, masked_preds.view(-1, masked_preds.shape[-1]), masked_labels.view(-1), masked_idxs]

        else:
            present_entity_embeddings = model_out

            emo_preds = None
            emo_loss = None
            if emo_labels is not None:
                emo_labels = emo_labels[entity_presence > 0]
                emo_negs = emo_negs[entity_presence > 0]
                if self.objective == 'xent':
                    xent = nn.CrossEntropyLoss(reduction='none')
                    # print('emo_labels: {}'.format(emo_labels.shape))
                    emo_preds = self.emotion_pred_head(self.dropout(present_entity_embeddings))
                    # emo_preds = self.emotion_pred_head(present_entity_embeddings)
                    # print('emo_preds: {}'.format(emo_preds.shape))
                    emo_loss = xent(emo_preds, emo_labels)
                    if torch.isnan(emo_loss).any():
                        agg = torch.cat([emo_loss.view(-1, 1), emo_labels.view(-1, 1).float()], dim=-1)
                        print('agg:\n{}'.format(agg))

                    emo_loss = emo_loss.mean()
                elif self.objective == 'xent_supcon':
                    # print('xent_supcon loss')
                    xent = nn.CrossEntropyLoss(reduction='none')
                    # print('emo_labels: {}'.format(emo_labels.shape))
                    emo_preds = self.emotion_pred_head(self.dropout(present_entity_embeddings))
                    # emo_preds = self.emotion_pred_head(present_entity_embeddings)
                    # print('emo_preds: {}'.format(emo_preds.shape))
                    emo_xent_loss = xent(emo_preds, emo_labels)

                    if torch.isnan(emo_xent_loss).any():
                        agg = torch.cat([emo_xent_loss.view(-1, 1), emo_labels.view(-1, 1).float()], dim=-1)
                        print('agg:\n{}'.format(agg))

                    emo_xent_loss = emo_xent_loss.mean()

                    if self.project_embds_for_contrast:
                        emo_projs = self.emotion_projection(self.dropout(present_entity_embeddings))
                    else:
                        emo_projs = present_entity_embeddings
                    # print('emo_preds: {}'.format(emo_preds))

                    if self.use_queue:
                        with torch.no_grad():
                            emo_protos = self.make_emo_protos(
                                emotion_queue, emo_labels,
                                single=True if self.supcon_loss_fn in ['cosine_sim', 'supcon'] else False
                            )
                    else:
                        emo_protos = None

                    loss_fn = ModifiedSupConLoss(temperature=0.9, base_temperature=0.9)
                    emo_supcon_loss = loss_fn(
                        features=F.normalize(emo_projs.unsqueeze(1), dim=-1),
                        protos=None if emo_protos is None else F.normalize(emo_protos.unsqueeze(1), dim=-1),
                        labels=emo_labels,
                        proto_labels=None if emo_protos is None else torch.arange(emo_protos.shape[0],
                                                                                  device=emo_protos.device,
                                                                                  dtype=emo_labels.dtype)
                    )
                    emo_loss = (0.9 * emo_xent_loss) + (0.1 * emo_supcon_loss)
                elif self.objective == 'xent_cosine_sim':
                    xent = nn.CrossEntropyLoss(reduction='none')
                    emo_preds = self.emotion_pred_head(self.dropout(present_entity_embeddings))
                    emo_xent_loss = xent(emo_preds, emo_labels)

                    if torch.isnan(emo_xent_loss).any():
                        agg = torch.cat([emo_xent_loss.view(-1, 1), emo_labels.view(-1, 1).float()], dim=-1)
                        print('agg:\n{}'.format(agg))

                    emo_xent_loss = emo_xent_loss.mean()

                    if self.project_embds_for_contrast:
                        emo_projs = self.emotion_projection(self.dropout(present_entity_embeddings))
                    else:
                        emo_projs = present_entity_embeddings
                    # print('emo_preds: {}'.format(emo_preds))

                    if self.use_queue:
                        with torch.no_grad():
                            emo_protos = self.make_emo_protos(
                                emotion_queue, emo_labels,
                                single=True if self.supcon_loss_fn in ['cosine_sim', 'supcon'] else False
                            )
                    else:
                        emo_protos = None

                    emo_cosine_sim_loss = self.cosine_sim_loss_v2(emo_projs, emo_protos, emo_labels)
                    emo_loss = (0.9 * emo_xent_loss) + (0.1 * emo_cosine_sim_loss)
                else:
                    if self.supcon_loss_fn == 'byol':
                        emo_proj = self.projector(self.dropout(present_entity_embeddings))
                        emo_preds = self.predictor(emo_proj)

                        with torch.no_grad():
                            emo_protos = self.make_emo_protos(emotion_queue, emo_labels)
                            emo_proto_proj = self.projector(emo_protos).detach_()

                        emo_loss = byol_loss(emo_preds, emo_proto_proj.detach())
                        # print('emo_loss: {}'.format(emo_loss))

                    else:
                        if self.project_embds_for_contrast:
                            emo_preds = self.emotion_projection(self.dropout(present_entity_embeddings))
                        else:
                            emo_preds = present_entity_embeddings
                        # print('emo_preds: {}'.format(emo_preds))

                        if self.use_queue:
                            with torch.no_grad():
                                emo_protos = self.make_emo_protos(
                                    emotion_queue, emo_labels,
                                    single=True if self.supcon_loss_fn in ['cosine_sim', 'supcon'] else False
                                )
                        else:
                            emo_protos = None

                        if self.supcon_loss_fn == 'cosine_sim':
                            # emo_loss = self.cosine_sim_loss(emo_preds, emo_protos)
                            emo_loss = self.cosine_sim_loss_v2(emo_preds, emo_protos, emo_labels)
                            # print('cosine sim v2 loss: {}'.format(emo_loss))
                        elif self.supcon_loss_fn == 'supcon':
                            loss_fn = ModifiedSupConLoss(temperature=0.9, base_temperature=0.9)
                            emo_loss = loss_fn(
                                features=F.normalize(emo_preds.unsqueeze(1), dim=-1),
                                protos=None if emo_protos is None else F.normalize(emo_protos.unsqueeze(1), dim=-1),
                                labels=emo_labels,
                                proto_labels=None if emo_protos is None else torch.arange(emo_protos.shape[0],
                                                                                          device=emo_protos.device,
                                                                                          dtype=emo_labels.dtype)
                            )
                            # input('okty')
                        else:
                            # loss_fn = SupConLoss()
                            # emo_loss = loss_fn(F.normalize(emo_preds.unsqueeze(1), dim=-1),
                            #                    F.normalize(emo_protos.unsqueeze(1), dim=-1))
                            emo_loss = self.contrastive_loss(emo_preds, emo_protos)

                    # loss_fn = SupConLoss()
                    # print('raw emo_preds: {}'.format(emo_preds.shape))
                    # emo_preds = emo_preds.unsqueeze(1)
                    # print('unsqz emo_preds: {}'.format(emo_preds.shape))
                    # emo_loss = loss_fn(F.normalize(emo_preds.unsqueeze(1), dim=-1), emo_labels)

                    # emo_preds = self.emotion_projection(present_entity_embeddings)
                    # emo_labels = self.emotion_prototypes(emo_labels)
                    # emo_negs = self.emotion_prototypes(emo_negs)
                    # # print('emo_preds: {}'.format(emo_preds.shape))
                    # # print('emo_labels: {}'.format(emo_labels.shape))
                    # emo_loss = self.contrastive_loss(emo_preds, emo_labels, emo_negs)

            sent_preds = None
            sent_loss = None
            if sent_labels is not None and self.do_sentiment:
                sent_labels = sent_labels[entity_presence > 0]
                sent_negs = sent_negs[entity_presence > 0]
                if self.objective == 'xent':
                    xent = nn.CrossEntropyLoss()
                    sent_labels = sent_labels[entity_presence > 0]
                    # print('sent_labels: {}'.format(sent_labels.shape))
                    sent_preds = self.sent_pred_head(present_entity_embeddings)
                    # print('sent_preds: {}'.format(sent_preds.shape))
                    sent_loss = xent(sent_preds, sent_labels)
                else:
                    sent_preds = self.sentiment_projection(self.dropout(present_entity_embeddings))
                    sent_labels = self.sentiment_prototypes(sent_labels)
                    sent_negs = self.sentiment_prototypes(sent_negs)
                    sent_loss = self.contrastive_loss(sent_preds, sent_labels, sent_negs)

            to_return = [present_entity_embeddings, emo_preds, emo_loss, sent_preds, sent_loss]

        return to_return

    def process_data(self, input_ids, attn_mask, emo_labels=None, sent_labels=None, entity_presence=None,
                     emo_negs=None, sent_negs=None, position_ids=None, emotion_queue=None):
        mask_idxs = None
        if (self.mode == 'pt' or self.mode == 'pt_test') and self.token_mask_pct > 0.0:
            input_ids, mask_idxs = self.mask_input_ids(input_ids)

        input_embds = self.roberta.embeddings.word_embeddings(input_ids)

        if self.entity_attn:
            entity_embeddings = self.entity_embd(
                torch.zeros(input_ids.shape[0], self.max_n_utterances,
                            device=input_ids.device, dtype=input_ids.dtype)
            )
            input_embds = torch.concatenate(
                [entity_embeddings, input_embds], dim=1
            )
        else:
            pass

        token_type_ids = torch.zeros(input_embds.shape[0], input_embds.shape[1],
                                     dtype=input_ids.dtype, device=input_ids.device)
        lm_outputs = self.roberta(inputs_embeds=input_embds, attention_mask=attn_mask, token_type_ids=token_type_ids,
                                  position_ids=position_ids)
        # print('lm_outputslm_outputs: {}'.format(lm_outputs))
        lm_out_seq = lm_outputs.last_hidden_state
        # print('lm_out_seq: {}'.format(lm_out_seq.shape))
        # print('all_possible_entity_embeddings: {}'.format(all_possible_entity_embeddings.shape))
        if self.mode == 'pt' or self.mode == 'pt_test':
            if self.entity_attn:
                embds_to_return = lm_out_seq[:, self.max_n_utterances:, :][mask_idxs]
            else:
                embds_to_return = lm_out_seq[mask_idxs]
            to_return = [embds_to_return, mask_idxs]
        else:
            all_possible_entity_embeddings = lm_out_seq[:, :self.max_n_utterances, :]
            if self.add_cls_to_entity:
                cls_embd = lm_out_seq[:, self.max_n_utterances, :].unsqueeze(1)
                # print('present_entity_embeddings: {}'.format(all_possible_entity_embeddings.shape))
                # print('cls_embd: {}'.format(cls_embd.shape))
                all_possible_entity_embeddings += cls_embd

            to_return = all_possible_entity_embeddings[entity_presence > 0]

        return to_return

    def mask_input_ids(self, input_ids):
        mask_idxs = torch.zeros_like(input_ids, dtype=torch.bool)
        generator = None
        if self.mode == 'pt_test':
            generator = torch.Generator(device='cuda')
            generator.manual_seed(16)
        mask_probs = torch.rand(input_ids.shape, device=input_ids.device, generator=generator)
        # print('mask_probs: {}'.format(mask_probs.shape))

        input_ids[
            (mask_probs < self.token_mask_pct) & (input_ids != self.PAD_TOKEN_ID) & (input_ids != self.CLS_TOKEN_ID)
            ] = self.MASK_TOKEN_ID
        mask_idxs[
            (mask_probs < self.token_mask_pct) & (input_ids != self.PAD_TOKEN_ID) & (input_ids != self.CLS_TOKEN_ID)
            ] = True

        # print('input_ids:\n{}\n\t{} min: {} max: {}'.format(input_ids, input_ids.shape, input_ids.min(), input_ids.max()))
        # input('okty')
        return input_ids, mask_idxs

    def contrastive_loss(self, q, k):
        # normalize
        q = nn.functional.normalize(q, dim=1)
        k = nn.functional.normalize(k, dim=1)
        # gather all targets
        # k = concat_all_gather(k)
        # Einstein sum is more intuitive
        logits = torch.einsum('nc,mc->nm', [q, k]) / self.T
        N = logits.shape[0]  # batch size per GPU
        labels = (torch.arange(N, dtype=torch.long) + N * torch.distributed.get_rank()).cuda()
        return nn.CrossEntropyLoss()(logits, labels) * (2 * self.T)

    # def contrastive_loss(self, q, k, c):
    #     # score = F.cosine_similarity(q, k, dim=-1)
    #     # print('raw q:\n{}\n\t{}'.format(q, q.shape))
    #     q = q / torch.norm(q, p=2, dim=1, keepdim=True)
    #     # print('norm q:\n{}\n\t{}'.format(q, q.shape))
    #     # k = k / k.norm(2, -1)
    #     k = k / torch.norm(k, p=2, dim=1, keepdim=True)
    #     # print('k:\n{}\n\t{}'.format(k, k.shape))
    #     c = c / torch.norm(c, p=2, dim=1, keepdim=True)
    #
    #     score = (q * k) - (q * c)
    #
    #     # print('score:\n{}\n\t{}'.format(score, score.shape))
    #     score = score.sum(-1).mean()
    #
    #     return 1 - score

    def make_emo_protos(self, queue, labels, single=False):
        protos = []
        proto_targets = []

        for label_idx in range(queue.shape[0]):
            label_queue = queue[label_idx]
            label_queue = label_queue[label_queue.sum(-1) != 0]
            queue_order_perm = torch.randperm(label_queue.shape[0])
            label_queue = label_queue[queue_order_perm]
            queue_for_proto = label_queue[:64]
            # print('queue_for_proto: {}'.format(queue_for_proto.shape))
            if self.project_embds_for_contrast and self.supcon_loss_fn != 'byol':
                records_for_proto = self.emotion_projection(self.dropout(queue_for_proto))
            else:
                records_for_proto = queue_for_proto
            proto = records_for_proto.mean(dim=0).unsqueeze(0)
            protos.append(proto)
            # print('proto for label {}: {}'.format(label_idx, proto.shape))

        if single:
            for lbl in range(self.n_emotion_classes):
                proto_targets.append(protos[lbl])
        else:
            for lbl in labels:
                proto_targets.append(protos[lbl])

        proto_targets = torch.concat(proto_targets, dim=0)
        return proto_targets

    def cosine_sim_loss(self, q, k):
        score = F.cosine_similarity(q, k, dim=-1).mean()
        # print('raw cosine_sim_loss score: {}'.format(score.shape))

        return 1 - score

    def score_func(self, x, y):
        return (1+F.cosine_similarity(x, y, dim=-1))/2 + self.eps

    def cosine_sim_loss_v2(self, reps, protos, labels):
        batch_size = reps.shape[0]
        concated_reps = torch.cat([reps, protos], dim=0)
        concated_labels = torch.cat([labels, torch.arange(protos.shape[0], dtype=labels.dtype, device=labels.device)])

        concated_bsz = concated_reps.shape[0]
        mask1 = concated_labels.unsqueeze(0).expand(concated_labels.shape[0], concated_labels.shape[0])
        mask2 = concated_labels.unsqueeze(1).expand(concated_labels.shape[0], concated_labels.shape[0])
        mask = 1 - torch.eye(concated_bsz).to(reps.device)

        pos_mask = (mask1 == mask2).long()
        rep1 = concated_reps.unsqueeze(0).expand(concated_bsz, concated_bsz, concated_reps.shape[-1])
        rep2 = concated_reps.unsqueeze(1).expand(concated_bsz, concated_bsz, concated_reps.shape[-1])
        scores = self.score_func(rep1, rep2)
        scores *= 1 - torch.eye(concated_bsz).to(scores.device)

        scores = scores[:batch_size]
        pos_mask = pos_mask[:batch_size]
        mask = mask[:batch_size]
        scores -= torch.max(scores).item()

        # print('scores: {} pos_mask: {} mask: {}'.format(scores.shape, pos_mask.shape, mask.shape))

        scores = torch.exp(scores)
        pos_scores = scores * (pos_mask * mask)
        neg_scores = scores * (1 - pos_mask)
        probs = pos_scores.sum(-1) / (pos_scores.sum(-1) + neg_scores.sum(-1))
        probs /= (pos_mask * mask).sum(-1) + self.eps
        loss = - torch.log(probs + self.eps)
        # print('loss: {}'.format(loss))
        loss_mask = (loss > 0.3).long()
        loss = (loss * loss_mask).sum() / (loss_mask.sum().item() + self.eps)

        return loss


