from typing import Any
import pandas as pd
from sklearn.metrics import f1_score

from mledojo.metrics.base import CompetitionMetrics, InvalidSubmissionError


class UCF101ActionRecognitionMetrics(CompetitionMetrics):
    """
    Metric class for UCF101 Action Recognition.
    Evaluation metric: Macro-averaged F1 score over class labels.
    Submission/ground-truth schema: two columns [id, label].
    """

    def __init__(self, value: str = "label", higher_is_better: bool = True):
        super().__init__(higher_is_better)
        self.value = value

    def evaluate(self, y_true: pd.DataFrame, y_pred: pd.DataFrame) -> float:
        """
        Compute macro-F1 score between ground truth and submission predictions.
        Both inputs must be DataFrames with columns [id, label].
        The rows are aligned by sorting on the id column.
        """
        if not isinstance(y_true, pd.DataFrame):
            raise InvalidSubmissionError("Ground truth must be a pandas DataFrame.")
        if not isinstance(y_pred, pd.DataFrame):
            raise InvalidSubmissionError("Submission must be a pandas DataFrame.")

        # Validate and align
        ok, msg = self.validate_submission(y_pred, y_true)
        if not ok:
            raise InvalidSubmissionError(msg)

        # Sort both by id and compute macro F1
        y_true = y_true.sort_values(by=y_true.columns[0]).reset_index(drop=True)
        y_pred = y_pred.sort_values(by=y_pred.columns[0]).reset_index(drop=True)

        true_labels = y_true[self.value].astype(str).tolist()
        pred_labels = y_pred[self.value].astype(str).tolist()

        # Use labels present in y_true to define the label space
        label_space = sorted(pd.unique(y_true[self.value].astype(str)))
        return float(f1_score(true_labels, pred_labels, labels=label_space, average="macro"))

    def validate_submission(self, submission: Any, ground_truth: Any) -> tuple[bool, str]:
        """
        Validate that submission and ground truth:
        - are pandas DataFrames
        - both contain exactly the columns {id, label}
        - have the same number of rows
        - have identical ids in the same order once sorted by the first column
        """
        if not isinstance(submission, pd.DataFrame):
            raise InvalidSubmissionError(
                "Submission must be a pandas DataFrame. Please provide a valid pandas DataFrame."
            )
        if not isinstance(ground_truth, pd.DataFrame):
            raise InvalidSubmissionError(
                "Ground truth must be a pandas DataFrame. Please provide a valid pandas DataFrame."
            )

        if len(submission) != len(ground_truth):
            raise InvalidSubmissionError(
                f"Number of rows in submission ({len(submission)}) does not match ground truth ({len(ground_truth)})."
            )

        expected_cols = {"id", self.value}
        sub_cols = set(submission.columns)
        true_cols = set(ground_truth.columns)
        missing_cols = expected_cols - sub_cols
        extra_cols = sub_cols - expected_cols
        if missing_cols:
            raise InvalidSubmissionError(
                f"Missing required columns in submission: {', '.join(sorted(missing_cols))}."
            )
        if extra_cols:
            raise InvalidSubmissionError(
                f"Extra unexpected columns found in submission: {', '.join(sorted(extra_cols))}."
            )

        # Sort on the first column (should be id) and check identical ids
        sub_sorted = submission.sort_values(by=submission.columns[0]).reset_index(drop=True)
        true_sorted = ground_truth.sort_values(by=ground_truth.columns[0]).reset_index(drop=True)
        if not sub_sorted.iloc[:, 0].equals(true_sorted.iloc[:, 0]):
            raise InvalidSubmissionError(
                "First column values (ids) do not match between submission and ground truth."
            )

        # Basic type checks for labels
        if sub_sorted[self.value].isna().any():
            raise InvalidSubmissionError("Submission contains missing labels.")

        return True, "Submission is valid."
