from __future__ import annotations

from typing import Any, Tuple
import pandas as pd
import numpy as np
import re

from mledojo.metrics.base import CompetitionMetrics, InvalidSubmissionError


_WORD_RE = re.compile(r"[^a-z0-9' ]+")


def _normalize_text(s: str) -> str:
    if s is None:
        return ""
    s = s.strip().lower().replace("-", " ")
    s = _WORD_RE.sub(" ", s)
    s = re.sub(r"\s+", " ", s).strip()
    return s


def _edit_ops_counts(ref_tokens: list[str], hyp_tokens: list[str]) -> tuple[int, int, int, int]:
    n, m = len(ref_tokens), len(hyp_tokens)
    dp = [[0] * (m + 1) for _ in range(n + 1)]
    bp = [[None] * (m + 1) for _ in range(n + 1)]
    for i in range(1, n + 1):
        dp[i][0] = i
        bp[i][0] = "del"
    for j in range(1, m + 1):
        dp[0][j] = j
        bp[0][j] = "ins"
    for i in range(1, n + 1):
        for j in range(1, m + 1):
            if ref_tokens[i - 1] == hyp_tokens[j - 1]:
                dp[i][j] = dp[i - 1][j - 1]
                bp[i][j] = "ok"
            else:
                sub = dp[i - 1][j - 1] + 1
                ins = dp[i][j - 1] + 1
                delete = dp[i - 1][j] + 1
                best = min(sub, ins, delete)
                dp[i][j] = best
                if best == sub:
                    bp[i][j] = "sub"
                elif best == ins:
                    bp[i][j] = "ins"
                else:
                    bp[i][j] = "del"
    i, j = n, m
    S = D = I = 0
    while i > 0 or j > 0:
        op = bp[i][j]
        if op == "ok":
            i -= 1
            j -= 1
        elif op == "sub":
            S += 1
            i -= 1
            j -= 1
        elif op == "del":
            D += 1
            i -= 1
        elif op == "ins":
            I += 1
            j -= 1
        else:
            break
    N = n
    return S, D, I, N


class LibriSpeechASRMetrics(CompetitionMetrics):
    """Corpus-level WER for LibriSpeech-style ASR submissions.

    Expected CSV/DataFrame schema:
    - ground truth (private/test_answer.csv): columns ['id', 'transcript']
    - submission (public/sample_submission.csv or user submission): columns ['id', 'transcript']
    """

    def __init__(self, value: str = "transcript", higher_is_better: bool = False):
        super().__init__(higher_is_better)
        self.value = value

    def evaluate(self, y_true: Any, y_pred: Any) -> float:
        if not isinstance(y_true, pd.DataFrame):
            raise InvalidSubmissionError("Ground truth must be a pandas DataFrame.")
        if not isinstance(y_pred, pd.DataFrame):
            raise InvalidSubmissionError("Submission must be a pandas DataFrame.")

        # Validate first to ensure alignment
        self.validate_submission(y_pred, y_true)

        # Sort both by id (first column)
        y_true = y_true.sort_values(by=y_true.columns[0]).reset_index(drop=True)
        y_pred = y_pred.sort_values(by=y_pred.columns[0]).reset_index(drop=True)

        # Compute corpus WER
        S_total = D_total = I_total = N_total = 0
        for i in range(len(y_true)):
            ref = _normalize_text(str(y_true.loc[i, self.value]))
            hyp = _normalize_text(str(y_pred.loc[i, self.value]))
            ref_tokens = ref.split() if ref else []
            hyp_tokens = hyp.split() if hyp else []
            S, D, I, N = _edit_ops_counts(ref_tokens, hyp_tokens)
            S_total += S
            D_total += D
            I_total += I
            N_total += N
        if N_total == 0:
            return 0.0 if (S_total + D_total + I_total) == 0 else 1.0
        return float((S_total + D_total + I_total) / N_total)

    def validate_submission(self, submission: Any, ground_truth: Any) -> str:
        if not isinstance(submission, pd.DataFrame):
            raise InvalidSubmissionError("Submission must be a pandas DataFrame.")
        if not isinstance(ground_truth, pd.DataFrame):
            raise InvalidSubmissionError("Ground truth must be a pandas DataFrame.")

        if len(submission) != len(ground_truth):
            raise InvalidSubmissionError(
                f"Number of rows in submission ({len(submission)}) does not match ground truth ({len(ground_truth)})."
            )

        # Sort by the first column (id) and ensure identical
        sub_sorted = submission.sort_values(by=submission.columns[0]).reset_index(drop=True)
        gt_sorted = ground_truth.sort_values(by=ground_truth.columns[0]).reset_index(drop=True)

        if (sub_sorted[sub_sorted.columns[0]].values != gt_sorted[gt_sorted.columns[0]].values).any():
            raise InvalidSubmissionError(
                "First-column identifiers do not match ground truth test ids."
            )

        # Must contain exactly the id column and the transcript value column
        required_cols = set([ground_truth.columns[0], self.value])
        sub_cols = set(submission.columns)
        if sub_cols != required_cols:
            missing = required_cols - sub_cols
            extra = sub_cols - required_cols
            msgs = []
            if missing:
                msgs.append(f"missing columns: {sorted(missing)}")
            if extra:
                msgs.append(f"extra columns: {sorted(extra)}")
            raise InvalidSubmissionError(
                "Submission must have exactly columns {'id','transcript'}; " + ", ".join(msgs)
            )

        # Basic type checks for transcript column
        if submission[self.value].isnull().any():
            raise InvalidSubmissionError("Submission contains null values in transcript column.")

        return "Submission is valid."
