# Copyright (c) Alibaba, Inc. and its affiliates.

from typing import Dict, Literal

import numpy as np
from nltk.translate.bleu_score import SmoothingFunction, sentence_bleu
from rouge.rouge import Rouge
from torch import Tensor
from transformers.trainer_utils import EvalPrediction

from .logger import get_logger

logger = get_logger()


def compute_nlg_metrics(prediction, tokenizer):
    import jieba
    preds, labels = prediction[0], prediction[1]

    score_dict = {'rouge-1': [], 'rouge-2': [], 'rouge-l': [], 'bleu-4': []}

    def _decode(tokens, ignore_pad_token_for_loss=False):
        if ignore_pad_token_for_loss:
            tokens = np.where(tokens != -100, tokens, tokenizer.pad_token_id)
        tokens = np.where(tokens < tokenizer.vocab_size, tokens,
                          tokenizer.pad_token_id)
        return [
            t
            for t in tokenizer.batch_decode(tokens, skip_special_tokens=True)
        ]

    for pred, label in zip(preds, labels):
        pred = ''.join(_decode(pred, False))
        label = ''.join(_decode(label, True))
        hypothesis = list(jieba.cut(pred))
        if len(hypothesis) == 0 or ''.join(hypothesis) == '.':
            hypothesis = [tokenizer.decode(tokenizer.eos_token_id)]
        reference = list(jieba.cut(label))
        try:
            rouge = Rouge()
            scores = rouge.get_scores(' '.join(hypothesis),
                                      ' '.join(reference))
            result = scores[0]

            for k, v in result.items():
                score_dict[k].append(round(v['f'] * 100, 4))
            bleu_score = sentence_bleu(
                [list(label)],
                list(pred),
                smoothing_function=SmoothingFunction().method3)
            score_dict['bleu-4'].append(round(bleu_score * 100, 4))
        except Exception as e:
            logger.error(e)
            logger.error(f'eval error {hypothesis}, {reference}')

    for k, v in score_dict.items():
        score_dict[k] = float(np.mean(v))
    return score_dict


def compute_acc_metrics(
        eval_prediction: EvalPrediction,
        acc_strategy: Literal['token',
                              'sentence'] = 'token') -> Dict[str, Tensor]:
    labels = eval_prediction.label_ids[..., 1:]
    predictions = eval_prediction.predictions[..., :-1]
    if predictions.shape != labels.shape:
        return {}
    masks = labels != -100
    if acc_strategy == 'sentence':
        acc_list = []
        for i, m in enumerate(masks):
            acc_list.append(np.all(predictions[i, m] == labels[i, m]))
        acc = np.mean(np.array(acc_list))
    else:
        acc = np.mean((predictions[masks] == labels[masks]).astype(np.float64))
    return {'acc': acc}


def preprocess_logits_for_metrics(logits: Tensor, labels: Tensor) -> Tensor:
    if isinstance(logits, (list, tuple)):
        logits = logits[0]
    preds = logits.argmax(dim=-1)
    return preds
