

import csv
import os
import json
import math
import torch
import random

import numpy as np

from torch.utils.data import Dataset

from transformers import RobertaTokenizer


def make_entity_attn(u_ids, id_vec, max_n_utterances, max_len):
    # print('u_ids: {}'.format(u_ids))
    # print('id_vec: {}'.format(id_vec))
    attn_mask = torch.zeros(
        max_n_utterances, max_len
    )

    for idx, u_id in enumerate(u_ids):
        id_presence = (id_vec == u_id).nonzero()
        attn_mask[idx, id_presence] = 1

    return attn_mask


def load_meld(data_fp, emo_map, sent_map, tokenizer, max_len, max_n_utterances,
              CLS_TOKEN_ID, PAD_TOKEN_ID, SEP_TOKEN_ID, entity_attn, e2e_attn):
    items = []
    curr_dialogue_id = None

    with open(data_fp, 'r') as csv_file:
        csv_reader = csv.DictReader(csv_file)
        for row_idx, row in enumerate(csv_reader):
            # if row_idx > 0:
            if row['Dialogue_ID'] != curr_dialogue_id:
                items.append([])
                curr_dialogue_id = row['Dialogue_ID']

            items[-1].append(
                {
                    'utterance': row['Utterance'],
                    'utterance_id': int(row['Utterance_ID']),
                    'speaker': row['Speaker'],
                    'emotion': emo_map[row['Emotion']],
                    'sentiment': sent_map[row['Sentiment']],
                    'dialogue_id': int(row['Dialogue_ID'])
                }
            )

    built_items = []
    for dialogue in items:
        speaker_ids = {}
        speakers = [-1]  # init w/ no speaker
        input_ids = [CLS_TOKEN_ID]  # init w/ [CLS]
        attn_mask = [1]  # always allow to tend to [CLS]
        utterance_ids = [-1]  # [CLS] not part of an utterance
        unique_utterance_ids = []
        emo_labels = []
        sent_labels = []
        emo_negs = []
        sent_negs = []
        if entity_attn:
            if e2e_attn:
                position_ids = [514 + i for i in range(max_n_utterances)]
                position_ids.append(2)
            else:
                position_ids = [2 for _ in range(max_n_utterances + 1)]     # entity embeddings
        else:
            position_ids = [2]                                      # cls
        curr_position_idx = 2

        for utterance in dialogue:
            utt_speaker = utterance['speaker']
            utt_id = utterance['utterance_id']
            unique_utterance_ids.append(utt_id)
            emo_labels.append(utterance['emotion'])
            sent_labels.append(utterance['sentiment'])

            emo_complements_choices = [i for i in range(len(emo_map)) if i != utterance['emotion']]
            emo_complement = random.choice(emo_complements_choices)
            emo_negs.append(emo_complement)

            if utt_speaker not in speaker_ids:
                speaker_ids[utt_speaker] = len(speaker_ids)

            utt_speaker_id = speaker_ids[utt_speaker]
            utt_tokenized = tokenizer(utterance['utterance'], add_special_tokens=False)
            utt_input_ids = utt_tokenized['input_ids']
            utt_attn_mask = utt_tokenized['attention_mask']

            input_ids.extend(utt_input_ids)
            attn_mask.extend(utt_attn_mask)
            speakers.extend([utt_speaker_id for _ in range(len(utt_input_ids))])
            utterance_ids.extend([utt_id for _ in range(len(utt_input_ids))])
            position_ids.extend([curr_position_idx + i for i in range(len(utt_input_ids))])
            curr_position_idx += len(utt_input_ids)

            input_ids.append(SEP_TOKEN_ID)     # add sep token between utterances
            attn_mask.append(1)                     # allow to attend to sep token
            speakers.append(-1)                     # denote nobody is speaking
            utterance_ids.append(-1)                # denote nobody is speaking
            position_ids.append(curr_position_idx)
            curr_position_idx += 1

        if len(input_ids) > max_len:
            print('!! Not processing record w/ dialogue ID {} b/c too long !!'.format(
                dialogue[0]['dialogue_id']
            ))
        else:
            while len(input_ids) < max_len:
                input_ids.append(PAD_TOKEN_ID)
                # attn_mask.append(0)
                speakers.append(-1)
                utterance_ids.append(-1)
            
            if entity_attn:
                while len(position_ids) < max_len + max_n_utterances:
                    position_ids.append(1)
            else:
                while len(position_ids) < max_len:
                    position_ids.append(1)

            while len(emo_labels) < max_n_utterances:
                emo_labels.append(-1)
                sent_labels.append(-1)
                emo_negs.append(-1)

            input_ids = torch.tensor(input_ids)
            position_ids = torch.tensor(position_ids)
            # attn_mask = torch.tensor(attn_mask)
            speakers = torch.tensor(speakers)
            utterance_ids = torch.tensor(utterance_ids)
            emo_labels = torch.tensor(emo_labels)
            sent_labels = torch.tensor(sent_labels)
            emo_negs = torch.tensor(emo_negs)
            dialogue_id = torch.tensor([dialogue[0]['dialogue_id']])

            entity_presence = torch.zeros(max_n_utterances)
            entity_presence[:len(unique_utterance_ids)] = 1

            n_tokens = len(attn_mask)
            attn_mask = torch.zeros(max_len, max_len)
            attn_mask[:n_tokens, :n_tokens] = 1

            # print('input_ids: {}'.format(input_ids.shape))
            # print('attn_mask: {}'.format(attn_mask.shape))

            if entity_attn:
                utt_entity_attn = make_entity_attn(unique_utterance_ids, utterance_ids, max_n_utterances, max_len)
                # print('utt_entity_attn: {}'.format(utt_entity_attn.shape))

                entity_attn_c1 = torch.concatenate(
                    [
                        torch.zeros(max_n_utterances, max_n_utterances),
                        utt_entity_attn
                    ], dim=1
                )
                # print('entity_attn_c1: {}'.format(entity_attn_c1.shape))

                entity_attn_c2 = torch.concatenate(
                    [
                        utt_entity_attn.T, attn_mask
                    ], dim=1
                )
                # print('entity_attn_c2: {}'.format(entity_attn_c2.shape))

                attn_mask = torch.concatenate(
                    [
                        entity_attn_c1, entity_attn_c2
                    ], dim=0
                )
                # print('attn_mask: {}'.format(attn_mask.shape))

                if e2e_attn:
                    attn_mask[:len(unique_utterance_ids), :len(unique_utterance_ids)] = 1
                    # if e2e_attn_masking:
                    #     utt_presence = torch.zeros_like(attn_mask)
                    #     utt_presence[:len(unique_utterance_ids), :len(unique_utterance_ids)] = 1
                    #     mask_probs = torch.rand_like(attn_mask)
                    #     attn_mask[
                    #         (mask_probs <= 0.15) & (utt_presence > 0)
                    #     ] = 0
                    #
                    #     for idx in range(len(unique_utterance_ids)):
                    #         attn_mask[idx, idx] = 1

                else:
                    for idx in range(len(unique_utterance_ids)):
                        attn_mask[idx, idx] = 1

            out = {
                'input_ids': input_ids,
                'position_ids': position_ids,
                'attn_mask': attn_mask,
                'dialogue_id': dialogue_id,
                'emo_labels': emo_labels,
                'sent_labels': sent_labels,
                'entity_presence': entity_presence,
                'source': 'meld'
            }
            built_items.append(out)

    return built_items


