
import os
import time
import json
import math
import torch
import itertools

import numpy as np
import pandas as pd

import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist

from torch import optim
from sklearn.metrics import f1_score
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torch.nn.parallel import DistributedDataParallel as DDP
from transformers import get_constant_schedule_with_warmup, get_cosine_schedule_with_warmup

from . import BaseRunner
from ..utils import jsonKeys2int
from ..datasets import EmoSentDataset
from ..models import LMWithEntityAttention


import torch._dynamo

def update_queue(queue, embeddings, labels):
    # print('embeddings: {}'.format(embeddings.shape))
    for label_idx in range(queue.shape[0]):
        # print('Updating emotion idx {} in q...'.format(label_idx))
        label_embds = embeddings[labels == label_idx]
        # print('\tlabel_embds: {}'.format(label_embds.shape))

        if label_embds.shape[0] > queue.shape[1]:
            label_embds = label_embds[:queue.shape[1]]

        n_embds = label_embds.shape[0]
        # print('\tlabel_embds: {}'.format(label_embds.shape))
        # print('\t{} embeddings to add...'.format(n_embds))
        if n_embds > 0:
            queue[label_idx, n_embds:] = queue[label_idx, :-n_embds].clone()
            queue[label_idx, :n_embds] = label_embds.detach()

    return queue


