from typing import Any, Dict, List, Tuple
import math
import pandas as pd

from mledojo.metrics.base import CompetitionMetrics, InvalidSubmissionError


def _normalize_text(s: str) -> List[str]:
    if not isinstance(s, str):
        s = "" if s is None else str(s)
    s = s.lower()
    # Replace punctuation with space, keep alnum
    out_chars = []
    for ch in s:
        if ch.isalnum() or ch.isspace():
            out_chars.append(ch)
        else:
            out_chars.append(" ")
    s = "".join(out_chars)
    # Collapse whitespace
    tokens = [t for t in s.strip().split() if t]
    return tokens


def _ngram_counts(tokens: List[str], n: int) -> Dict[Tuple[str, ...], int]:
    counts: Dict[Tuple[str, ...], int] = {}
    if n <= 0 or len(tokens) < n:
        return counts
    for i in range(len(tokens) - n + 1):
        ng = tuple(tokens[i : i + n])
        counts[ng] = counts.get(ng, 0) + 1
    return counts


def _overlap(c_ref: Dict, c_hyp: Dict) -> int:
    keys = set(c_ref.keys()) & set(c_hyp.keys())
    return sum(min(c_ref[k], c_hyp[k]) for k in keys)


def _precision_recall_f1(overlap: int, hyp_total: int, ref_total: int) -> Tuple[float, float, float]:
    p = 0.0 if hyp_total == 0 else overlap / max(hyp_total, 1)
    r = 0.0 if ref_total == 0 else overlap / max(ref_total, 1)
    if p <= 0.0 or r <= 0.0:
        f1 = 0.0
    else:
        denom = p + r
        f1 = 0.0 if denom == 0 else 2 * p * r / denom
    # Guard numerical issues
    for v in (p, r, f1):
        if not (isinstance(v, float) and math.isfinite(v)):
            return 0.0, 0.0, 0.0
    return p, r, f1


def rouge_n_f1(ref: str, hyp: str, n: int) -> float:
    ref_t = _normalize_text(ref)
    hyp_t = _normalize_text(hyp)
    c_ref = _ngram_counts(ref_t, n)
    c_hyp = _ngram_counts(hyp_t, n)
    overlap = _overlap(c_ref, c_hyp)
    _, _, f1 = _precision_recall_f1(
        overlap,
        sum(c_hyp.values()),
        sum(c_ref.values()),
    )
    return f1


def _lcs_length(a: List[str], b: List[str]) -> int:
    # Space-optimized DP for LCS length
    if len(a) < len(b):
        a, b = b, a
    prev = [0] * (len(b) + 1)
    for i in range(1, len(a) + 1):
        curr = [0]
        ai = a[i - 1]
        for j in range(1, len(b) + 1):
            if ai == b[j - 1]:
                curr.append(prev[j - 1] + 1)
            else:
                curr.append(max(prev[j], curr[-1]))
        prev = curr
    return prev[-1]


def rouge_l_f1(ref: str, hyp: str) -> float:
    ref_t = _normalize_text(ref)
    hyp_t = _normalize_text(hyp)
    if not ref_t or not hyp_t:
        return 0.0
    lcs = _lcs_length(ref_t, hyp_t)
    _, _, f1 = _precision_recall_f1(lcs, len(hyp_t), len(ref_t))
    return f1


