from typing import Any
import pandas as pd
import numpy as np

from mledojo.metrics.base import CompetitionMetrics, InvalidSubmissionError


class MovieYearPredictionMetrics(CompetitionMetrics):
    """
    Metric class for Movie Year Prediction from Dialogue Transcripts.

    - Metric: Mean Absolute Error (MAE)
    - Lower is better
    - Expected schema for both y_true and y_pred: pandas DataFrame with columns [id, year]
    """

    def __init__(self, value: str = "year", higher_is_better: bool = False):
        super().__init__(higher_is_better)
        self.value = value

    def evaluate(self, y_true: pd.DataFrame, y_pred: pd.DataFrame) -> float:
        """
        Compute Mean Absolute Error between predicted years and true years.

        Args:
            y_true: DataFrame with columns [id, year]
            y_pred: DataFrame with columns [id, year]
        Returns:
            float MAE
        """
        # Sort both dataframes by the first column (id) before calculating score
        y_true = y_true.sort_values(by=y_true.columns[0]).reset_index(drop=True)
        y_pred = y_pred.sort_values(by=y_pred.columns[0]).reset_index(drop=True)

        # Basic shape/column checks similar to validate but lightweight here
        if y_true.shape[0] != y_pred.shape[0]:
            raise InvalidSubmissionError(
                f"Row count mismatch: y_true has {y_true.shape[0]}, y_pred has {y_pred.shape[0]}"
            )
        if list(y_true.columns) != [y_true.columns[0], self.value]:
            raise InvalidSubmissionError(
                f"y_true must have columns [id, {self.value}] but got {list(y_true.columns)}"
            )
        if list(y_pred.columns) != [y_pred.columns[0], self.value]:
            raise InvalidSubmissionError(
                f"y_pred must have columns [id, {self.value}] but got {list(y_pred.columns)}"
            )

        # Ensure ids align exactly
        if not np.array_equal(y_true.iloc[:, 0].values, y_pred.iloc[:, 0].values):
            raise InvalidSubmissionError("Row ids in prediction do not match ground truth ids.")

        true_vals = pd.to_numeric(y_true[self.value], errors="coerce")
        pred_vals = pd.to_numeric(y_pred[self.value], errors="coerce")
        if true_vals.isna().any():
            raise InvalidSubmissionError("Ground truth contains non-numeric values in year column.")
        if pred_vals.isna().any():
            raise InvalidSubmissionError("Submission contains non-numeric values in year column.")

        mae = (true_vals.subtract(pred_vals).abs()).mean()
        return float(mae)

    def validate_submission(self, submission: Any, ground_truth: Any) -> str:
        """
        Validate that the submission matches the expected format and ids of ground truth.

        Args:
            submission: pandas DataFrame with columns [id, year]
            ground_truth: pandas DataFrame with columns [id, year]
        Returns:
            A human-readable success message if valid.
        Raises:
            InvalidSubmissionError on any validation failure.
        """
        if not isinstance(submission, pd.DataFrame):
            raise InvalidSubmissionError(
                "Submission must be a pandas DataFrame. Please provide a valid pandas DataFrame."
            )
        if not isinstance(ground_truth, pd.DataFrame):
            raise InvalidSubmissionError(
                "Ground truth must be a pandas DataFrame. Please provide a valid pandas DataFrame."
            )

        # Check required columns and no extra columns
        expected_cols = list(ground_truth.columns)
        if expected_cols != [expected_cols[0], self.value]:
            raise InvalidSubmissionError(
                f"Ground truth must have columns [id, {self.value}] but got {expected_cols}."
            )

        if list(submission.columns) != expected_cols:
            raise InvalidSubmissionError(
                f"Submission columns must exactly match ground truth columns {expected_cols}, but got {list(submission.columns)}."
            )

        if len(submission) != len(ground_truth):
            raise InvalidSubmissionError(
                f"Number of rows in submission ({len(submission)}) does not match ground truth ({len(ground_truth)})."
            )

        # Sort both by the first column and compare id order
        sub_sorted = submission.sort_values(by=submission.columns[0]).reset_index(drop=True)
        gt_sorted = ground_truth.sort_values(by=ground_truth.columns[0]).reset_index(drop=True)

        if not np.array_equal(sub_sorted.iloc[:, 0].values, gt_sorted.iloc[:, 0].values):
            raise InvalidSubmissionError(
                "First column (ids) do not match between submission and ground truth."
            )

        # Check numeric year values in submission
        year_vals = pd.to_numeric(sub_sorted[self.value], errors="coerce")
        if year_vals.isna().any():
            raise InvalidSubmissionError(
                "Submission contains non-numeric values in the year column."
            )

        return "Submission is valid."
