from typing import Any, Dict
import math
import re
import pandas as pd

from mledojo.metrics.base import CompetitionMetrics, InvalidSubmissionError

# Normalization rules for fair ASR scoring
# - lowercase
# - convert to ascii (drop accents)
# - remove punctuation except internal apostrophes (e.g., don't)
# - collapse whitespace
_PUNCT_REGEX = re.compile(r"[^a-z0-9'\s]")
_WS_REGEX = re.compile(r"\s+")


def normalize_text(s: str) -> str:
    if s is None:
        return ""
    # to ascii by ignoring non-ascii
    s = s.encode("ascii", "ignore").decode("ascii")
    s = s.lower()
    # keep letters, numbers, apostrophes, and spaces
    s = _PUNCT_REGEX.sub(" ", s)
    # remove stray apostrophes at word boundaries
    s = re.sub(r"\b'+", " ", s)
    s = re.sub(r"'+\b", " ", s)
    # collapse whitespace
    s = _WS_REGEX.sub(" ", s).strip()
    return s


def wer(ref: str, hyp: str) -> float:
    """Compute word error rate (WER) with standard Levenshtein distance.
    WER = (S + D + I) / N, where N is number of words in reference.
    If reference is empty, define WER as 0.0 if hyp also empty, else 1.0.
    """
    ref_norm = normalize_text(ref)
    hyp_norm = normalize_text(hyp)
    r = ref_norm.split()
    h = hyp_norm.split()
    if len(r) == 0:
        return 0.0 if len(h) == 0 else 1.0
    # DP computation
    R, H = len(r), len(h)
    dp = [[0] * (H + 1) for _ in range(R + 1)]
    for i in range(R + 1):
        dp[i][0] = i
    for j in range(H + 1):
        dp[0][j] = j
    for i in range(1, R + 1):
        for j in range(1, H + 1):
            cost = 0 if r[i - 1] == h[j - 1] else 1
            dp[i][j] = min(
                dp[i - 1][j] + 1,  # deletion
                dp[i][j - 1] + 1,  # insertion
                dp[i - 1][j - 1] + cost,  # substitution
            )
    dist = dp[R][H]
    wer_value = dist / max(1, len(r))
    # guard for NaN/Inf
    if not (wer_value >= 0.0 and wer_value <= float("inf")):
        return 1.0
    return float(wer_value)


def corpus_wer(refs: Dict[str, str], hyps: Dict[str, str]) -> float:
    """Compute corpus-level average WER across utterances.
    Missing predictions count as empty hypotheses. Extra predictions are ignored.
    The final score is the average WER (lower is better).
    """
    if not refs:
        return 0.0
    wers = []
    for k, ref in refs.items():
        hyp = hyps.get(k, "")
        w = wer(ref, hyp)
        if math.isnan(w) or math.isinf(w):
            w = 1.0
        # clip to [0, 1] since per-utterance WER is bounded by 1 when normalized this way
        w = max(0.0, min(1.0, w))
        wers.append(w)
    return float(sum(wers) / len(wers))


class LJSpeechASRMetrics(CompetitionMetrics):
    """Metric class for LJ Speech ASR competition using average WER (lower is better)."""

    def __init__(self, value: str = "transcript", higher_is_better: bool = False):
        super().__init__(higher_is_better)
        self.value = value

    def evaluate(self, y_true: Any, y_pred: Any) -> float:
        # Expect pandas DataFrames with columns [clip_id, transcript]
        if not isinstance(y_true, pd.DataFrame):
            raise InvalidSubmissionError("y_true must be a pandas DataFrame.")
        if not isinstance(y_pred, pd.DataFrame):
            raise InvalidSubmissionError("y_pred must be a pandas DataFrame.")
        required_cols = {"clip_id", self.value}
        if not required_cols.issubset(set(y_true.columns)):
            raise InvalidSubmissionError(
                f"y_true must contain columns {required_cols}, got {set(y_true.columns)}"
            )
        if not required_cols.issubset(set(y_pred.columns)):
            raise InvalidSubmissionError(
                f"y_pred must contain columns {required_cols}, got {set(y_pred.columns)}"
            )

        # Sort both dataframes by first column before calculating score
        y_true = y_true.sort_values(by=y_true.columns[0]).reset_index(drop=True)
        y_pred = y_pred.sort_values(by=y_pred.columns[0]).reset_index(drop=True)

        # Build dicts
        refs = {str(r.clip_id): str(getattr(r, self.value)) for r in y_true.itertuples(index=False)}
        hyps = {str(r.clip_id): str(getattr(r, self.value)) for r in y_pred.itertuples(index=False)}
        return corpus_wer(refs, hyps)

    def validate_submission(self, submission: Any, ground_truth: Any) -> str:
        if not isinstance(submission, pd.DataFrame):
            raise InvalidSubmissionError(
                "Submission must be a pandas DataFrame. Please provide a valid pandas DataFrame."
            )
        if not isinstance(ground_truth, pd.DataFrame):
            raise InvalidSubmissionError(
                "Ground truth must be a pandas DataFrame. Please provide a valid pandas DataFrame."
            )

        sub_cols = set(submission.columns)
        true_cols = set(ground_truth.columns)
        # We expect both to have [clip_id, transcript]
        required_cols = {"clip_id", self.value}
        missing_cols = required_cols - sub_cols
        if missing_cols:
            raise InvalidSubmissionError(
                f"Missing required columns in submission: {', '.join(sorted(missing_cols))}."
            )
        missing_gt_cols = required_cols - true_cols
        if missing_gt_cols:
            raise InvalidSubmissionError(
                f"Missing required columns in ground truth: {', '.join(sorted(missing_gt_cols))}."
            )

        if len(submission) != len(ground_truth):
            raise InvalidSubmissionError(
                f"Number of rows in submission ({len(submission)}) does not match ground truth ({len(ground_truth)})."
            )

        # Sort the submission and ground truth by the first column and compare IDs
        submission_sorted = submission.sort_values(by=submission.columns[0]).reset_index(drop=True)
        ground_truth_sorted = ground_truth.sort_values(by=ground_truth.columns[0]).reset_index(drop=True)

        if (
            submission_sorted[submission_sorted.columns[0]].values
            != ground_truth_sorted[ground_truth_sorted.columns[0]].values
        ).any():
            raise InvalidSubmissionError(
                "First column values (IDs) do not match between submission and ground truth."
            )

        return "Submission is valid."
