from typing import Any
import pandas as pd
from sklearn.metrics import f1_score

from mledojo.metrics.base import CompetitionMetrics, InvalidSubmissionError


class KineticsActionRecognitionMetrics(CompetitionMetrics):
    """Macro-F1 for single-label action recognition submissions.

    Expected CSV schema (aligned with public/sample_submission.csv and private/test_answer.csv):
    - Columns: [video_id, label]
    - One row per test example, with identical video_id set and order as test.csv
    """

    def __init__(self, value: str = "label", higher_is_better: bool = True):
        super().__init__(higher_is_better)
        self.value = value

    def evaluate(self, y_true: pd.DataFrame, y_pred: pd.DataFrame) -> float:
        """Compute macro-averaged F1.

        Args:
            y_true: DataFrame with columns [video_id, label]; hidden answers for the test split
            y_pred: DataFrame with columns [video_id, label]; participant submission
        Returns:
            float: macro F1 score in [0, 1]
        """
        # Sort both by first column to align rows
        y_true_sorted = y_true.sort_values(by=y_true.columns[0]).reset_index(drop=True)
        y_pred_sorted = y_pred.sort_values(by=y_pred.columns[0]).reset_index(drop=True)

        # Basic sanity checks – will raise if invalid
        self.validate_submission(y_pred_sorted, y_true_sorted)

        y_t = y_true_sorted[self.value].astype(str).tolist()
        y_p = y_pred_sorted[self.value].astype(str).tolist()
        return float(f1_score(y_t, y_p, average="macro", zero_division=0))

    def validate_submission(self, submission: Any, ground_truth: Any) -> str:
        """Validate participant submission DataFrame against ground truth DataFrame.

        Both must be pandas DataFrames with identical row counts, identical first-column values
        (video_id) after sorting by that column, and identical column names {video_id, label}.
        """
        if not isinstance(submission, pd.DataFrame):
            raise InvalidSubmissionError(
                "Submission must be a pandas DataFrame. Please provide a valid pandas DataFrame."
            )
        if not isinstance(ground_truth, pd.DataFrame):
            raise InvalidSubmissionError(
                "Ground truth must be a pandas DataFrame. Please provide a valid pandas DataFrame."
            )

        if len(submission) != len(ground_truth):
            raise InvalidSubmissionError(
                f"Number of rows in submission ({len(submission)}) does not match ground truth ({len(ground_truth)})."
            )

        # Expected columns
        expected_cols = {ground_truth.columns[0], self.value}
        sub_cols = set(submission.columns)
        true_cols = set(ground_truth.columns)

        missing_cols = expected_cols - sub_cols
        extra_cols = sub_cols - expected_cols
        if missing_cols:
            raise InvalidSubmissionError(
                f"Missing required columns in submission: {', '.join(sorted(missing_cols))}."
            )
        if extra_cols:
            raise InvalidSubmissionError(
                f"Extra unexpected columns found in submission: {', '.join(sorted(extra_cols))}."
            )

        # Sort for deterministic checks
        submission_sorted = submission.sort_values(by=submission.columns[0]).reset_index(drop=True)
        ground_sorted = ground_truth.sort_values(by=ground_truth.columns[0]).reset_index(drop=True)

        # First-column values must match exactly
        if not submission_sorted.iloc[:, 0].equals(ground_sorted.iloc[:, 0]):
            raise InvalidSubmissionError(
                "First column values (video_id) do not match between submission and ground truth."
            )

        # Labels must be non-empty strings
        if submission_sorted[self.value].isna().any():
            raise InvalidSubmissionError("Submission contains missing labels.")
        if not submission_sorted[self.value].map(lambda x: isinstance(x, (str,))).all():
            raise InvalidSubmissionError("Submission labels must be strings.")
        if (submission_sorted[self.value].str.len() == 0).any():
            raise InvalidSubmissionError("Submission labels must be non-empty strings.")

        return "Submission is valid."
