# Copyright 2025 HuggingFace Inc., THUDM, and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's transformers library and the THUDM's ChatGLM implementation.
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/summarization/run_summarization.py
# https://github.com/THUDM/ChatGLM-6B/blob/main/ptuning/main.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# You may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass
from fractions import Fraction
from typing import TYPE_CHECKING, Optional, Dict, List

import numpy as np
import torch
import re

from transformers.utils import is_jieba_available
from ...extras.constants import IGNORE_INDEX
from ...extras.misc import numpify
from ...extras.packages import is_rouge_available

# --- Optional packages ---
if is_jieba_available():
    import jieba  # type: ignore

if is_rouge_available():
    from rouge_chinese import Rouge  # type: ignore

# sacreBLEU for corpus-level BLEU
import sacrebleu

# METEOR (NLTK)
try:
    from nltk.translate.meteor_score import meteor_score  # type: ignore
except Exception:  # pragma: no cover
    meteor_score = None  # Fallback: disabled

# BERTScore via HuggingFace evaluate
try:
    import evaluate as hf_evaluate  # type: ignore
except Exception:  # pragma: no cover
    hf_evaluate = None

if TYPE_CHECKING:
    from transformers import EvalPrediction, PreTrainedTokenizer


def eval_logit_processor(logits: "torch.Tensor", labels: "torch.Tensor") -> "torch.Tensor":
    r"""Compute the token with the largest likelihood to reduce memory footprint."""
    if isinstance(logits, (list, tuple)):
        if logits[0].dim() == 3:  # (batch_size, seq_len, vocab_size)
            logits = logits[0]
        else:  # moe models have aux loss
            logits = logits[1]

    if logits.dim() != 3:
        raise ValueError("Cannot process the logits.")

    return torch.argmax(logits, dim=-1)


@dataclass
class ComputeAccuracy:
    r"""Compute accuracy and support `batch_eval_metrics`."""

    def _dump(self) -> Optional[Dict[str, float]]:
        result = None
        if hasattr(self, "score_dict"):
            result = {k: float(np.mean(v)) for k, v in self.score_dict.items()}

        self.score_dict = {"accuracy": []}
        return result

    def __post_init__(self):
        self._dump()

    def __call__(self, eval_preds: "EvalPrediction", compute_result: bool = True) -> Optional[Dict[str, float]]:
        preds, labels = numpify(eval_preds.predictions), numpify(eval_preds.label_ids)
        for i in range(len(preds)):
            pred, label = preds[i, :-1], labels[i, 1:]
            label_mask = label != IGNORE_INDEX
            self.score_dict["accuracy"].append(np.mean(pred[label_mask] == label[label_mask]))

        if compute_result:
            return self._dump()


# ---------------- helpers for math/common acc ----------------

def is_number(s):
    try:
        float(s)
        return True
    except ValueError:
        pass
    try:
        import unicodedata
        unicodedata.numeric(s)
        return True
    except (TypeError, ValueError):
        pass
    return False


def extract_gsm_num(completion):
    # Regex pattern to find the number following '####'
    text = completion.split('#### ')
    if len(text) > 1:
        extract_ans = text[-1].strip()
        match = re.search(r'[\-+]?\d*[\.,/]?\d+', extract_ans)
        if match:
            if '/' in match.group():
                denominator = match.group().split('/')[1]
                numerator = match.group().split('/')[0]
                if is_number(denominator) == True and is_number(numerator) == True:
                    if denominator == '0':
                        return round(float(numerator.replace(',', '')))
                    else:
                        frac = Fraction(match.group().replace(',', ''))
                        num_numerator = frac.numerator
                        num_denominator = frac.denominator
                        return round(float(num_numerator / num_denominator))
                else:
                    return None
            else:
                if float(match.group().replace(',', '')) == float('inf'):
                    return None
                return round(float(match.group().replace(',', '')))
        else:
            return None
    else:
        return None


def extract_commonsense_from_text(completion):
    text = completion.split('the correct answer is')
    if len(text) > 1:
        extract_ans = text[-1].strip()
        match = re.search(r'\b\w+\b', extract_ans)
        if match:
            return match.group(0)
        else:
            return None
    else:
        return None


def compute_accuracy(predictions: List[str], references: List[str]) -> float:
    """Compute accuracy by comparing predictions and references."""
    correct = sum([pred == ref for pred, ref in zip(predictions, references)])
    return correct / len(predictions)


def _contains_chinese(s: str) -> bool:
    return bool(re.search(r"[\u4e00-\u9fff]", s))


def _tok_for_meteor(text: str) -> List[str]:
    """Tokenize for METEOR: English whitespace; Chinese via jieba if available."""
    if _contains_chinese(text) and is_jieba_available():
        return list(jieba.cut(text))
    return text.split()