class PubMedSummarizationMetrics(CompetitionMetrics):
    """Metrics for PubMed Article Summarization using mean ROUGE-1/2/L F1.

    Expected inputs:
    - y_true: pandas DataFrame with columns ['id', 'abstract'] (e.g., test_answer.csv)
    - y_pred: pandas DataFrame with columns ['id', 'abstract'] (e.g., sample_submission.csv or participant submission)

    The evaluation sorts by the first column (id), aligns rows, and computes the
    average of ROUGE-1 F1, ROUGE-2 F1, and ROUGE-L F1 across all examples.
    """

    def __init__(self, value: str = "abstract", higher_is_better: bool = True):
        # Higher is better for ROUGE-based metrics
        super().__init__(higher_is_better)
        self.value = value

    def evaluate(self, y_true: pd.DataFrame, y_pred: pd.DataFrame) -> float:
        if not isinstance(y_true, pd.DataFrame) or not isinstance(y_pred, pd.DataFrame):
            raise InvalidSubmissionError("Both y_true and y_pred must be pandas DataFrames.")

        # Basic column checks
        for name, df in [("y_true", y_true), ("y_pred", y_pred)]:
            cols = list(df.columns)
            if len(cols) != 2 or cols[0] != "id" or self.value not in cols:
                raise InvalidSubmissionError(
                    f"{name} must have exactly two columns ['id', '{self.value}'] in that order."
                )

        if len(y_true) != len(y_pred):
            raise InvalidSubmissionError(
                f"Number of rows in predictions ({len(y_pred)}) does not match ground truth ({len(y_true)})."
            )

        # Sort both dataframes by id and align
        y_true = y_true.sort_values(by="id").reset_index(drop=True)
        y_pred = y_pred.sort_values(by="id").reset_index(drop=True)

        # ID alignment and uniqueness
        if not y_true["id"].is_unique or not y_pred["id"].is_unique:
            raise InvalidSubmissionError("IDs must be unique in both y_true and y_pred.")
        if (y_true["id"].values != y_pred["id"].values).any():
            raise InvalidSubmissionError("ID values do not match between ground truth and predictions.")

        r1_list: List[float] = []
        r2_list: List[float] = []
        rl_list: List[float] = []

        for ref, hyp in zip(y_true[self.value].astype(str), y_pred[self.value].astype(str)):
            r1 = rouge_n_f1(ref, hyp, n=1)
            r2 = rouge_n_f1(ref, hyp, n=2)
            rl = rouge_l_f1(ref, hyp)
            for v in (r1, r2, rl):
                if not (isinstance(v, float) and math.isfinite(v)):
                    r1, r2, rl = 0.0, 0.0, 0.0
                    break
            r1_list.append(r1)
            r2_list.append(r2)
            rl_list.append(rl)

        def _mean(xs: List[float]) -> float:
            if not xs:
                return 0.0
            return sum(xs) / len(xs)

        r1m = _mean(r1_list)
        r2m = _mean(r2_list)
        rlm = _mean(rl_list)
        final = (r1m + r2m + rlm) / 3.0
        # Clip to [0,1]
        final = max(0.0, min(1.0, final))
        return final

    def validate_submission(self, submission: Any, ground_truth: Any) -> tuple[bool, str]:
        if not isinstance(submission, pd.DataFrame):
            raise InvalidSubmissionError(
                "Submission must be a pandas DataFrame with columns ['id', 'abstract']."
            )
        if not isinstance(ground_truth, pd.DataFrame):
            raise InvalidSubmissionError(
                "Ground truth must be a pandas DataFrame with columns ['id', 'abstract']."
            )

        # Check columns must be exactly ['id','abstract']
        if list(submission.columns) != ["id", "abstract"]:
            raise InvalidSubmissionError(
                "Submission must have exactly two columns: ['id', 'abstract'] in that order."
            )
        if list(ground_truth.columns) != ["id", "abstract"]:
            raise InvalidSubmissionError(
                "Ground truth must have exactly two columns: ['id', 'abstract'] in that order."
            )

        if len(submission) != len(ground_truth):
            raise InvalidSubmissionError(
                f"Number of rows in submission ({len(submission)}) does not match ground truth ({len(ground_truth)})."
            )

        # Sort both by id and compare ids
        sub_sorted = submission.sort_values(by="id").reset_index(drop=True)
        gt_sorted = ground_truth.sort_values(by="id").reset_index(drop=True)

        # IDs must match exactly and be unique
        if not sub_sorted["id"].is_unique:
            raise InvalidSubmissionError("Duplicate ids in submission.")
        if not gt_sorted["id"].is_unique:
            raise InvalidSubmissionError("Duplicate ids in ground truth.")
        if (sub_sorted["id"].values != gt_sorted["id"].values).any():
            raise InvalidSubmissionError("Submission ids do not match the ground truth ids.")

        # No empty abstracts
        if sub_sorted["abstract"].fillna("").astype(str).str.strip().eq("").any():
            raise InvalidSubmissionError("Submission contains empty abstracts.")

        return True, "Submission is valid."
