import math
import sys
import json
from typing import Iterable, Optional

import torch
import torch.nn as nn
import torch.nn.functional as F

from timm.utils import ModelEma
from timm.utils import accuracy, ModelEma
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy


import utils
def get_sentencepiece_model_for_beit3():
    from transformers import XLMRobertaTokenizer
    return XLMRobertaTokenizer('./model_weights/beit3.spm')

class TaskHandler(object):
    def __init__(self) -> None:
        self.metric_logger = None
        self.split = None

    def train_batch(self, model, **kwargs):
        raise NotImplementedError()

    def eval_batch(self, model, **kwargs):
        raise NotImplementedError()

    def before_eval(self, metric_logger, data_loader, **kwargs):
        self.metric_logger = metric_logger
        self.split = data_loader.dataset.split

    def after_eval(self, **kwargs):
        raise NotImplementedError()


class NLVR2Handler(TaskHandler):
    def __init__(self) -> None:
        super().__init__()
        self.criterion = torch.nn.CrossEntropyLoss()

    def train_batch(self, model, image, image2, language_tokens, padding_mask, label):
        logits = model(
            image_a=image, image_b=image2,
            text_description=language_tokens,
            padding_mask=padding_mask)
        acc = (logits.max(-1)[-1] == label).float().mean()
        return {
            "loss": self.criterion(input=logits, target=label),
            "acc": acc,
        }

    def eval_batch(self, model, image, image2, language_tokens, padding_mask, label):
        logits = model(
            image_a=image, image_b=image2,
            text_description=language_tokens,
            padding_mask=padding_mask)
        batch_size = language_tokens.shape[0]
        acc = (logits.max(-1)[-1] == label).float().sum(0) * 100.0 / batch_size
        self.metric_logger.meters['acc'].update(acc.item(), n=batch_size)

    def after_eval(self, **kwargs):
        print('* Acc {acc.global_avg:.3f}'.format(acc=self.metric_logger.acc))
        return {k: meter.global_avg for k, meter in self.metric_logger.meters.items()}, "acc"


class ImageNetHandler(TaskHandler):
    def __init__(self) -> None:
        super().__init__()
        mixup_active = 0
        label_smoothing = 0.1
        if mixup_active:
            # smoothing is handled with mixup label transform
            self.criterion = SoftTargetCrossEntropy()
        elif label_smoothing > 0.:
            self.criterion = LabelSmoothingCrossEntropy(smoothing=label_smoothing)
        else:
            self.criterion = torch.nn.CrossEntropyLoss()

    def train_batch(self, model, image, label):
        logits = model(image=image)
        return {
            "loss": self.criterion(logits, label),
        }

    def eval_batch(self, model, image, label):
        logits = model(image=image)
        batch_size = image.shape[0]
        acc1, acc5 = accuracy(logits, label, topk=(1, 5))
        self.metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
        self.metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)

    def after_eval(self, **kwargs):
        print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f}'
              .format(top1=self.metric_logger.acc1, top5=self.metric_logger.acc5))
        return {k: meter.global_avg for k, meter in self.metric_logger.meters.items()}, "acc1"


