# Copyright 2024 HuggingFace Inc., THUDM, and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's transformers library and the THUDM's ChatGLM implementation.
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/summarization/run_summarization.py
# https://github.com/THUDM/ChatGLM-6B/blob/main/ptuning/main.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, Optional

import numpy as np
import torch
from transformers.utils import is_jieba_available, is_nltk_available

from ...extras.constants import IGNORE_INDEX
from ...extras.misc import numpify
from ...extras.packages import is_rouge_available


if TYPE_CHECKING:
    from transformers import EvalPrediction, PreTrainedTokenizer


if is_jieba_available():
    import jieba  # type: ignore


if is_nltk_available():
    from nltk.translate.bleu_score import SmoothingFunction, sentence_bleu


if is_rouge_available():
    from rouge_chinese import Rouge


def eval_logit_processor(logits: "torch.Tensor", labels: "torch.Tensor") -> "torch.Tensor":
    # import torch.distributed as dist
    # curr_rank = dist.get_rank()
    # if curr_rank == 0:
    #     print("logits in processor:", logits)
    # print(logits)
    # print(len(logits))
    # print(logits.shape)
    if isinstance(logits, (list, tuple)):
        # if logits[0].dim() == 3:  # (batch_size, seq_len, vocab_size)
        #     print("eval_logit_processor, choose 0")
        #     logits = logits[0]
        # else:  # moe models have aux loss
        #     print("eval_logit_processor, choose 1")
        #     logits = logits[1]

        # # logit: safety in inference
        # # logits_safety = logits[2]

        # logit: safety，推理的时候不改这个
        logits = logits[1]
    print("logits:", logits)
    print("logits.shape:", logits.shape)
    print("logits.dim():", logits.dim())
    print("torch.argmax(logits, dim=-1):", torch.argmax(logits, dim=-1))

    # TODO：为什么dim需要是3，是否是必须的？目前在推理的时候dim不是3，所以先拿掉这个判断看看有没有问题
    # if logits.dim() != 3:
    #     raise ValueError("Cannot process the logits.")

    # print("logits:", logits)
    return torch.argmax(logits, dim=-1)
    # return torch.argmax(logits_safety, dim=-1), torch.argmax(logits, dim=-1)


