from fairseq.tasks import register_task
from fairseq import metrics, utils
from fairseq.data import data_utils
from fairseq.tasks.translation import TranslationTask, TranslationConfig
import torch
import torch.nn.functional as F
from dataclasses import dataclass, field
import logging

import collections

EVAL_BLEU_ORDER = 4

logger = logging.getLogger(__name__)

@dataclass
class SummarizationUtilizationConfig(TranslationConfig):
    stat_save_path: str = field(
        default='extra_stat.pkl',
        metadata={
            "help": "where to save extra stuff"
        },
    )

@register_task("summarization_utilization", dataclass=SummarizationUtilizationConfig)
class UtilizationSummarizationTask(TranslationTask):
    cfg: SummarizationUtilizationConfig

    def __init__(self, cfg: SummarizationUtilizationConfig, src_dict, tgt_dict):
        super().__init__(cfg, src_dict, tgt_dict)

        self.extra_statistics = None
        self.initialize_extra_statistics()
        self.validation_epoch = False

    def initialize_extra_statistics(self):
        """
        extra_statistics only handles lists, i.e. add any list into logging_outputs and it will be dumped in the extra_statistics
        """
        self.extra_statistics = collections.defaultdict(list)

    def _inference_with_bleu(self, generator, sample, model):
        def decode(toks, escape_unk=False):
            s = self.tgt_dict.string(
                toks.int().cpu(),
                self.cfg.eval_bleu_remove_bpe,
                # The default unknown string in fairseq is `<unk>`, but
                # this is tokenized by sacrebleu as `< unk >`, inflating
                # BLEU scores. Instead, we use a somewhat more verbose
                # alternative that is unlikely to appear in the real
                # reference, but doesn't get split into multiple tokens.
                unk_string=("UNKNOWNTOKENINREF" if escape_unk else "UNKNOWNTOKENINHYP"),
            )
            if self.tokenizer:
                s = self.tokenizer.decode(s)
            return s

        gen_out = self.inference_step(generator, [model], sample, prefix_tokens=None)

        # we keep all generated hyps such that we can compute pool-level quality later on
        hyps, scores, refs = [], [], []
        for i in range(len(gen_out)):
            generated_hyps = [decode(gen_out[i][hypid]["tokens"]) for hypid in range(len(gen_out[i]))]
            generated_scores = [gen_out[i][hypid]["score"].item() for hypid in range(len(gen_out[i]))]
            hyps.append(generated_hyps)
            scores.append(generated_scores)
            refs.append(
                decode(
                    utils.strip_pad(sample["target"][i], self.tgt_dict.pad()),
                    escape_unk=True,  # don't count <unk> as matches to the hypo
                )
            )
        return {
            'refs_text': refs,
            'hyps_text': hyps,
            'hyps_scores': scores,
            'gen_out': gen_out
        }


    @torch.no_grad()
    def valid_step(self, sample, model, criterion):
        # original valid step
        model.eval()
        with torch.no_grad():
            loss, sample_size, logging_output = criterion(model, sample)

        # processing target sequences
        non_pad_mask = (sample['target'] != 1).float()[:,:,None]
        target_net_output = model(**sample['net_input'])
        target_model_lprobs = F.log_softmax(target_net_output[0], dim=-1)
        target_model_true_lprobs = torch.gather(target_model_lprobs, dim=-1, index=sample['target'].unsqueeze(-1))
        target_seq_lengths = non_pad_mask.squeeze(-1).sum(-1).tolist()
        logging_output['target_seq_lengths'] = target_seq_lengths
        logging_output['target_model_log_probs'] = target_model_true_lprobs.detach().cpu().tolist()
        logging_output['target_src_seq_lengths'] = sample['net_input']['src_lengths'].tolist()

        # now we generate only once!
        if self.cfg.eval_bleu:
            result_dict = self._inference_with_bleu(self.sequence_generator, sample, model)
            generated = result_dict['gen_out']
            # we split counts into separate entries so that they can be
            # summed efficiently across workers using fast-stat-sync

            beam_hyps = []
            for i in range(len(generated)):
                beam_hyps.append(generated[i][0]["tokens"])
            generated_seq_lengths = [hyp.numel() for hyp in beam_hyps]

            beam_hyps_target = data_utils.collate_tokens(beam_hyps, pad_idx=self.tgt_dict.pad(), move_eos_to_beginning=False)
            beam_hyps_input = data_utils.collate_tokens(beam_hyps, pad_idx=self.tgt_dict.pad(), move_eos_to_beginning=True)

            generated_net_output = model(sample['net_input']['src_tokens'], sample['net_input']['src_lengths'], beam_hyps_input)
            generated_model_lprobs = F.log_softmax(generated_net_output[0], dim=-1)
            generated_model_true_lprobs = torch.gather(generated_model_lprobs, dim=-1, index=beam_hyps_target.unsqueeze(-1))

            logging_output['generated_seq_lengths'] = generated_seq_lengths
            logging_output['generated_hyps_text'] = result_dict['hyps_text']
            logging_output['generated_refs_text'] = result_dict['refs_text']
            logging_output['generated_hyps_scores'] = result_dict['hyps_scores']
            logging_output['generated_model_log_probs'] = generated_model_true_lprobs.detach().cpu().tolist()

        return loss, sample_size, logging_output

    def reduce_metrics(self, logging_outputs, criterion):

        super().reduce_metrics(logging_outputs, criterion)
        nll_loss_sum = sum(log.get('nll_loss', torch.tensor(0)) for log in logging_outputs)
        utilization_loss = sum(log.get('utilization_loss', torch.tensor(0)) for log in logging_outputs)

        with metrics.aggregate('train'):
            metrics.log_scalar('target/nll_loss_sum', nll_loss_sum)
            metrics.log_scalar('target/utilization_loss', utilization_loss)

        # extra stats to pickle
        if self.validation_epoch:

            keys_with_list_type_values = []
            for k,v in logging_outputs[0].items():
                if isinstance(v, list):
                    keys_with_list_type_values.append(k)

            for k in keys_with_list_type_values:
                reduced_list = [log[k] for log in logging_outputs]
                reduced_list = sum(reduced_list, [])
                self.extra_statistics[k] += reduced_list

    def log_tensorboard(self):
            """
            Logging entire valid epoch metrics
            """
            with metrics.aggregate('valid'):
                # log length ratio in tb
                if self.extra_statistics[f'generated_seq_lengths']:
                    metrics.log_scalar(f'target_generated_lenratio', (torch.tensor(self.extra_statistics[f'target_seq_lengths']) / torch.tensor(self.extra_statistics[f'generated_seq_lengths'])).mean().item())