class RetrievalHandler(TaskHandler):
    def __init__(self) -> None:
        super().__init__()
        self.image_feats = []
        self.text_feats = []
        self.image_ids = []
        self.metric_logger = None

    def train_batch(self, model, image, language_tokens, padding_mask, image_id):
        loss, vision_cls, language_cls = model(
            image=image, text_description=language_tokens, padding_mask=padding_mask)
        return {
            "loss": loss,
        }

    def before_eval(self, metric_logger, **kwargs):
        self.image_feats.clear()
        self.text_feats.clear()
        self.image_ids.clear()
        self.metric_logger = metric_logger

    def eval_batch(self, model, image, language_tokens, padding_mask, image_id):
        vision_cls, _ = model(image=image, only_infer=True)
        _, language_cls = model(
            text_description=language_tokens, padding_mask=padding_mask, only_infer=True)

        self.image_feats.append(vision_cls.clone())
        self.text_feats.append(language_cls.clone())
        self.image_ids.append(image_id.clone())

    def after_eval(self, **kwargs):
        image_feats = {}
        for feats, ids in zip(self.image_feats, self.image_ids):
            for i, _idx in enumerate(ids):
                idx = _idx.item()
                if idx not in image_feats:
                    image_feats[idx] = feats[i]

        tiids = torch.cat(self.image_ids, dim=0)
        iids = []
        sorted_tensors = []
        for key in sorted(image_feats.keys()):
            sorted_tensors.append(image_feats[key].view(1, -1))
            iids.append(key)

        image_cls_feats = torch.cat(sorted_tensors, dim=0)
        text_cls_feats = torch.cat(self.text_feats, dim=0)

        scores = image_cls_feats @ text_cls_feats.t()
        iids = torch.LongTensor(iids).to(scores.device)

        print("scores: {}".format(scores.size()))
        print("iids: {}".format(iids.size()))
        print("tiids: {}".format(tiids.size()))

        topk10 = scores.topk(10, dim=1)
        topk5 = scores.topk(5, dim=1)
        topk1 = scores.topk(1, dim=1)

        topk10_iids = tiids[topk10.indices]
        topk5_iids = tiids[topk5.indices]
        topk1_iids = tiids[topk1.indices]

        tr_r10 = (iids.unsqueeze(1) == topk10_iids).float().max(dim=1)[0].mean()
        tr_r5 = (iids.unsqueeze(1) == topk5_iids).float().max(dim=1)[0].mean()
        tr_r1 = (iids.unsqueeze(1) == topk1_iids).float().max(dim=1)[0].mean()

        topk10 = scores.topk(10, dim=0)
        topk5 = scores.topk(5, dim=0)
        topk1 = scores.topk(1, dim=0)
        topk10_iids = iids[topk10.indices]
        topk5_iids = iids[topk5.indices]
        topk1_iids = iids[topk1.indices]

        ir_r10 = (tiids.unsqueeze(0) == topk10_iids).float().max(dim=0)[0].mean()
        ir_r5 = (tiids.unsqueeze(0) == topk5_iids).float().max(dim=0)[0].mean()
        ir_r1 = (tiids.unsqueeze(0) == topk1_iids).float().max(dim=0)[0].mean()

        eval_result = {
            "tr_r10": tr_r10.item() * 100.0,
            "tr_r5": tr_r5.item() * 100.0,
            "tr_r1": tr_r1.item() * 100.0,
            "ir_r10": ir_r10.item() * 100.0,
            "ir_r5": ir_r5.item() * 100.0,
            "ir_r1": ir_r1.item() * 100.0,
            "average_score": 100.0 * (tr_r1 + tr_r5 + tr_r10 + ir_r1 + ir_r5 + ir_r10).item() / 6.0,
        }

        print('* Eval result = %s' % json.dumps(eval_result))
        return eval_result, "average_score"


class VQAHandler(TaskHandler):
    def __init__(self) -> None:
        super().__init__()
        self.predictions = []
        self.criterion = nn.BCEWithLogitsLoss(reduction='mean')
        self.label2ans = None

    def train_batch(self, model, image, language_tokens, padding_mask, labels):
        logits = model(
            image=image, question=language_tokens,
            padding_mask=padding_mask)
        return {
            "loss": self.criterion(input=logits.float(), target=labels.float()) * labels.shape[1],
        }

    def before_eval(self, metric_logger, data_loader, **kwargs):
        self.predictions.clear()
        self.metric_logger = metric_logger
        self.label2ans = data_loader.dataset.label2ans

    def eval_batch(self, model, image, language_tokens, padding_mask, labels=None, qid=None):
        logits = model(
            image=image, question=language_tokens,
            padding_mask=padding_mask)
        batch_size = language_tokens.shape[0]
        if labels is not None:
            scores = utils.VQAScore()(logits, labels) * 100.0
            self.metric_logger.meters['score'].update(scores.item(), n=batch_size)
        else:
            _, preds = logits.max(-1)
            for image_id, pred in zip(qid, preds):
                self.predictions.append({
                    "question_id": image_id.item(),
                    "answer": self.label2ans[pred.item()],
                })

    def after_eval(self, **kwargs):
        if len(self.predictions) == 0:
            print('* Score {score.global_avg:.3f}'.format(score=self.metric_logger.score))
            return {k: meter.global_avg for k, meter in self.metric_logger.meters.items()}, "score"
        else:
            return self.predictions, "prediction"


