from typing import Any
import pandas as pd
import numpy as np

from mledojo.metrics.base import CompetitionMetrics, InvalidSubmissionError

# This competition evaluates abstractive summaries with ROUGE-L F1 (higher is better).
# Expected CSV schemas
# - sample_submission.csv: [id, summary]
# - test_answer.csv:       [id, summary]
# - test.csv:              [id, article]


def _normalize_text(s: str) -> str:
    if not isinstance(s, str):
        s = "" if s is None else str(s)
    s = s.lower().strip()
    s = " ".join(s.split())
    return s


def _tokenize(s: str) -> list[str]:
    s = _normalize_text(s)
    if not s:
        return []
    return s.split(" ")


def _lcs_length(a: list[str], b: list[str]) -> int:
    n, m = len(a), len(b)
    if n == 0 or m == 0:
        return 0
    prev = [0] * (m + 1)
    for i in range(1, n + 1):
        ai = a[i - 1]
        curr = [0] * (m + 1)
        for j in range(1, m + 1):
            if ai == b[j - 1]:
                curr[j] = prev[j - 1] + 1
            else:
                curr[j] = prev[j] if prev[j] >= curr[j - 1] else curr[j - 1]
        prev = curr
    return prev[m]


def rouge_l_f1(pred: str, ref: str) -> float:
    pred_toks = _tokenize(pred)
    ref_toks = _tokenize(ref)
    if len(pred_toks) == 0 and len(ref_toks) == 0:
        return 1.0
    if len(pred_toks) == 0 or len(ref_toks) == 0:
        return 0.0
    lcs = _lcs_length(pred_toks, ref_toks)
    p = lcs / max(len(pred_toks), 1)
    r = lcs / max(len(ref_toks), 1)
    if p + r == 0:
        return 0.0
    f1 = 2 * p * r / (p + r)
    if not np.isfinite(f1):
        return 0.0
    return float(max(0.0, min(1.0, f1)))


class NewsSummarizationMetrics(CompetitionMetrics):
    """ROUGE-L F1 metric for CNN/DailyMail style summarization.

    value: name of the prediction column (defaults to 'summary').
    higher_is_better: True for ROUGE-L F1.
    """

    def __init__(self, value: str = "summary", higher_is_better: bool = True):
        super().__init__(higher_is_better)
        self.value = value

    def evaluate(self, y_true: pd.DataFrame, y_pred: pd.DataFrame) -> float:
        # Validate inputs
        if not isinstance(y_true, pd.DataFrame):
            raise InvalidSubmissionError("y_true must be a pandas DataFrame.")
        if not isinstance(y_pred, pd.DataFrame):
            raise InvalidSubmissionError("y_pred must be a pandas DataFrame.")
        if "id" not in y_true.columns:
            raise InvalidSubmissionError("y_true must contain an 'id' column.")
        if self.value not in y_true.columns:
            raise InvalidSubmissionError(f"y_true must contain a '{self.value}' column.")
        if "id" not in y_pred.columns:
            raise InvalidSubmissionError("y_pred must contain an 'id' column.")
        if self.value not in y_pred.columns:
            raise InvalidSubmissionError(f"y_pred must contain a '{self.value}' column.")

        # Sort and align by id, check ids identical
        y_true_sorted = y_true.sort_values(by=["id"]).reset_index(drop=True)
        y_pred_sorted = y_pred.sort_values(by=["id"]).reset_index(drop=True)
        if len(y_true_sorted) != len(y_pred_sorted):
            raise InvalidSubmissionError(
                f"Row count mismatch: y_true={len(y_true_sorted)}, y_pred={len(y_pred_sorted)}"
            )
        if (y_true_sorted["id"].values != y_pred_sorted["id"].values).any():
            raise InvalidSubmissionError("The 'id' values of y_true and y_pred must match exactly in order.")

        # Compute mean ROUGE-L F1
        scores: list[float] = []
        for ref, pred in zip(y_true_sorted[self.value].astype(str), y_pred_sorted[self.value].astype(str)):
            try:
                s = rouge_l_f1(pred, ref)
            except Exception:
                s = 0.0
            if not np.isfinite(s):
                s = 0.0
            scores.append(max(0.0, min(1.0, float(s))))

        if not scores:
            return 0.0
        m = float(np.mean(scores))
        if not np.isfinite(m):
            return 0.0
        return max(0.0, min(1.0, m))

    def validate_submission(self, submission: Any, ground_truth: Any) -> str:
        # Type checks
        if not isinstance(submission, pd.DataFrame):
            raise InvalidSubmissionError(
                "Submission must be a pandas DataFrame. Please provide a valid pandas DataFrame."
            )
        if not isinstance(ground_truth, pd.DataFrame):
            raise InvalidSubmissionError(
                "Ground truth must be a pandas DataFrame. Please provide a valid pandas DataFrame."
            )

        # Column checks
        required_cols = {"id", self.value}
        if set(submission.columns) != required_cols:
            extra_cols = set(submission.columns) - required_cols
            missing_cols = required_cols - set(submission.columns)
            msg_parts = []
            if missing_cols:
                msg_parts.append(f"Missing required columns: {', '.join(sorted(missing_cols))}.")
            if extra_cols:
                msg_parts.append(f"Extra unexpected columns: {', '.join(sorted(extra_cols))}.")
            raise InvalidSubmissionError(" ".join(msg_parts) or "Submission has incorrect columns.")

        if "id" not in ground_truth.columns:
            raise InvalidSubmissionError("Ground truth must contain an 'id' column.")

        if submission.shape[0] != ground_truth.shape[0]:
            raise InvalidSubmissionError(
                f"Number of rows in submission ({len(submission)}) does not match ground truth ({len(ground_truth)})."
            )

        # Sort by id and compare
        sub_sorted = submission.sort_values(by=["id"]).reset_index(drop=True)
        gt_sorted = ground_truth.sort_values(by=["id"]).reset_index(drop=True)
        if (sub_sorted["id"].values != gt_sorted["id"].values).any():
            raise InvalidSubmissionError(
                "First column values (ids) do not match between submission and ground truth."
            )

        # Null checks (ids cannot be null; predictions can be empty strings but not NaN)
        if sub_sorted["id"].isnull().any():
            raise InvalidSubmissionError("Submission contains null ids.")
        if sub_sorted[self.value].isnull().any():
            raise InvalidSubmissionError(
                f"Submission contains null values in '{self.value}'. Use empty strings if necessary."
            )

        # Duplicate id check
        if sub_sorted["id"].duplicated().any():
            raise InvalidSubmissionError("Duplicate ids found in submission.")

        return "Submission is valid."