def load_iemocap(data_dir, emo_map, sent_map, tokenizer, max_len, max_n_utterances,
                 CLS_TOKEN_ID, PAD_TOKEN_ID, SEP_TOKEN_ID, entity_attn, e2e_attn):
    items = []
    for dialog_idx, dialog_filename in enumerate(os.listdir(data_dir)):
        dialog_fp = os.path.join(data_dir, dialog_filename)

        speakers = [-1]  # init w/ no speaker
        input_ids = [CLS_TOKEN_ID]  # init w/ [CLS]
        attn_mask = [1]  # always allow to tend to [CLS]
        utterance_ids = [-1]  # [CLS] not part of an utterance
        unique_utterance_ids = []
        emo_labels = []
        if entity_attn:
            if e2e_attn:
                position_ids = [514 + i for i in range(max_n_utterances)]
                position_ids.append(2)
            else:
                position_ids = [2 for _ in range(max_n_utterances + 1)]  # entity embeddings
        else:
            position_ids = [2]                                      # cls
        curr_position_idx = 2

        with open(dialog_fp, 'r') as f:
            for utt_idx, line in enumerate(f):
                if utt_idx < max_n_utterances:
                    utterance = json.loads(line.strip())

                    utter_text = utterance['text']
                    labels = utterance['labels']

                    utt_tokenized = tokenizer(utter_text, add_special_tokens=False)
                    utt_input_ids = utt_tokenized['input_ids']
                    utt_attn_mask = utt_tokenized['attention_mask']

                    if len(input_ids) + len(utt_input_ids) + 1 < max_len:
                        unique_utterance_ids.append(utt_idx)
                        emo_labels.append([emo_map[l] for l in labels])

                        input_ids.extend(utt_input_ids)
                        attn_mask.extend(utt_attn_mask)
                        speakers.extend([utt_idx for _ in range(len(utt_input_ids))])
                        utterance_ids.extend([utt_idx for _ in range(len(utt_input_ids))])
                        position_ids.extend([curr_position_idx + i for i in range(len(utt_input_ids))])
                        curr_position_idx += len(utt_input_ids)

                        input_ids.append(SEP_TOKEN_ID)  # add sep token between utterances
                        attn_mask.append(1)  # allow to attend to sep token
                        speakers.append(-1)  # denote nobody is speaking
                        utterance_ids.append(-1)  # denote nobody is speaking
                        position_ids.append(curr_position_idx)
                        curr_position_idx += 1

        while len(input_ids) < max_len:
            input_ids.append(PAD_TOKEN_ID)
            # attn_mask.append(0)
            speakers.append(-1)
            utterance_ids.append(-1)

        if entity_attn:
            while len(position_ids) < max_len + max_n_utterances:
                position_ids.append(1)
        else:
            while len(position_ids) < max_len:
                position_ids.append(1)

        while len(emo_labels) < max_n_utterances:
            emo_labels.append(-1)

        input_ids = torch.tensor(input_ids)
        position_ids = torch.tensor(position_ids)
        utterance_ids = torch.tensor(utterance_ids)
        # emo_labels = torch.tensor(emo_labels)
        entity_presence = torch.zeros(max_n_utterances)
        entity_presence[:len(unique_utterance_ids)] = 1

        n_tokens = len(attn_mask)
        attn_mask = torch.zeros(max_len, max_len)
        attn_mask[:n_tokens, :n_tokens] = 1

        if entity_attn:
            utt_entity_attn = make_entity_attn(unique_utterance_ids, utterance_ids, max_n_utterances, max_len)

            entity_attn_c1 = torch.concatenate(
                [
                    torch.zeros(max_n_utterances, max_n_utterances),
                    utt_entity_attn
                ], dim=1
            )
            entity_attn_c2 = torch.concatenate(
                [
                    utt_entity_attn.T, attn_mask
                ], dim=1
            )
            attn_mask = torch.concatenate(
                [
                    entity_attn_c1, entity_attn_c2
                ], dim=0
            )
            if e2e_attn:
                attn_mask[:len(unique_utterance_ids), :len(unique_utterance_ids)] = 1
                # if e2e_attn_masking:
                #     utt_presence = torch.zeros_like(attn_mask)
                #     utt_presence[:len(unique_utterance_ids), :len(unique_utterance_ids)] = 1
                #     mask_probs = torch.rand_like(attn_mask)
                #     attn_mask[
                #         (mask_probs <= 0.15) & (utt_presence > 0)
                #         ] = 0
                #
                #     for idx in range(len(unique_utterance_ids)):
                #         attn_mask[idx, idx] = 1
            else:
                for idx in range(len(unique_utterance_ids)):
                    attn_mask[idx, idx] = 1

        out = {
            'input_ids': input_ids,
            'position_ids': position_ids,
            'attn_mask': attn_mask,
            'emo_labels': emo_labels,
            'entity_presence': entity_presence,
            'source': 'iemocap',
            'dialogue_id': torch.tensor([dialog_idx])
        }
        items.append(out)

    return items