@dataclass
class ComputeSimilarity:
    r"""Compute text similarity scores and support `batch_eval_metrics`.

    Adds corpus-BLEU (sacreBLEU), METEOR (NLTK), BERTScore (evaluate),
    and 'average' = mean(BLEU, BERT-F1, BERT-R, BERT-P, METEOR, ROUGE-L).

    NOTE:
    - Keeps original `common_acc` and `math_acc` logic unchanged.
    - Leaves `acc` key present but not populated (may be NaN upstream).
    """

    tokenizer: "PreTrainedTokenizer"

    def _dump(self) -> Optional[Dict[str, float]]:
        result = None
        if hasattr(self, "score_dict"):
            result = {k: float(np.mean(v)) for k, v in self.score_dict.items() if len(v) > 0}

            # Compute the table-style Average over six metrics
            needed = ["bleu-4", "bert_f1", "bert_r", "bert_p", "meteor", "rouge-l"]
            if all(k in result for k in needed):
                result["average"] = float(np.mean([result[k] for k in needed]))

        # initialize containers
        self.score_dict = {
            "rouge-1": [], "rouge-2": [], "rouge-l": [],
            "bleu-4": [],
            "meteor": [],
            "bert_f1": [], "bert_p": [], "bert_r": [],
            "acc": [],  # kept for compatibility; not filled here
            "common_acc": [], "math_acc": []
        }
        return result

    def __post_init__(self):
        self._dump()

    def __call__(self, eval_preds: "EvalPrediction", compute_result: bool = True) -> Optional[Dict[str, float]]:
        preds, labels = numpify(eval_preds.predictions), numpify(eval_preds.label_ids)

        preds = np.where(preds != IGNORE_INDEX, preds, self.tokenizer.pad_token_id)
        labels = np.where(labels != IGNORE_INDEX, labels, self.tokenizer.pad_token_id)

        decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True)
        decoded_labels = self.tokenizer.batch_decode(labels, skip_special_tokens=True)

        # ---- Per-sample metrics: ROUGE + METEOR + (common/math acc kept) ----
        rouge = Rouge() if is_rouge_available() else None

        for pred, label in zip(decoded_preds, decoded_labels):
            # Tokenization for ROUGE: Chinese -> jieba; otherwise whitespace
            if is_jieba_available():
                hypothesis = list(jieba.cut(pred))
                reference = list(jieba.cut(label))
            else:
                hypothesis = pred.split()
                reference = label.split()

            # ROUGE (F1 components)
            if len(" ".join(hypothesis).split()) == 0 or len(" ".join(reference).split()) == 0 or rouge is None:
                r1 = r2 = rl = 0.0
            else:
                scores = rouge.get_scores(" ".join(hypothesis), " ".join(reference))[0]
                r1, r2, rl = scores["rouge-1"]["f"], scores["rouge-2"]["f"], scores["rouge-l"]["f"]

            self.score_dict["rouge-1"].append(round(r1 * 100, 4))
            self.score_dict["rouge-2"].append(round(r2 * 100, 4))
            self.score_dict["rouge-l"].append(round(rl * 100, 4))

            # === Keep original math_acc logic ===
            pre_num = extract_gsm_num(pred)
            ref_num = extract_gsm_num(label)
            self.score_dict["math_acc"].append(1.0 if pre_num == ref_num else 0.0)

            # === Keep original common_acc logic ===
            common_pre = extract_commonsense_from_text(pred)
            common_ref = extract_commonsense_from_text(label)
            self.score_dict["common_acc"].append(1.0 if common_pre == common_ref else 0.0)

            # METEOR (sentence level; averaged later)
            if meteor_score is not None:
                try:
                    mete = meteor_score([_tok_for_meteor(label)], _tok_for_meteor(pred))
                except Exception:
                    mete = 0.0
            else:
                mete = 0.0
            self.score_dict["meteor"].append(round(mete * 100, 4))

        # ---- Corpus-level BLEU (sacreBLEU) ----
        use_zh = any(_contains_chinese(x) for x in decoded_preds) or any(_contains_chinese(x) for x in decoded_labels)
        tokenize = "zh" if use_zh else "13a"

        nonempty_exists = any(len(p.strip()) > 0 for p in decoded_preds) and any(len(r.strip()) > 0 for r in decoded_labels)
        if nonempty_exists:
            cb = sacrebleu.corpus_bleu(decoded_preds, [decoded_labels], tokenize=tokenize, lowercase=False)
            corpus_bleu = round(cb.score, 4)
        else:
            corpus_bleu = 0.0
        # Backfill corpus BLEU so mean() returns the same single score
        self.score_dict["bleu-4"] = [corpus_bleu] * max(1, len(decoded_preds))

        # ---- BERTScore (averaged; returns per-sample arrays) ----
        if hf_evaluate is not None and len(decoded_preds) > 0:
            try:

                bs_metric = hf_evaluate.load("bertscore")
                bs_res = bs_metric.compute(
                    predictions=decoded_preds,
                    references=decoded_labels,
                    lang="en",                 # change to "zh" if Chinese-only
                    model_type="roberta-large",
                    rescale_with_baseline=True
                )
                # Append per-sample values (they're already lists)
                for v in bs_res["f1"]:
                    self.score_dict["bert_f1"].append(round(float(v) * 100, 4))
                for v in bs_res["precision"]:
                    self.score_dict["bert_p"].append(round(float(v) * 100, 4))
                for v in bs_res["recall"]:
                    self.score_dict["bert_r"].append(round(float(v) * 100, 4))
            except Exception:
                # If BERTScore failed, keep them as zeros of proper length
                self._fill_bertscore_zeros(len(decoded_preds))
        else:
            self._fill_bertscore_zeros(len(decoded_preds))

        if compute_result:
            return self._dump()

    def _fill_bertscore_zeros(self, n: int):
        if n <= 0:
            n = 1
        self.score_dict["bert_f1"] += [0.0] * n
        self.score_dict["bert_p"]  += [0.0] * n
        self.score_dict["bert_r"]  += [0.0] * n
