from typing import Any
import pandas as pd
import numpy as np

from mledojo.metrics.base import CompetitionMetrics, InvalidSubmissionError


VALID_INTONATIONS = {"neutral", "bored", "excited", "question"}


class SpokenNumeralsMetrics(CompetitionMetrics):
    """
    Evaluation metric for Spoken Numerals with Intonation.

    Computes a weighted accuracy score combining Numeral accuracy (0.7),
    Intonation accuracy (0.3), and a bonus for joint correctness (0.2),
    clipped to [0, 1].
    """

    def __init__(self, value: str | None = None, higher_is_better: bool = True):
        super().__init__(higher_is_better)
        # `value` kept for API compatibility with samples; not used here
        self.value = value

    # y_true: DataFrame with columns [id, Numeral, Intonation]
    # y_pred: DataFrame with columns [id, Numeral, Intonation]
    def evaluate(self, y_true: pd.DataFrame, y_pred: pd.DataFrame) -> float:
        # Validate and align
        self.validate_submission(y_pred, y_true)

        true_df = y_true.copy()
        pred_df = y_pred.copy()

        # Normalize types
        true_df["Intonation"] = (
            true_df["Intonation"].astype(str).str.strip().str.lower()
        )
        pred_df["Intonation"] = (
            pred_df["Intonation"].astype(str).str.strip().str.lower()
        )
        # coerce Numeral to int
        true_df["Numeral"] = true_df["Numeral"].apply(lambda x: int(float(x)))
        pred_df["Numeral"] = pred_df["Numeral"].apply(lambda x: int(float(x)))

        # Sort by id and align
        true_df = true_df.sort_values(by=true_df.columns[0]).reset_index(drop=True)
        pred_df = pred_df.sort_values(by=pred_df.columns[0]).reset_index(drop=True)

        num_correct = (pred_df["Numeral"].values == true_df["Numeral"].values)
        int_correct = (pred_df["Intonation"].values == true_df["Intonation"].values)
        joint_correct = num_correct & int_correct

        num_acc = float(num_correct.mean()) if len(num_correct) else 0.0
        int_acc = float(int_correct.mean()) if len(int_correct) else 0.0
        joint_acc = float(joint_correct.mean()) if len(joint_correct) else 0.0

        score = 0.7 * num_acc + 0.3 * int_acc + 0.2 * joint_acc
        score = float(np.clip(score, 0.0, 1.0))
        return score

    def validate_submission(self, submission: Any, ground_truth: Any) -> str:
        # Type checks
        if not isinstance(submission, pd.DataFrame):
            raise InvalidSubmissionError(
                "Submission must be a pandas DataFrame with columns [id, Numeral, Intonation]."
            )
        if not isinstance(ground_truth, pd.DataFrame):
            raise InvalidSubmissionError(
                "Ground truth must be a pandas DataFrame with columns [id, Numeral, Intonation]."
            )

        required_cols = ["id", "Numeral", "Intonation"]
        # Column checks: exact set as required (order will be normalized later)
        sub_cols = list(submission.columns)
        if sub_cols != required_cols:
            raise InvalidSubmissionError(
                f"Submission must have columns exactly {required_cols} in this order."
            )

        true_cols = list(ground_truth.columns)
        if not {"id", "Numeral", "Intonation"}.issubset(set(true_cols)):
            raise InvalidSubmissionError(
                "Ground truth must contain columns: id, Numeral, Intonation."
            )

        if len(submission) != len(ground_truth):
            raise InvalidSubmissionError(
                f"Submission has {len(submission)} rows but expected {len(ground_truth)}."
            )

        # Ensure ids match and are unique
        if not submission["id"].is_unique:
            raise InvalidSubmissionError("Duplicate ids found in submission.")
        if set(submission["id"]) != set(ground_truth["id"]):
            raise InvalidSubmissionError(
                "Submission ids must match ground truth ids exactly."
            )

        # Sort and check first column values match exactly
        sub_sorted = submission.sort_values(by=submission.columns[0]).reset_index(drop=True)
        true_sorted = ground_truth.sort_values(by=ground_truth.columns[0]).reset_index(
            drop=True
        )
        if (sub_sorted["id"].values != true_sorted["id"].values).any():
            raise InvalidSubmissionError(
                "Submission ids do not align with ground truth after sorting."
            )

        # Validate Numeral and Intonation values
        # Numeral must be coercible to int
        try:
            _ = sub_sorted["Numeral"].apply(lambda x: int(float(x)))
        except Exception:
            raise InvalidSubmissionError("Numeral column must contain integers.")

        # Intonation values must be within the allowed set
        ints = sub_sorted["Intonation"].astype(str).str.strip().str.lower()
        if not set(ints.unique()).issubset(VALID_INTONATIONS):
            raise InvalidSubmissionError(
                f"Intonation values must be among {sorted(VALID_INTONATIONS)}."
            )

        return "Submission is valid."