class EmoSentRunner(BaseRunner):
    def __init__(self, gpu, mode, args):
        self.rank = gpu
        self.mode = mode
        self.args = args
        self.force_model_save = False
        torch._dynamo.config.verbose = True

        print('Initializing EmoSentRunner on device {}...'.format(gpu))
        if self.args.on_cpu:
            self.device = torch.device('cpu')
        else:
            self.device = torch.device('cuda:{}'.format(self.rank))
            torch.cuda.set_device(self.device)

        self.world_size = len(self.args.gpus)
        print('** EmoSentRunner.world_size: {} **'.format(self.world_size))
        print('\ttorch.cuda.device_count(): {}'.format(torch.cuda.device_count()))
        # torch.manual_seed(self.args.seed)
        dist.init_process_group('nccl',
                                world_size=self.world_size,
                                rank=self.rank)

        self.out = args.out
        self.log_every = args.log_every
        self.summary_every = args.summary_every
        self.save_model_every = args.save_model_every
        self.print_every = args.print_every
        self.n_warmup_iters = self.args.n_warmup_iters
        self.max_n_utterances = self.args.max_n_utterances
        self.do_sentiment = self.args.do_sentiment
        self.objective = self.args.objective
        self.queue_size = self.args.queue_size
        self.project_embds_for_contrast = self.args.project_embds_for_contrast
        self.freeze_roberta = self.args.freeze_roberta
        self.supcon_loss_fn = self.args.supcon_loss_fn
        self.use_queue = self.args.use_queue

        if self.mode == 'inspect':
            if self.rank == 0:
                print('Making inspection out dir...')
            self.inspection_dir = os.path.join(args.out, 'inspection')
            if not os.path.exists(self.inspection_dir):
                os.makedirs(self.inspection_dir)

        self.pred_dir = os.path.join(self.args.out, 'preds')
        if mode == 'test':
            if self.rank == 0:
                print('== mode is test, setting log_every to 1 ==')
            self.log_every = 1

        self.torch_amp = getattr(self.args, 'torch_amp', True)
        self.lr = getattr(self.args, 'lr', 1e-5)
        self.l2 = getattr(self.args, 'l2', 0.0001)

        if self.torch_amp:
            print('** Using Torch AMP **')
            self.scaler = torch.cuda.amp.GradScaler()

        if not os.path.exists(self.pred_dir) and self.rank == 0:
            os.makedirs(self.pred_dir)

        print('EmoSentRunner on device {} making dataset...'.format(self.rank))
        self.dataset = EmoSentDataset(args, mode)
        self.meld_pred_mask = self.dataset.meld_pred_mask
        self.emory_nlp_pred_mask = self.dataset.emory_nlp_pred_mask
        self.n_emotions = len(self.dataset.emotion_map)
        self.source_mapping = self.dataset.source_mapping
        self.reverse_source_mapping = {v: k for k, v in self.source_mapping.items()}
        print('\tlen(dataset): {}'.format(len(self.dataset)))
        if self.args.on_cpu:
            data_sampler = None
        else:
            data_sampler = torch.utils.data.distributed.DistributedSampler(
                self.dataset, num_replicas=args.world_size, rank=self.rank,
                shuffle=True if self.mode in ['train', 'pt'] else False,
            )
        self.data_loader = DataLoader(self.dataset, batch_size=self.args.batch_size, shuffle=False,
                                      num_workers=self.args.n_data_workers if self.mode == 'train'
                                      else self.args.n_data_workers_else,
                                      pin_memory=True, sampler=data_sampler,
                                      drop_last=True if mode in ['train', 'pt'] else False, persistent_workers=False)

        self.n_iters = int(math.ceil(len(self.dataset) / (self.args.batch_size * len(self.args.gpus))))
        self.aux_dataset = None
        self.aux_data_loader = None
        self.aux_n_iters = None
        if self.mode in ['train', 'pt'] and (self.args.dev or self.args.dev_every > 0):
            print('EmoSentRunner on device {} creating auxiliary dataset...'.format(self.rank))
            self.aux_dataset = EmoSentDataset(self.args, 'dev')

            if self.args.on_cpu:
                aux_data_sampler = None
            else:
                aux_data_sampler = torch.utils.data.distributed.DistributedSampler(self.aux_dataset,
                                                                                   num_replicas=args.world_size,
                                                                                   rank=self.rank,
                                                                                   shuffle=False)

            self.aux_data_loader = DataLoader(self.aux_dataset, batch_size=self.args.batch_size, shuffle=False,
                                              num_workers=int(self.args.n_data_workers_else), pin_memory=True,
                                              sampler=aux_data_sampler, drop_last=True, persistent_workers=False)
            self.aux_n_iters = int(math.ceil(len(self.aux_dataset) / (self.args.batch_size * len(self.args.gpus))))

        print('EmoSentRunner on device {} creating model...'.format(self.rank))
        self.model = LMWithEntityAttention(self.args, self.mode)
        # self.model = torch.compile(self.model)
        self.start_epoch = 0
        self.n_epochs = 1
        if self.mode != 'train' and self.mode != 'pt':
            # ckpt_file = self.args.out_dir
            if self.args.train or self.args.pt:
                ckpt_epoch_offset = 1
                ckpt_file = os.path.join(self.args.model_save_dir,
                                         self.args.ckpt_file_tmplt.format(self.args.epochs - ckpt_epoch_offset))
                while not os.path.exists(ckpt_file) and self.args.epochs - ckpt_epoch_offset >= 0:
                    ckpt_epoch_offset += 1
                    ckpt_file = os.path.join(self.args.model_save_dir,
                                             self.args.ckpt_file_tmplt.format(self.args.epochs - ckpt_epoch_offset))
            else:
                ckpt_file = self.args.ckpt_file
        else:
            ckpt_file = self.args.ckpt_file

        if self.rank == 0:
            print('*** ckpt_file: {} ***'.format(ckpt_file))
        if ckpt_file is not None:
            if self.rank == 0:
                print('Loading model from ckpt...')
                print('\tckpt_file: {}'.format(ckpt_file))
            map_location = {'cuda:0': 'cpu'}
            state_dict = torch.load(ckpt_file, map_location=map_location)

            self.model.load_state_dict(state_dict, strict=False if self.mode != 'test' and self.mode != 'pt_test' else True)
            model_epoch = int(os.path.split(ckpt_file)[-1].split('_')[1].split('e')[0])
            if self.mode != 'train' and self.mode != 'pt':
                self.start_epoch = model_epoch
                self.n_epochs = model_epoch + 1

        self.model = self.model.to(self.device)

        if not self.args.on_cpu:
            self.model = DDP(self.model, device_ids=[self.rank],
                             find_unused_parameters=True)

        self.summary_writer = None
        if self.rank == 0:
            self.summary_writer = SummaryWriter(log_dir=self.args.tb_dir)

        self.scheduler = None
        if self.mode == 'train' or self.mode == 'pt':
            self.n_epochs = self.args.epochs

            no_decay = ['layernorm', 'norm']
            param_optimizer = list(self.model.named_parameters())
            no_decay_parms = []
            reg_parms = []
            emo_pred_head_parms = []
            # if self.rank == 0:
            #     for idx, (n, p) in enumerate(self.model.named_parameters()):
            #         print('{}: {}'.format(idx, n))

            for n, p in param_optimizer:
                # print('n: {} requires_grad: {}'.format(n, p.requires_grad))
                if 'emotion_pred_head' in n:
                    emo_pred_head_parms.append(p)
                elif any(nd in n for nd in no_decay):
                    no_decay_parms.append(p)
                else:
                    reg_parms.append(p)

            optimizer_grouped_parameters = [
                {'params': reg_parms, 'weight_decay': self.l2},
                {'params': emo_pred_head_parms, 'weight_decay': self.l2 / 10},
                {'params': no_decay_parms, 'weight_decay': 0.0},
            ]
            if self.rank == 0:
                print('n parms: {}'.format(len(param_optimizer)))
                print('len(optimizer_grouped_parameters[0]): {}'.format(len(optimizer_grouped_parameters[0]['params'])))
                print('len(optimizer_grouped_parameters[1]): {}'.format(len(optimizer_grouped_parameters[1]['params'])))
                print('len(optimizer_grouped_parameters[2]): {}'.format(len(optimizer_grouped_parameters[2]['params'])))

            if self.rank == 0:
                print('Making Adam optimizer...')
            self.optimizer = optim.AdamW(optimizer_grouped_parameters, lr=self.lr,
                                         betas=(0.9, 0.95))

            if self.n_warmup_iters > 0:
                # self.scheduler = get_constant_schedule_with_warmup(self.optimizer,
                #                                                    num_warmup_steps=self.n_warmup_iters)
                self.scheduler = get_cosine_schedule_with_warmup(
                    self.optimizer,
                    num_warmup_steps=self.n_warmup_iters,
                    num_training_steps=self.n_iters * self.n_epochs
                )

        print('torch.backends.cuda.flash_sdp_enabled(): {}'.format(torch.backends.cuda.flash_sdp_enabled()))
        print('torch.backends.cuda.mem_efficient_sdp_enabled(): {}'.format(
            torch.backends.cuda.mem_efficient_sdp_enabled()))
        print('torch.backends.cuda.math_sdp_enabled(): {}'.format(torch.backends.cuda.math_sdp_enabled()))

        self.emo_headers = [
            'dialogue_id', 'utterance_id', 'p_neutral', 'p_surprise', 'p_fear', 'p_sadness', 'p_joy', 'p_disgust',
            'p_anger', 'label'
        ]
        self.sent_headers = [
            'dialogue_id', 'utterance_id', 'p_neutral', 'p_positive', 'p_negative', 'label'
        ]
        self.pred_fp_tmplt = os.path.join(
            self.pred_dir, '{}_{}_preds_e{}.csv'
        )
        # input('okty')
        self.original_protos = []

        if self.objective != 'xent' and self.mode != 'pt' and self.mode != 'pt_test' and self.use_queue:
            self.emotion_queue = torch.zeros(self.n_emotions, self.queue_size, 1024, device=self.device)

            if self.mode == 'train':
                queue_dataset = EmoSentDataset(args, 'train')
                queue_dataset.e2e_attn_masking = False

                data_sampler = torch.utils.data.distributed.DistributedSampler(
                    self.dataset, num_replicas=args.world_size, rank=self.rank,
                    shuffle=True if self.mode in ['train', 'pt'] else False,
                )
                queue_data_loader = DataLoader(self.dataset, batch_size=self.args.batch_size, shuffle=False,
                                      num_workers=self.args.n_data_workers if self.mode == 'train'
                                      else self.args.n_data_workers_else,
                                      pin_memory=True, sampler=data_sampler,
                                      drop_last=True if mode in ['train', 'pt'] else False, persistent_workers=False)
                self.model.eval()
                with torch.no_grad():
                    self.preload_emo_queue(queue_data_loader)
                self.model.train()
        else:
            self.emotion_queue = None

        self.run()

    def preload_emo_queue(self, data_loader):
        print('Loading queue...')
        for batch_idx, batch_data in enumerate(data_loader):
            input_ids = batch_data['input_ids'].to(self.device, non_blocking=True)
            attn_mask = batch_data['attn_mask'].to(self.device, non_blocking=True)
            dialogue_ids = batch_data['dialogue_id'].to(self.device, non_blocking=True)
            emo_labels = batch_data['emo_labels'].to(self.device, non_blocking=True)
            position_ids = batch_data['position_ids'].to(self.device, non_blocking=True)
            item_source = batch_data['item_source'].to(self.device, non_blocking=True)
            entity_presence = batch_data['entity_presence'].to(self.device, non_blocking=True)
            emo_negs = batch_data['emo_negs'].to(self.device, non_blocking=True)

            if self.torch_amp:
                with torch.cuda.amp.autocast():
                    processed_embeddings = self.model.module.process_data(
                        input_ids=input_ids, attn_mask=attn_mask, entity_presence=entity_presence, position_ids=position_ids
                    )
            else:
                processed_embeddings = self.model.module.process_data(
                    input_ids=input_ids, attn_mask=attn_mask, entity_presence=entity_presence, position_ids=position_ids
                )
            emo_labels = emo_labels[entity_presence > 0]
            self.emotion_queue = update_queue(self.emotion_queue, processed_embeddings, emo_labels)
            print('\tAdded batch {} of {} to queue...'.format(batch_idx, self.n_iters))

        q_avg = self.emotion_queue.mean(dim=1)
        print('q_avg: {}'.format(q_avg.shape))
        for label_idx in range(q_avg.shape[0]):
            print('Label {} avg: {}'.format(label_idx, q_avg[label_idx]))
            self.original_protos.append(q_avg[label_idx].view(1, -1))

        q_var = q_avg.var(dim=0).mean()
        print('Queue var: {}'.format(q_var))
        # q_var2 = q_avg.var(dim=1)
        # print('q_var1: {}'.format(q_var1.shape))
        # print('q_var2: {}'.format(q_var2.shape))

        print('Queue is loaded!')
        self.original_protos = torch.cat(self.original_protos, dim=0)

    def run_one_epoch(self, epoch, mode):
        if mode == self.mode:
            dataset = self.data_loader
            n_iters = self.n_iters
            n_samples = len(self.dataset)
        else:
            dataset = self.aux_data_loader
            n_iters = self.aux_n_iters
            n_samples = len(self.aux_dataset)

        last_batch_end_time = None
        iter_since_grad_accum = 1
        emo_preds_to_write, sent_preds_to_write = [], []
        agg_emo_preds, agg_emo_labels, agg_sent_preds, agg_sent_labels = [], [], [], []
        agg_item_sources = []
        inspection_data = []
        losses, batch_sizes = [], []
        agg_probs, agg_labels = [], []
        agg_label_probs = []
        for batch_idx, batch_data in enumerate(dataset):
            global_item_idx = (epoch * n_iters) + batch_idx
            batch_start_time = time.time()

            input_ids = batch_data['input_ids'].to(self.device, non_blocking=True)
            attn_mask = batch_data['attn_mask'].to(self.device, non_blocking=True)
            dialogue_ids = batch_data['dialogue_id'].to(self.device, non_blocking=True)
            emo_labels = batch_data['emo_labels'].to(self.device, non_blocking=True)
            position_ids = batch_data['position_ids'].to(self.device, non_blocking=True)
            item_source = batch_data['item_source'].to(self.device, non_blocking=True)
            # sent_labels = batch_data['sent_labels'].to(self.device, non_blocking=True)
            # sent_negs = batch_data['sent_negs'].to(self.device, non_blocking=True)
            batch_size = input_ids.shape[0]
            if self.rank == 0 and epoch == 0 and batch_idx == 0:
                for k, v in batch_data.items():
                    print('k: {} v: {}'.format(k, v.shape))

                print('position_ids:\n{}'.format(position_ids))
                print('\tshape: {} min: {} max: {}'.format(position_ids.shape, position_ids.min(), position_ids.max()))

            sent_labels = None
            sent_negs = None
            entity_presence = batch_data['entity_presence'].to(self.device, non_blocking=True)
            emo_negs = batch_data['emo_negs'].to(self.device, non_blocking=True)
            # /home/czh/.local/lib/python3.10/site-packages/torch/_inductor/compile_fx.py:90: UserWarning:
            # TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. Consider setting
            # `torch.set_float32_matmul_precision('high')` for better performance.
            if self.torch_amp:
                with torch.cuda.amp.autocast():
                    model_outputs = self.model(
                        input_ids=input_ids, attn_mask=attn_mask, emo_labels=emo_labels, sent_labels=sent_labels,
                        entity_presence=entity_presence, emo_negs=emo_negs, sent_negs=sent_negs,
                        position_ids=position_ids, emotion_queue=self.emotion_queue
                    )
            else:
                model_outputs = self.model(
                    input_ids=input_ids, attn_mask=attn_mask, emo_labels=emo_labels, sent_labels=sent_labels,
                    entity_presence=entity_presence, emo_negs=emo_negs, sent_negs=sent_negs,
                    position_ids=position_ids, emotion_queue=self.emotion_queue
                )

            if self.mode == 'pt' or self.mode == 'pt_test':
                loss, masked_preds, masked_labels, masked_idxs = model_outputs
                # print('masked_preds:\n{}\n\t{}'.format(masked_preds, masked_preds.shape))
                # print('masked_labels:\n{}\n\t{}'.format(masked_labels, masked_labels.shape))

                sm = nn.Softmax(dim=-1)
                pred_probs = sm(masked_preds)
                # print('pred_probs:\n{}\n\t{}'.format(pred_probs, pred_probs.shape))

                label_probs = pred_probs.gather(-1, masked_labels.view(1, -1))
                # print('label_probs:\n{}\n\t{}'.format(label_probs, label_probs.shape))

                losses.append(loss.clone().detach())
                agg_label_probs.append(label_probs.clone().detach().cpu().view(-1))
                batch_sizes.append(batch_size)
            else:
                raw_embeddings, emo_preds, emo_loss, sent_preds, sent_loss = model_outputs
                if self.rank == 0 and epoch == 0 and batch_idx == 0:
                    if emo_preds is not None:
                        print('\temo_preds.shape: {}'.format(emo_preds.shape))
                    if sent_preds is not None:
                        print('\tsent_preds.shape: {}'.format(sent_preds.shape))
                emo_labels = emo_labels[entity_presence > 0]
                if mode == 'test':
                    agg_probs.append(emo_preds.detach().cpu())
                    agg_labels.append(emo_labels.view(-1).cpu())

                if mode == 'train' and self.objective != 'xent' and self.use_queue:
                    self.emotion_queue = update_queue(self.emotion_queue, raw_embeddings, emo_labels)

                if self.do_sentiment:
                    loss = emo_loss + sent_loss
                else:
                    loss = emo_loss

            if mode == 'train' or mode == 'pt':
                if self.torch_amp:
                    self.scaler.scale(loss).backward()

                    if self.objective != 'xent' and self.mode != 'pt':
                        # torch.nn.utils.clip_grad_norm_(self.model.module.entity_embd.parameters(), 1e-2)
                        # if self.supcon_loss_fn == 'byol':
                        #     torch.nn.utils.clip_grad_norm_(self.model.module.projector.parameters(), 1e-6)
                        #     torch.nn.utils.clip_grad_norm_(self.model.module.predictor.parameters(), 1e-6)
                        #
                        # elif self.project_embds_for_contrast:
                        if self.supcon_loss_fn != 'byol' and self.project_embds_for_contrast:
                            self.scaler.unscale_(self.optimizer)
                            torch.nn.utils.clip_grad_norm_(self.model.module.emotion_projection.parameters(), 1e-6)
                else:
                    loss.backward()

            if self.mode == 'pt' or self.mode == 'pt_test':
                pass
            else:
                if mode == 'inspect':
                    # print('raw_embeddings: {}'.format(raw_embeddings.shape))
                    # print('emo_labels: {}'.format(emo_labels.shape))
                    this_inspection_data = torch.cat([raw_embeddings.cpu(), emo_labels.view(-1, 1).cpu()], dim=1)
                    # print('this_inspection_data: {}'.format(this_inspection_data.shape))
                    inspection_data.append(this_inspection_data)

                if (global_item_idx % self.log_every == 0 or mode != 'train') and \
                        (self.objective.startswith('xent')):

                    emo_preds = self.gather(emo_preds)
                    emo_labels = self.gather(emo_labels)

                    item_source = item_source.expand(-1, self.max_n_utterances)
                    item_source = item_source[entity_presence > 0]
                    item_source = self.gather(item_source)

                    if self.do_sentiment:
                        sent_labels = sent_labels[entity_presence > 0]
                        sent_preds = self.gather(sent_preds)
                        sent_labels = self.gather(sent_labels)

                    # if mode != 'train':
                    agg_emo_preds.append(emo_preds.cpu())
                    agg_emo_labels.append(emo_labels.cpu())
                    agg_item_sources.append(item_source.cpu())
                    if self.do_sentiment:
                        agg_sent_preds.append(sent_preds.cpu())
                        agg_sent_labels.append(sent_labels.cpu())

                    if global_item_idx % self.log_every == 0:
                        dialogue_ids = dialogue_ids.expand(-1, self.max_n_utterances)[entity_presence > 0]
                        utterance_ids = torch.arange(self.max_n_utterances, device=dialogue_ids.device).unsqueeze(0).expand(input_ids.shape[0], -1)[entity_presence > 0]

                        dialogue_ids = self.gather(dialogue_ids)
                        utterance_ids = self.gather(utterance_ids)
                        softmax = nn.Softmax(dim=-1)

                        emo_pred_data = torch.concatenate(
                            [
                                dialogue_ids.view(-1, 1), utterance_ids.view(-1, 1),
                                softmax(emo_preds), emo_labels.view(-1, 1)
                            ], dim=-1
                        ).cpu()
                        emo_preds_to_write.append(emo_pred_data)

                        if self.do_sentiment:
                            sent_pred_data = torch.concatenate(
                                [
                                    dialogue_ids.view(-1, 1), utterance_ids.view(-1, 1),
                                    softmax(sent_preds), sent_labels.view(-1, 1)
                                ], dim=-1
                            ).cpu()
                            sent_preds_to_write.append(sent_pred_data)

            if global_item_idx % self.print_every == 0:
                batch_elapsed_time = time.time() - batch_start_time
                if last_batch_end_time is not None:
                    time_btw_batches = batch_start_time - last_batch_end_time
                else:
                    time_btw_batches = 0.0
                print_str = '{0}- epoch: {1}/{2} iter: {3:4d}/{4} loss: {5:6.2f}'.format(
                    mode, epoch, self.n_epochs, batch_idx, n_iters, loss

                )
                if self.mode != 'pt' and self.mode != 'pt_test':
                    print_str = '{0} emo loss: {1:.2f} sent loss: {2:.2f}'.format(
                        print_str, emo_loss, sent_loss if self.do_sentiment else -1.0
                    )
                print_str = '{0} Time: {1:4.2f}s ({2:1.2f}s)'.format(print_str, batch_elapsed_time, time_btw_batches)
                print(print_str)

            if global_item_idx % self.summary_every == 0 or mode != 'train':
                self.summary_writer.add_scalar('loss/{}'.format(mode), loss, global_item_idx)

                if self.mode != 'pt' and self.mode != 'pt_test':
                    self.summary_writer.add_scalar('emo_loss/{}'.format(mode), emo_loss, global_item_idx)

                    if self.do_sentiment:
                        self.summary_writer.add_scalar('sent_loss/{}'.format(mode), sent_loss, global_item_idx)

            # print('global_item_idx: {}'.format(global_item_idx))
            # print("self.mode in ['train', 'pt']: {}".format(self.mode in ['train', 'pt']))
            # print('self.args.grad_summary: {}'.format(self.args.grad_summary))
            # print('global_item_idx % self.args.grad_summary_every == 0: {}'.format(global_item_idx % self.args.grad_summary_every == 0))
            if global_item_idx % self.args.grad_summary_every == 0 \
                    and self.summary_writer is not None and self.mode in ['train', 'pt'] \
                    and self.args.grad_summary and global_item_idx != 0:
                # print('* grad summary *')
                for name, p in self.model.named_parameters():
                    if p.grad is not None and p.grad.data is not None:
                        # print('$$ {} grad is not None! $$'.format(name))
                        self.summary_writer.add_histogram('grad/{}'.format(name), p.grad.data,
                                                          (epoch * n_iters) + batch_idx)
                        self.summary_writer.add_histogram('weight/{}'.format(name), p.data,
                                                          (epoch * n_iters) + batch_idx)

            if iter_since_grad_accum == self.args.n_grad_accum and mode == 'train' or mode == 'pt':
                # print('OPTIMIZER STEP')
                if self.torch_amp:
                    self.scaler.step(self.optimizer)
                    self.scaler.update()
                else:
                    self.optimizer.step()

                if self.scheduler is not None:
                    self.scheduler.step()
                self.optimizer.zero_grad()
                iter_since_grad_accum = 1
            else:
                iter_since_grad_accum += 1

            last_batch_end_time = time.time()

        if iter_since_grad_accum > 1 and mode == 'train':
            if self.torch_amp:
                self.scaler.step(self.optimizer)
                self.scaler.update()
            else:
                self.optimizer.step()
            self.optimizer.zero_grad()

        if self.rank == 0:
            if len(agg_label_probs) > 0:
                def calc_perplexity_from_probs(prob_list):
                    inv_probs = [1 / (p + 1e-12) for p in prob_list]
                    perplexity = sum(inv_probs) / len(inv_probs)

                    return perplexity

                agg_label_probs = torch.cat(agg_label_probs, dim=0).tolist()
                # inv_probs = 1 / agg_label_probs
                # mean_perp = inv_probs.mean()
                mean_perp = calc_perplexity_from_probs(agg_label_probs)
                print('mean_perp: {}'.format(mean_perp))

            if len(agg_probs) > 0:
                agg_probs = torch.cat(agg_probs, dim=0)
                print('agg_probs: {}'.format(agg_probs.shape))
                agg_labels = torch.cat(agg_labels, dim=0)
                print('agg_labels: {}'.format(agg_labels.shape))
                write_data = torch.cat([agg_probs, agg_labels.view(-1, 1)], dim=1)

                headers = ['cls{}'.format(idx) for idx in range(agg_probs.shape[1])]
                headers.append('label')

                df = pd.DataFrame(write_data, columns=headers)
                df.to_csv(
                    os.path.join(self.pred_dir, '{}_pred_probs_e{}.csv'.format(mode, epoch)), index=False
                )

            if len(losses) > 0:
                total_records = sum(batch_sizes)
                batch_pcts = [bz / total_records for bz in batch_sizes]
                weighted_losses = [b_pct * b_loss for b_pct, b_loss in zip(batch_pcts, losses)]

                # avg_loss = sum(losses) / len(losses)
                avg_loss = sum(weighted_losses)
                print('avg_loss: {}'.format(avg_loss))
                with open(os.path.join(self.pred_dir, '{}_mlm_losses.txt'.format(mode)), 'a+') as f:
                    f.write('epoch: {},avg_loss:{}\n'.format(
                        epoch, avg_loss
                    ))

            if self.objective != 'xent' and mode != 'inspect' and self.emotion_queue is not None:
                curr_protos = []
                q_avg = self.emotion_queue.mean(dim=1)
                print('q_avg: {}'.format(q_avg.shape))
                for label_idx in range(q_avg.shape[0]):
                    print('Label {} avg: {}'.format(label_idx, q_avg[label_idx]))
                    curr_protos.append(q_avg[label_idx].view(1, -1))

                curr_protos = torch.cat(curr_protos, dim=0)
                # a_norm = a / a.norm(dim=1)[:, None]
                # b_norm = b / b.norm(dim=1)[:, None]
                # res = torch.mm(a_norm, b_norm.transpose(0,1))
                # print(res)

                # orig_protos_norm = self.original_protos / self.original_protos.norm(dim=1)[:, None]
                curr_protos_norm = curr_protos / curr_protos.norm(dim=1)[:, None]
                proto_sim = torch.mm(curr_protos_norm, curr_protos_norm.transpose(0, 1))
                print('proto_sim:\n{}'.format(proto_sim))

                q_var = q_avg.var(dim=0).mean()
                print('Queue var: {}'.format(q_var))

            if len(inspection_data) > 0:
                print('Saving inspection results...')
                inspection_data = torch.cat(inspection_data, dim=0)
                print('\tinspection_data: {}'.format(inspection_data.shape))
                outfile = os.path.join(self.inspection_dir, 'inspection_e{}.npy'.format(epoch))
                np.save(outfile, inspection_data)

            # if len(emo_preds_to_write) > 0:
            #     print('Saving emotion predictions...')
            #
            #     emo_preds_to_write = torch.cat(emo_preds_to_write, dim=0).numpy()
            #     emo_preds_df = pd.DataFrame(emo_preds_to_write, columns=self.emo_headers)
            #     emo_preds_df.to_csv(
            #         self.pred_fp_tmplt.format(mode, 'emo', epoch), index=False
            #     )
            #
            # if len(sent_preds_to_write) > 0:
            #     print('Saving sentiment predictions...')
            #
            #     sent_preds_to_write = torch.cat(sent_preds_to_write, dim=0).numpy()
            #     sent_preds_df = pd.DataFrame(sent_preds_to_write, columns=self.sent_headers)
            #     sent_preds_df.to_csv(
            #         self.pred_fp_tmplt.format(mode, 'sent', epoch), index=False
            #     )

            if len(agg_emo_preds) > 0 and len(agg_emo_labels) > 0:
                agg_item_sources = torch.cat(agg_item_sources, dim=0).numpy().reshape(-1)
                agg_emo_preds = torch.cat(agg_emo_preds, dim=0).numpy()
                # pred_probs, pred_classes = torch.topk(agg_emo_preds, k=1)
                # pred_classes = pred_classes.numpy()
                # pred_classes = np.argmax(agg_emo_preds, axis=-1)
                pred_labels = torch.cat(agg_emo_labels, dim=0).numpy()

                print('agg_item_sources: {}'.format(agg_item_sources.shape))
                print('agg_emo_preds: {}'.format(agg_emo_preds.shape))
                # print('pred_classes: {}'.format(pred_classes.shape))
                print('pred_labels: {}'.format(pred_labels.shape))
                u_item_sources = np.unique(agg_item_sources)

                for item_source in u_item_sources:
                    source_name = self.reverse_source_mapping[item_source]
                    source_pred_labels = pred_labels[agg_item_sources == item_source]
                    # source_pred_classes = pred_classes[agg_item_sources == item_source]
                    source_agg_emo_preds = agg_emo_preds[agg_item_sources == item_source]
                    if source_name == 'meld':
                        pred_mask = self.meld_pred_mask
                    else:
                        pred_mask = self.emory_nlp_pred_mask
                    print('raw pred_mask: {}'.format(pred_mask.shape))
                    pred_mask = pred_mask.expand(source_agg_emo_preds.shape[0], -1)
                    print('expanded pred_mask: {}'.format(pred_mask.shape))
                    print('raw preds: {}'.format(source_agg_emo_preds))
                    source_agg_emo_preds[pred_mask < 1] = float('-inf')
                    print('masked preds: {}'.format(source_agg_emo_preds))
                    # if source_name == 'meld':
                    #     print('raw meld preds:\n{}'.format(source_agg_emo_preds))
                    #     source_agg_emo_preds[:, 7:] = float('-inf')
                    #     print('updated meld preds:\n{}'.format(source_agg_emo_preds))

                    source_pred_classes = np.argmax(source_agg_emo_preds, axis=-1)

                    micro_f1 = f1_score(source_pred_labels, source_pred_classes, average='micro')
                    macro_f1 = f1_score(source_pred_labels, source_pred_classes, average='macro')
                    weighted_f1 = f1_score(source_pred_labels, source_pred_classes, average='weighted')
                    print('Emo scores for {}:'.format(source_name))
                    print('\tmicro_f1: {}'.format(micro_f1))
                    print('\tmacro_f1: {}'.format(macro_f1))
                    print('\tweighted_f1: {}'.format(weighted_f1))

                    with open(os.path.join(self.pred_dir, 'emo_{}_scores.txt'.format(mode)), 'a+') as f:
                        f.write('epoch: {},source: {},micro_f1: {}, macro_f1: {},weighted_f1: {}\n'.format(
                            epoch, source_name, micro_f1, macro_f1, weighted_f1
                        ))
                        # f.write('macro_f1: {}\n'.format(macro_f1))
                        # f.write('weighted_f1: {}\n'.format(weighted_f1))

                    if self.summary_writer is not None:
                        self.summary_writer.add_scalar('emo_micro_f1/{}'.format(mode), micro_f1, epoch)
                        self.summary_writer.add_scalar('emo_macro_f1/{}'.format(mode), macro_f1, epoch)
                        self.summary_writer.add_scalar('emo_weighted_f1/{}'.format(mode), weighted_f1, epoch)

            if len(agg_sent_preds) > 0 and len(agg_sent_labels) > 0:
                agg_sent_preds = torch.cat(agg_sent_preds, dim=0).numpy()
                pred_classes = np.argmax(agg_sent_preds, axis=-1)
                pred_labels = torch.cat(agg_sent_labels, dim=0).numpy()

                micro_f1 = f1_score(pred_labels, pred_classes, average='micro')
                macro_f1 = f1_score(pred_labels, pred_classes, average='macro')
                weighted_f1 = f1_score(pred_labels, pred_classes, average='weighted')
                print('SENT PREDS:')
                print('\tmicro_f1: {}'.format(micro_f1))
                print('\tmacro_f1: {}'.format(macro_f1))
                print('\tweighted_f1: {}'.format(weighted_f1))

                with open(os.path.join(self.pred_dir, 'sent_{}_scores_e{}.txt'.format(mode, epoch)), 'w+') as f:
                    f.write('micro_f1: {}\n'.format(micro_f1))
                    f.write('macro_f1: {}\n'.format(macro_f1))
                    f.write('weighted_f1: {}\n'.format(weighted_f1))

                if self.summary_writer is not None:
                    self.summary_writer.add_scalar('sent_micro_f1/{}'.format(mode), micro_f1, epoch)
                    self.summary_writer.add_scalar('sent_macro_f1/{}'.format(mode), macro_f1, epoch)
                    self.summary_writer.add_scalar('sent_weighted_f1/{}'.format(mode), weighted_f1, epoch)