def load_emory_nlp(data_fp, emo_map, sent_map, tokenizer, max_len, max_n_utterances,
                   CLS_TOKEN_ID, PAD_TOKEN_ID, SEP_TOKEN_ID, entity_attn, e2e_attn):
    items = []
    data_j = json.load(open(data_fp))
    for episode_j in data_j['episodes']:
        episode_id = episode_j['episode_id']
        for scene_idx, scene_j in enumerate(episode_j['scenes']):
            scene_id = scene_j['scene_id']

            speakers = [-1]  # init w/ no speaker
            input_ids = [CLS_TOKEN_ID]  # init w/ [CLS]
            attn_mask = [1]  # always allow to tend to [CLS]
            utterance_ids = [-1]  # [CLS] not part of an utterance
            unique_utterance_ids = []
            emo_labels = []
            if entity_attn:
                if e2e_attn:
                    position_ids = [514 + i for i in range(max_n_utterances)]
                    position_ids.append(2)
                else:
                    position_ids = [2 for _ in range(max_n_utterances + 1)]  # entity embeddings
            else:
                position_ids = [2]                                      # cls
            curr_position_idx = 3

            for utt_idx, utter_j in enumerate(scene_j['utterances']):
                utter_text = utter_j['transcript']
                emotion_str = utter_j['emotion'].lower()

                if emotion_str == 'joyful':
                    emotion_str = 'joy'
                elif emotion_str == 'sad':
                    emotion_str = 'sadness'
                elif emotion_str == 'scared':
                    emotion_str = 'fear'
                elif emotion_str == 'mad':
                    emotion_str = 'anger'

                utt_tokenized = tokenizer(utter_text, add_special_tokens=False)
                utt_input_ids = utt_tokenized['input_ids']
                utt_attn_mask = utt_tokenized['attention_mask']
                if len(input_ids) + len(utt_input_ids) + 1 < max_len:
                    unique_utterance_ids.append(utt_idx)
                    emotion_label = emo_map[emotion_str]
                    emo_labels.append(emotion_label)

                    input_ids.extend(utt_input_ids)
                    attn_mask.extend(utt_attn_mask)
                    speakers.extend([utt_idx for _ in range(len(utt_input_ids))])
                    utterance_ids.extend([utt_idx for _ in range(len(utt_input_ids))])
                    position_ids.extend([curr_position_idx + i for i in range(len(utt_input_ids))])
                    curr_position_idx += len(utt_input_ids)

                    input_ids.append(SEP_TOKEN_ID)  # add sep token between utterances
                    attn_mask.append(1)  # allow to attend to sep token
                    speakers.append(-1)  # denote nobody is speaking
                    utterance_ids.append(-1)  # denote nobody is speaking
                    position_ids.append(curr_position_idx)
                    curr_position_idx += 1

            while len(input_ids) < max_len:
                input_ids.append(PAD_TOKEN_ID)
                # attn_mask.append(0)
                speakers.append(-1)
                utterance_ids.append(-1)

            if entity_attn:
                while len(position_ids) < max_len + max_n_utterances:
                    position_ids.append(1)
            else:
                while len(position_ids) < max_len:
                    position_ids.append(1)

            while len(emo_labels) < max_n_utterances:
                emo_labels.append(-1)

            input_ids = torch.tensor(input_ids)
            position_ids = torch.tensor(position_ids)
            utterance_ids = torch.tensor(utterance_ids)
            emo_labels = torch.tensor(emo_labels)
            entity_presence = torch.zeros(max_n_utterances)
            entity_presence[:len(unique_utterance_ids)] = 1

            n_tokens = len(attn_mask)
            attn_mask = torch.zeros(max_len, max_len)
            attn_mask[:n_tokens, :n_tokens] = 1

            if entity_attn:
                utt_entity_attn = make_entity_attn(unique_utterance_ids, utterance_ids, max_n_utterances, max_len)

                entity_attn_c1 = torch.concatenate(
                    [
                        torch.zeros(max_n_utterances, max_n_utterances),
                        utt_entity_attn
                    ], dim=1
                )
                entity_attn_c2 = torch.concatenate(
                    [
                        utt_entity_attn.T, attn_mask
                    ], dim=1
                )
                attn_mask = torch.concatenate(
                    [
                        entity_attn_c1, entity_attn_c2
                    ], dim=0
                )
                if e2e_attn:
                    attn_mask[:len(unique_utterance_ids), :len(unique_utterance_ids)] = 1
                    # if e2e_attn_masking:
                    #     utt_presence = torch.zeros_like(attn_mask)
                    #     utt_presence[:len(unique_utterance_ids), :len(unique_utterance_ids)] = 1
                    #     mask_probs = torch.rand_like(attn_mask)
                    #     attn_mask[
                    #         (mask_probs <= 0.15) & (utt_presence > 0)
                    #     ] = 0
                    #
                    #     for idx in range(len(unique_utterance_ids)):
                    #         attn_mask[idx, idx] = 1
                else:
                    for idx in range(len(unique_utterance_ids)):
                        attn_mask[idx, idx] = 1

            out = {
                'input_ids': input_ids,
                'position_ids': position_ids,
                'attn_mask': attn_mask,
                'emo_labels': emo_labels,
                'entity_presence': entity_presence,
                'source': 'emory_nlp',
                'dialogue_id': torch.tensor([scene_idx])
            }
            items.append(out)

    return items