@dataclass
class ComputeAccuracy:
    def _dump(self) -> Optional[Dict[str, float]]:
        result = None
        if hasattr(self, "score_dict"):
            result = {k: float(np.mean(v)) for k, v in self.score_dict.items()}

        self.score_dict = {"accuracy": [], "accuracy_all": [], "accuracy_response": []}
        return result

    def __post_init__(self):
        self._dump()

    def __call__(self, eval_preds: "EvalPrediction", compute_result: bool = True) -> Optional[Dict[str, float]]:
        preds, labels = numpify(eval_preds.predictions), numpify(eval_preds.label_ids)
        # import torch.distributed as dist
        # curr_rank = dist.get_rank()
        # if curr_rank == 0:
        #     print("preds in metric:", preds)
        #     print("labels in metric:", labels)
        #     print("preds shape in metric:", preds.shape)
        #     print("labels shape in metric:", labels.shape)
        for i in range(len(preds)):
            pred, label = preds[i, :-1], labels[i, 1:]
            # pred, label = preds[i], labels[i]
            # pred, label = preds[i, :], labels[i, :]
            # print("pred:", pred)
            # print("label:", label)
            label_mask = label != IGNORE_INDEX
                
            # TODO: 可以分别计算prompt和response的accuracy
            # !!! 训练和推理的时候要修改
            pred_without_padding = pred[label_mask]
            label_without_padding = label[label_mask]
            # import torch.distributed as dist
            # curr_rank = dist.get_rank()
            # # 在推理评测时，只能计算prompt的准确率，response的准确率因为是新生成的，所以没有ground truth
            # if curr_rank == 0:
            #     print("pred_without_padding:", pred_without_padding)
            #     # print("pred:", pred)
            #     print("label_without_padding:", label_without_padding)
            #     # print("label:", label)
            #     # # print("label_mask:", label_mask)
            #     print("pred_without_padding:", pred_without_padding.shape)
            #     print("label_without_padding:", label_without_padding.shape)

            # # for safety inference only
            # pred_without_padding, label_without_padding = preds[i, :], labels[i, :]
            # # print("padding in pred_without_padding", len([pred for pred in pred_without_padding if pred == IGNORE_INDEX]))
            # # # # # print("padding in label_without_padding", len([label for label in label_without_padding if label == IGNORE_INDEX]))
            # label_without_padding_mask = label_without_padding != IGNORE_INDEX
            # label_without_padding = label_without_padding[label_without_padding_mask]
            # pred_without_padding_mask = pred_without_padding != IGNORE_INDEX
            # pred_without_padding = pred_without_padding[pred_without_padding_mask]
            # 二次筛选，删除prompt位置对应的safety token
            # pred_without_padding_mask = pred_without_padding != 128009
            # pred_without_padding = pred_without_padding[pred_without_padding_mask]
            # import torch.distributed as dist
            # curr_rank = dist.get_rank()
            # # 在推理评测时，只能计算prompt的准确率，response的准确率因为是新生成的，所以没有ground truth
            # if curr_rank == 0:
            #     print("pred_without_padding:", pred_without_padding)
            #     # print("pred:", pred)
            #     print("label_without_padding:", label_without_padding)
            #     # print("label:", label)
            #     # # print("label_mask:", label_mask)
            #     print("pred_without_padding:", pred_without_padding.shape)
            #     print("label_without_padding:", label_without_padding.shape)

            pred_first_and_last = [pred_without_padding[0], pred_without_padding[-1]]
            label_first_and_last = [label_without_padding[0], label_without_padding[-1]]
            pred_first = [pred_without_padding[0]]
            label_first = [label_without_padding[0]]
            pred_last = [pred_without_padding[-1]]
            label_last = [label_without_padding[-1]]

            # import torch.distributed as dist
            # curr_rank = dist.get_rank()
            # if curr_rank == 0:
                # print("pred shape in compute:", pred[label_mask].shape)
                # print("labels shape in compute:", label[label_mask].shape)
                # print("pred in compute:", pred[label_mask])
                # print("labels in compute:", label[label_mask])
                # print("pred_first:", pred_first, "  label_first:", label_first)
                # print("pred_first:", pred_first)
                # print("label_first:", label_first)
                # print("pred_last:", pred_last)
                # print("label_last:", label_last)

            # self.score_dict["accuracy"].append(np.mean(pred[label_mask] == label[label_mask]))

            self.score_dict["accuracy"].append(np.mean(pred_first == label_first))
            # self.score_dict["accuracy"].append(np.mean(pred_first == label_first))
            # self.score_dict["accuracy"].append(np.mean(pred_last == label_last))
            # print("accuracy all:", np.mean(pred_first_and_last == label_first_and_last))
            # print("accuracy last:", np.mean(pred_last == label_last))

            self.score_dict["accuracy_all"].append(np.mean(pred_first_and_last == label_first_and_last))
            self.score_dict["accuracy_response"].append(np.mean(pred_last == label_last))

        if compute_result:
            return self._dump()


@dataclass
class ComputeSimilarity:
    r"""
    Wraps the tokenizer into metric functions, used in CustomSeq2SeqTrainer.
    """

    tokenizer: "PreTrainedTokenizer"

    def _dump(self) -> Optional[Dict[str, float]]:
        result = None
        if hasattr(self, "score_dict"):
            result = {k: float(np.mean(v)) for k, v in self.score_dict.items()}

        self.score_dict = {"rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []}
        return result

    def __post_init__(self):
        self._dump()

    def __call__(self, eval_preds: "EvalPrediction", compute_result: bool = True) -> Optional[Dict[str, float]]:
        preds, labels = numpify(eval_preds.predictions), numpify(eval_preds.label_ids)

        preds = np.where(preds != IGNORE_INDEX, preds, self.tokenizer.pad_token_id)
        labels = np.where(labels != IGNORE_INDEX, labels, self.tokenizer.pad_token_id)

        decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True)
        decoded_labels = self.tokenizer.batch_decode(labels, skip_special_tokens=True)

        for pred, label in zip(decoded_preds, decoded_labels):
            hypothesis = list(jieba.cut(pred))
            reference = list(jieba.cut(label))

            if len(" ".join(hypothesis).split()) == 0 or len(" ".join(reference).split()) == 0:
                result = {"rouge-1": {"f": 0.0}, "rouge-2": {"f": 0.0}, "rouge-l": {"f": 0.0}}
            else:
                rouge = Rouge()
                scores = rouge.get_scores(" ".join(hypothesis), " ".join(reference))
                result = scores[0]

            for k, v in result.items():
                self.score_dict[k].append(round(v["f"] * 100, 4))

            bleu_score = sentence_bleu([list(label)], list(pred), smoothing_function=SmoothingFunction().method3)
            self.score_dict["bleu-4"].append(round(bleu_score * 100, 4))

        if compute_result:
            return self._dump()