class CaptioningHandler(TaskHandler):
    def __init__(self) -> None:
        super().__init__()
        self.predictions = []
        label_smoothing = 0.1
        drop_worst_ratio=0.2
        self.criterion = utils.BertCaptioningLoss(label_smoothing, drop_worst_ratio, 12000)
        self.tokenizer = get_sentencepiece_model_for_beit3()
        self.num_beams = 3
        self.max_len = 64
        self.length_penalty = 0.6
        self.vocab_size = 64010

    def train_batch(self, model, image, language_tokens, masked_tokens, language_masked_pos, padding_mask, image_id,
                    global_step):
        logits, _ = model(
            image=image, text_ids=masked_tokens, padding_mask=padding_mask, language_masked_pos=language_masked_pos,
            image_id=image_id)
        masked_labels = language_tokens[language_masked_pos.bool()]
        score = torch.max(logits, -1)[1].data == masked_labels
        acc = torch.sum(score.float()) / torch.sum(language_masked_pos)
        return {
            "loss": self.criterion(logits, masked_labels, global_step),
            "acc": acc
        }

    def before_eval(self, metric_logger, data_loader, **kwargs):
        self.predictions.clear()
        self.metric_logger = metric_logger

    def eval_batch(self, model, image, image_id=None):
        cur_len = 2
        num_keep_best = 1
        TOPN_PER_BEAM = 3

        batch_size = image.size(0)
        mask_id = self.tokenizer.mask_token_id
        cls_id = self.tokenizer.cls_token_id
        pad_id = self.tokenizer.pad_token_id
        sep_id = self.tokenizer.sep_token_id
        eos_token_ids = [sep_id]

        cls_ids = torch.full(
            (batch_size, 1), cls_id, dtype=torch.long, device=image.device
        )
        mask_ids = torch.full(
            (batch_size, 1), mask_id, dtype=torch.long, device=image.device
        )
        cur_input_ids = torch.cat([cls_ids, mask_ids], dim=1)
        tmp_ids = torch.full(
            (batch_size, self.max_len - 1), mask_id, dtype=torch.long, device=image.device
        )
        decoding_results = torch.cat([cls_ids, tmp_ids], dim=1)

        # Expand input to num beams
        cur_input_ids = cur_input_ids.unsqueeze(1).expand(batch_size, self.num_beams, cur_len)
        cur_input_ids = cur_input_ids.contiguous().view(batch_size * self.num_beams,
                                                        cur_len)  # (batch_size * num_beams, cur_len)
        decoding_results = decoding_results.unsqueeze(1).expand(batch_size, self.num_beams, self.max_len)
        decoding_results = decoding_results.contiguous().view(batch_size * self.num_beams,
                                                              self.max_len)  # (batch_size * num_beams, cur_len)
        image = image.unsqueeze(1).expand(batch_size, self.num_beams, image.size(-3), image.size(-2), image.size(-1))
        image = image.contiguous().view(batch_size * self.num_beams, image.size(-3), image.size(-2), image.size(-1))

        generated_hyps = [
            utils.BeamHypotheses(
                num_keep_best, self.max_len, length_penalty=self.length_penalty, early_stopping=False
            ) for _ in range(batch_size)
        ]
        # scores for each sentence in the beam
        beam_scores = torch.zeros((batch_size, self.num_beams), dtype=torch.float, device=cur_input_ids.device)
        beam_scores[:, 1:] = -1e9
        beam_scores = beam_scores.view(-1)  # shape (batch_size * num_beams,)

        # done sentences
        done = [False for _ in range(batch_size)]
        incremental_state = {}

        while cur_len <= self.max_len:
            next_token_idx = 1
            padding_masks = torch.full(
                cur_input_ids.shape, 0, dtype=torch.long, device=image.device
            )
            input_image = image
            if cur_len != 2:
                input_image = None

            outputs, incremental_state_next = model(
                image=input_image, text_ids=cur_input_ids, language_masked_pos=None,
                padding_mask=padding_masks, text_len=cur_len, incremental_state=incremental_state)
            incremental_state = incremental_state_next

            # assert outputs.shape[1] == token_len
            scores = outputs[:, next_token_idx, :]  # (batch_size * num_beams, vocab_size)
            scores = F.log_softmax(scores, dim=-1)  # (batch_size * num_beams, vocab_size)
            assert scores.size() == (batch_size * self.num_beams, self.vocab_size)
            # Add the log prob of the new beams to the log prob of the beginning of the sequence (sum of logs == log of the product)
            _scores = scores + beam_scores[:, None].expand_as(scores)  # (batch_size * num_beams, vocab_size)
            # re-organize to group the beam together (we are keeping top hypothesis accross beams)
            _scores = _scores.view(batch_size, self.num_beams * self.vocab_size)  # (batch_size, num_beams * vocab_size)
            next_scores, next_words = torch.topk(_scores, TOPN_PER_BEAM * self.num_beams, dim=1, largest=True,
                                                 sorted=True)
            assert next_scores.size() == next_words.size() == (batch_size, TOPN_PER_BEAM * self.num_beams)

            # next batch beam content
            # list of (batch_size * num_beams) tuple(next hypothesis score, next word, current position in the batch)
            next_batch_beam = []
            # for each sentence
            for batch_ex in range(batch_size):
                # if we are done with this sentence
                done[batch_ex] = done[batch_ex] or generated_hyps[batch_ex].is_done(next_scores[batch_ex].max().item())
                if done[batch_ex]:
                    next_batch_beam.extend([(0, pad_id, 0)] * self.num_beams)  # pad the batch
                    continue

                # next sentence beam content
                next_sent_beam = []
                for idx, score in zip(next_words[batch_ex], next_scores[batch_ex]):
                    # get beam and word IDs
                    beam_id = idx // self.vocab_size
                    word_id = idx % self.vocab_size
                    # end of sentence, or next word
                    # if word_id.item() in eos_token_ids or cur_len + 1 == max_len:
                    if (word_id.item() in eos_token_ids and cur_len + 1 <= self.max_len) or (
                            cur_len + 1 == self.max_len):
                        generated_hyps[batch_ex].add(
                            decoding_results[batch_ex * self.num_beams + beam_id, :cur_len].clone(), score.item()
                        )
                    else:
                        next_sent_beam.append((score, word_id, batch_ex * self.num_beams + beam_id))
                    # the beam for next step is full
                    if len(next_sent_beam) == self.num_beams:
                        break

                # update next beam content
                if cur_len + 1 == self.max_len:
                    assert len(next_sent_beam) == 0
                else:
                    assert len(next_sent_beam) == self.num_beams

                if len(next_sent_beam) == 0:
                    next_sent_beam = [(0, pad_id, 0)] * self.num_beams  # pad the batch
                next_batch_beam.extend(next_sent_beam)
                assert len(next_batch_beam) == self.num_beams * (batch_ex + 1)

            # sanity check / prepare next batch
            assert len(next_batch_beam) == batch_size * self.num_beams
            beam_scores = beam_scores.new([x[0] for x in next_batch_beam])
            beam_words = cur_input_ids.new([x[1] for x in next_batch_beam])
            beam_idx = cur_input_ids.new([x[2] for x in next_batch_beam])

            # re-order batch
            cur_input_ids = cur_input_ids[beam_idx, :]
            decoding_results = decoding_results[beam_idx, :]
            for module in incremental_state:
                for key in incremental_state[module]:
                    result = incremental_state[module][key].index_select(0, beam_idx)
                    incremental_state[module][key] = result[:, :, :-1, :]

            next_ids = torch.full(
                (batch_size * self.num_beams, 1), mask_id, dtype=torch.long, device=image.device
            )
            cur_input_ids = torch.cat([beam_words.unsqueeze(1), next_ids], dim=1)
            decoding_results[:, cur_len - 1] = beam_words
            # update current length
            cur_len = cur_len + 1
            # stop when we are done with each sentence
            if all(done):
                break

        # select the best hypotheses
        tgt_len = torch.ones(batch_size, num_keep_best, dtype=torch.long)
        logprobs = torch.zeros(batch_size, num_keep_best,
                               dtype=torch.float).fill_(-1e5).to(cur_input_ids.device)
        all_best = []

        for i, hypotheses in enumerate(generated_hyps):
            best = []
            hyp_scores = torch.tensor([x[0] for x in hypotheses.hyp])
            _, best_indices = torch.topk(hyp_scores,
                                         min(num_keep_best, len(hyp_scores)), largest=True)
            for best_idx, hyp_idx in enumerate(best_indices):
                conf, best_hyp = hypotheses.hyp[hyp_idx]
                best.append(best_hyp)
                logprobs[i, best_idx] = conf
                tgt_len[i, best_idx] = len(best_hyp) + 1  # +1 for the <EOS> symbol
            all_best.append(best)

        # generate target batch, pad to the same length
        decoded = cur_input_ids.new(batch_size, num_keep_best, self.max_len).fill_(pad_id)
        for batch_idx, best in enumerate(all_best):
            for best_idx, hypo in enumerate(best):
                decoded[batch_idx, best_idx, : tgt_len[batch_idx, best_idx] - 1] = hypo
                decoded[batch_idx, best_idx, tgt_len[batch_idx, best_idx] - 1] = eos_token_ids[0]

        captions = self.tokenizer.batch_decode(decoded.squeeze(1), skip_special_tokens=True)
        for qid, pred in zip(image_id, captions):
            self.predictions.append({
                "image_id": qid.item(),
                "caption": pred,
            })

    def after_eval(self, **kwargs):
        return self.predictions, "prediction"