class EmoSentDataset(Dataset):
    def __init__(self, args, mode):
        self.CLS_TOKEN_ID = 0
        self.PAD_TOKEN_ID = 1
        self.SEP_TOKEN_ID = 2

        if mode == 'inspect':
            mode = 'test'
        elif mode == 'pt':
            mode = 'train'
        elif mode == 'pt_test':
            mode = 'test'

        self.args = args
        self.mode = mode

        self.max_len = self.args.max_seq_len
        self.max_n_utterances = self.args.max_n_utterances
        self.use_meld = self.args.use_meld
        self.use_iemocap = self.args.use_iemocap
        self.use_emory_nlp = self.args.use_emory_nlp
        self.v2_attn = self.args.v2_attn
        self.entity_attn = self.args.entity_attn
        self.e2e_attn = self.args.e2e_attn
        self.e2e_attn_masking = self.args.e2e_attn_masking
        self.e2e_attn_mask_prob = self.args.e2e_attn_mask_prob

        # self.tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
        self.tokenizer = RobertaTokenizer.from_pretrained('roberta-large')

        meld_emotions = ['neutral', 'surprise', 'fear', 'sadness', 'joy', 'disgust', 'anger']
        iemocap_emotions = ['happiness', 'frustration', 'excited', 'other']
        # emory_nlp_emotions = ['neutral', 'joy', 'peaceful', 'powerful', 'scared', 'mad', 'sadness']
        emory_nlp_emotions = ['neutral', 'joy', 'peaceful', 'powerful', 'fear', 'anger', 'sadness']
        # self.emotion_map = {
        #     'neutral': 0,
        #     'surprise': 1,
        #     'fear': 2,
        #     'sadness': 3,
        #     'joy': 4,
        #     'disgust': 5,
        #     'anger': 6
        # }
        # self.source_mapping = {'meld': 0}
        self.emotion_map = {}
        self.source_mapping = {}
        if self.use_meld:
            self.source_mapping['meld'] = len(self.source_mapping)
            for emotion in meld_emotions:
                emotion_id = self.emotion_map.get(emotion, None)
                if emotion_id is None:
                    print('Adding {} to emotion map for MELD...'.format(emotion))
                    emotion_id = len(self.emotion_map)
                    self.emotion_map[emotion] = emotion_id

        if self.use_iemocap:
            self.source_mapping['iemocap'] = len(self.source_mapping)
            for emotion in iemocap_emotions:
                emotion_id = self.emotion_map.get(emotion, None)
                if emotion_id is None:
                    print('Adding {} to emotion map for IEMOCAP...'.format(emotion))
                    emotion_id = len(self.emotion_map)
                    self.emotion_map[emotion] = emotion_id

        if self.use_emory_nlp:
            self.source_mapping['emory_nlp'] = len(self.source_mapping)
            for emotion in emory_nlp_emotions:
                emotion_id = self.emotion_map.get(emotion, None)
                if emotion_id is None:
                    print('Adding {} to emotion map for Emory NLP...'.format(emotion))
                    emotion_id = len(self.emotion_map)
                    self.emotion_map[emotion] = emotion_id

        if self.use_meld:
            self.meld_pred_mask = torch.zeros(1, len(self.emotion_map))
            for emotion in meld_emotions:
                emotion_idx = self.emotion_map[emotion]
                self.meld_pred_mask[0, emotion_idx] = 1

        if self.use_emory_nlp:
            self.emory_nlp_pred_mask = torch.zeros(1, len(self.emotion_map))
            for emotion in emory_nlp_emotions:
                emotion_idx = self.emotion_map[emotion]
                self.emory_nlp_pred_mask[0, emotion_idx] = 1

        print('Dataset emotion map:\n{}\n\t{}'.format(json.dumps(self.emotion_map, indent=2), len(self.emotion_map)))

        self.sentiment_map = {
            'neutral': 0,
            'positive': 1,
            'negative': 2
        }

        self.items = []
        if self.use_meld:
            meld_basedir = self.args.meld_dir
            meld_fp = os.path.join(meld_basedir, '{}_sent_emo.csv'.format(
                mode if mode in ['train', 'test'] else 'dev'
            ))

            print('Reading MELD items...')
            meld_items = load_meld(
                meld_fp, self.emotion_map, self.sentiment_map, self.tokenizer, self.max_len, self.max_n_utterances,
                self.CLS_TOKEN_ID, self.PAD_TOKEN_ID, self.SEP_TOKEN_ID, self.entity_attn, self.e2e_attn
            )
            print('\rread {} MELD items...'.format(len(meld_items)))
            self.items.extend(meld_items)

        if self.use_emory_nlp:
            print('Reading Emory NLP items...')
            emory_nlp_basedir = self.args.emory_nlp_dir
            if mode == 'train':
                emory_nlp_fp = os.path.join(emory_nlp_basedir, 'emotion-detection-trn.json')
            elif mode == 'test':
                emory_nlp_fp = os.path.join(emory_nlp_basedir, 'emotion-detection-tst.json')
            else:
                emory_nlp_fp = os.path.join(emory_nlp_basedir, 'emotion-detection-dev.json')

            emory_nlp_items = load_emory_nlp(
                emory_nlp_fp, self.emotion_map, self.sentiment_map, self.tokenizer, self.max_len,
                self.max_n_utterances, self.CLS_TOKEN_ID, self.PAD_TOKEN_ID, self.SEP_TOKEN_ID, self.entity_attn,
                self.e2e_attn
            )
            print('\tread {} Emory NLP items...'.format(len(emory_nlp_items)))
            # for item_idx, item in enumerate(emory_nlp_items):
            #     print('Emory NLP item {}:'.format(item_idx))
            #     for k, v in item.items():
            #         if type(v) == torch.Tensor:
            #             print('\tk: {} v: {}'.format(k, v.shape))
            #         else:
            #             print('\tk: {} v: {}'.format(k, v))
            # input('okty')
            self.items.extend(emory_nlp_items)

        if self.use_iemocap and self.mode == 'train':
            print('Reading IEMOCAP items...')
            iemocap_basedir = self.args.iemocap_dir
            iemocap_items = load_iemocap(
                iemocap_basedir, self.emotion_map, self.sentiment_map, self.tokenizer, self.max_len,
                self.max_n_utterances, self.CLS_TOKEN_ID, self.PAD_TOKEN_ID, self.SEP_TOKEN_ID, self.entity_attn,
                self.e2e_attn
            )
            print('\read {} IEMOCAP items...'.format(len(iemocap_items)))
            self.items.extend(iemocap_items)

    def __len__(self):
        return len(self.items)

    def __getitem__(self, idx):
        item = {k: v for k, v in self.items[idx].items()}

        if self.v2_attn:
            n_entities_present = int(item['entity_presence'].sum())
            # print('n_entities_present: {}'.format(n_entities_present))
            item['attn_mask'][self.max_n_utterances, :n_entities_present] = 1
            item['attn_mask'][:n_entities_present, self.max_n_utterances] = 1

        if self.e2e_attn_masking and self.mode == 'train':
            n_entities_present = int(item['entity_presence'].sum())
            utt_presence = torch.zeros_like(item['attn_mask'])
            utt_presence[:n_entities_present, :n_entities_present] = 1
            mask_probs = torch.rand_like(item['attn_mask'])
            item['attn_mask'][
                (mask_probs <= self.e2e_attn_mask_prob) & (utt_presence > 0)
            ] = 0

            for idx in range(n_entities_present):
                item['attn_mask'][idx, idx] = 1

        # print('attn_mask:\n{}\n\t{}'.format(item['attn_mask'], item['attn_mask'].shape))
        # print('attn_mask[50,:]: {}'.format(item['attn_mask'][self.max_n_utterances, :]))
        # print('attn_mask[:, 50]: {}'.format(item['attn_mask'][:, self.max_n_utterances]))
        # np.save('/home/czh/nvme1/SportsAnalytics/misc/attn_mask_ex.npy', item['attn_mask'])
        # input('okty')
        # print('item: {}'.format(item.keys()))
        item_source = item['source']
        emo_negs = []
        if item_source != 'iemocap':
            emo_labels = item['emo_labels']
            for lbl in emo_labels:
                if lbl != -1:
                    # print('Finding neg for lbl {}...'.format(lbl))
                    emo_complements_choices = [i for i in range(len(self.emotion_map)) if i != lbl]
                    emo_complement = random.choice(emo_complements_choices)
                    emo_negs.append(emo_complement)
                else:
                    # print('{} is not real label!'.format(lbl))
                    emo_negs.append(lbl)

            emo_negs = torch.tensor(emo_negs)
            item['emo_negs'] = emo_negs

            # sent_labels = item['sent_labels']
            # sent_negs = []
            # for lbl in sent_labels:
            #     if lbl != -1:
            #         # print('Finding neg for lbl {}...'.format(lbl))
            #         sent_complements_choices = [i for i in range(len(self.sentiment_map)) if i != lbl]
            #         sent_complement = random.choice(sent_complements_choices)
            #         sent_negs.append(sent_complement)
            #     else:
            #         # print('{} is not real label!'.format(lbl))
            #         sent_negs.append(lbl)
            #
            # sent_negs = torch.tensor(sent_negs)
            # item['sent_negs'] = sent_negs

            if item_source == 'meld':
                del item['sent_labels']
                # del item['sent_negs']
        else:
            # print('making IEMOCAP item...')
            raw_emo_labels = item['emo_labels']
            emo_labels = []
            for label_set in raw_emo_labels:
                if label_set != -1:
                    selected_label = random.choice(label_set)
                    emo_complements_choices = [i for i in range(len(self.emotion_map)) if i != selected_label and i not in label_set]
                    emo_complement = random.choice(emo_complements_choices)

                    emo_labels.append(selected_label)
                    emo_negs.append(emo_complement)

            while len(emo_labels) < self.max_n_utterances:
                emo_labels.append(-1)
                emo_negs.append(-1)

            emo_labels = torch.tensor(emo_labels)
            emo_negs = torch.tensor(emo_negs)

            item['emo_labels'] = emo_labels
            item['emo_negs'] = emo_negs

        item_source = torch.tensor([self.source_mapping[item_source]])
        item['item_source'] = item_source
        del item['source']
        # for k, v in item.items():
        #     if type(v) == torch.Tensor:
        #         print('k: {} v.shape: {}'.format(k, v.shape))
        if type(emo_labels) == list:
            print('emo_labels: {}'.format(emo_labels))
            print('item_source: {}'.format(item_source))

        return item

    def make_entity_attn(self, u_ids, id_vec):
        print('u_ids: {}'.format(u_ids))
        print('id_vec: {}'.format(id_vec))
        attn_mask = torch.zeros(
            self.max_n_utterances, self.max_len
        )

        for idx, u_id in enumerate(u_ids):
            id_presence = (id_vec == u_id).nonzero()
            attn_mask[idx, id_presence] = 1

        return attn_mask






