from typing import Any
import pandas as pd
import numpy as np
from mledojo.metrics.base import CompetitionMetrics, InvalidSubmissionError


class AthleticsPointsMAEMetrics(CompetitionMetrics):
    """MAE metric for World Athletics result_score predictions.

    Expected CSV schema for evaluation/validation:
    - Two columns: id, result_score
    - One row per id in the public test set
    """

    def __init__(self, value: str = "result_score", higher_is_better: bool = False):
        super().__init__(higher_is_better)
        self.value = value

    def evaluate(self, y_true: pd.DataFrame, y_pred: pd.DataFrame) -> float:
        # Sort both dataframes by first column (expected to be 'id')
        true_sorted = y_true.sort_values(by=y_true.columns[0]).reset_index(drop=True)
        pred_sorted = y_pred.sort_values(by=y_pred.columns[0]).reset_index(drop=True)

        # Basic alignment checks
        if len(true_sorted) != len(pred_sorted):
            raise InvalidSubmissionError(
                f"Number of rows in predictions ({len(pred_sorted)}) does not match ground truth ({len(true_sorted)})."
            )
        if (true_sorted[true_sorted.columns[0]].values != pred_sorted[pred_sorted.columns[0]].values).any():
            raise InvalidSubmissionError(
                "Row ids between y_true and y_pred do not match after sorting."
            )

        # Compute MAE on the specified value column
        if self.value not in true_sorted.columns or self.value not in pred_sorted.columns:
            raise InvalidSubmissionError(
                f"Column '{self.value}' must exist in both y_true and y_pred."
            )

        y_t = pd.to_numeric(true_sorted[self.value], errors="coerce")
        y_p = pd.to_numeric(pred_sorted[self.value], errors="coerce")
        if y_t.isna().any():
            raise InvalidSubmissionError("y_true contains non-numeric/NaN values in the target column.")
        if y_p.isna().any():
            raise InvalidSubmissionError("y_pred contains non-numeric/NaN values in the target column.")

        mae = float(np.mean(np.abs(y_t.to_numpy(dtype=float) - y_p.to_numpy(dtype=float))))
        return mae

    def validate_submission(self, submission: Any, ground_truth: Any) -> str:
        if not isinstance(submission, pd.DataFrame):
            raise InvalidSubmissionError(
                "Submission must be a pandas DataFrame. Please provide a valid pandas DataFrame."
            )
        if not isinstance(ground_truth, pd.DataFrame):
            raise InvalidSubmissionError(
                "Ground truth must be a pandas DataFrame. Please provide a valid pandas DataFrame."
            )

        # Basic row count check
        if len(submission) != len(ground_truth):
            raise InvalidSubmissionError(
                f"Number of rows in submission ({len(submission)}) does not match ground truth ({len(ground_truth)})."
            )

        # Sort the submission and ground truth by the first column
        sub_sorted = submission.sort_values(by=submission.columns[0])
        gt_sorted = ground_truth.sort_values(by=ground_truth.columns[0])

        # First column values must match exactly in order
        if (sub_sorted[sub_sorted.columns[0]].values != gt_sorted[gt_sorted.columns[0]].values).any():
            raise InvalidSubmissionError(
                "First column values (ids) do not match between submission and ground truth."
            )

        # Columns must match exactly
        sub_cols = set(submission.columns)
        true_cols = set(ground_truth.columns)
        missing_cols = true_cols - sub_cols
        extra_cols = sub_cols - true_cols

        if missing_cols:
            raise InvalidSubmissionError(
                f"Missing required columns in submission: {', '.join(sorted(missing_cols))}."
            )
        if extra_cols:
            raise InvalidSubmissionError(
                f"Extra unexpected columns found in submission: {', '.join(sorted(extra_cols))}."
            )

        # Validate values in the target column
        if self.value not in submission.columns:
            raise InvalidSubmissionError(
                f"Submission must contain the target column '{self.value}'."
            )
        pred_vals = pd.to_numeric(submission[self.value], errors="coerce")
        if pred_vals.isna().any() or not np.isfinite(pred_vals).all():
            raise InvalidSubmissionError(
                "Submission contains non-numeric or non-finite values in the target column."
            )

        return "Submission is valid."
