from typing import Any
import math
import pandas as pd
from mledojo.metrics.base import CompetitionMetrics, InvalidSubmissionError


class ArabicNLIMetrics(CompetitionMetrics):
    """Metric class for Arabic NLI classification using Macro-F1.

    Expected CSV schemas:
    - Ground truth (test_answer.csv): columns ['id', 'label']
    - Submission (sample_submission.csv): columns ['id', 'label']
    Labels are integers in {0, 1, 2}.
    """

    VALID_LABELS = {0, 1, 2}

    def __init__(self, value: str = "label", higher_is_better: bool = True):
        super().__init__(higher_is_better)
        self.value = value

    @staticmethod
    def _safe_int_series(s: pd.Series, name: str) -> pd.Series:
        if s.isnull().any():
            raise InvalidSubmissionError(f"Column '{name}' contains null values.")
        try:
            s = pd.to_numeric(s, errors="raise").astype(int)
        except Exception as e:
            raise InvalidSubmissionError(f"Column '{name}' must be integer-castable: {e}")
        return s

    @staticmethod
    def _f1(precision: float, recall: float) -> float:
        if precision <= 0.0 or recall <= 0.0:
            return 0.0
        denom = precision + recall
        if denom <= 0.0 or not math.isfinite(denom):
            return 0.0
        return 2.0 * precision * recall / denom

    @classmethod
    def _macro_f1(cls, y_true: list[int], y_pred: list[int], labels: list[int]) -> float:
        # Compute per-class precision/recall robustly
        per_label_tp = {c: 0 for c in labels}
        per_label_fp = {c: 0 for c in labels}
        per_label_fn = {c: 0 for c in labels}

        for t, p in zip(y_true, y_pred):
            if t == p:
                per_label_tp[t] += 1
            else:
                per_label_fp[p] += 1
                per_label_fn[t] += 1
        f1s = []
        for c in labels:
            tp = per_label_tp[c]
            fp = per_label_fp[c]
            fn = per_label_fn[c]
            prec = 0.0 if (tp + fp) == 0 else tp / (tp + fp)
            rec = 0.0 if (tp + fn) == 0 else tp / (tp + fn)
            f1s.append(cls._f1(prec, rec))
        return float(sum(f1s) / len(f1s)) if f1s else 0.0

    def evaluate(self, y_true: pd.DataFrame, y_pred: pd.DataFrame) -> float:
        # Sort both dataframes by first column before calculating score
        y_true = y_true.sort_values(by=y_true.columns[0]).reset_index(drop=True)
        y_pred = y_pred.sort_values(by=y_pred.columns[0]).reset_index(drop=True)

        # Validate required columns
        if self.value not in y_true.columns:
            raise InvalidSubmissionError(f"Ground truth missing column: '{self.value}'")
        if self.value not in y_pred.columns:
            raise InvalidSubmissionError(f"Submission missing column: '{self.value}'")

        # Extract aligned lists
        yt = self._safe_int_series(y_true[self.value], self.value).tolist()
        yp = self._safe_int_series(y_pred[self.value], self.value).tolist()

        # Validate labels are within expected set
        bad_true = set(yt) - self.VALID_LABELS
        bad_pred = set(yp) - self.VALID_LABELS
        if bad_true:
            raise InvalidSubmissionError(f"Ground truth contains invalid labels: {bad_true}")
        if bad_pred:
            raise InvalidSubmissionError(f"Submission contains invalid labels: {bad_pred}")

        score = self._macro_f1(yt, yp, labels=sorted(self.VALID_LABELS))
        if not (0.0 <= score <= 1.0) or not math.isfinite(score):
            raise InvalidSubmissionError(f"Computed score is out of bounds or not finite: {score}")
        return score

    def validate_submission(self, submission: Any, ground_truth: Any) -> str:
        if not isinstance(submission, pd.DataFrame):
            raise InvalidSubmissionError(
                "Submission must be a pandas DataFrame. Please provide a valid pandas DataFrame."
            )
        if not isinstance(ground_truth, pd.DataFrame):
            raise InvalidSubmissionError(
                "Ground truth must be a pandas DataFrame. Please provide a valid pandas DataFrame."
            )

        # Check length
        if len(submission) != len(ground_truth):
            raise InvalidSubmissionError(
                f"Number of rows in submission ({len(submission)}) does not match ground truth ({len(ground_truth)}). Please ensure both have the same number of rows."
            )

        # Sort both by first column (which should be 'id')
        submission = submission.sort_values(by=submission.columns[0]).reset_index(drop=True)
        ground_truth = ground_truth.sort_values(by=ground_truth.columns[0]).reset_index(drop=True)

        # First column values must match exactly
        if (submission[submission.columns[0]].values != ground_truth[ground_truth.columns[0]].values).any():
            raise InvalidSubmissionError(
                "First column values do not match between submission and ground truth. Please ensure the first column values are identical."
            )

        # Column set must match exactly
        sub_cols = set(submission.columns)
        true_cols = set(ground_truth.columns)

        missing_cols = true_cols - sub_cols
        extra_cols = sub_cols - true_cols

        if missing_cols:
            raise InvalidSubmissionError(f"Missing required columns in submission: {', '.join(sorted(missing_cols))}.")
        if extra_cols:
            raise InvalidSubmissionError(f"Extra unexpected columns found in submission: {', '.join(sorted(extra_cols))}.")

        # Validate label column integer and within expected set
        if self.value not in submission.columns:
            raise InvalidSubmissionError(f"Submission must contain '{self.value}' column.")
        labels = self._safe_int_series(submission[self.value], self.value)
        bad = set(labels.unique()) - self.VALID_LABELS
        if bad:
            raise InvalidSubmissionError(f"Submission contains invalid labels: {bad}")

        return "Submission is valid."
