from typing import Any
import pandas as pd
from sklearn.metrics import f1_score

from mledojo.metrics.base import CompetitionMetrics, InvalidSubmissionError


class BundesligaVideoClassificationMetrics(CompetitionMetrics):
    """Macro F1 metric for video classification submissions.

    Expected CSV schema:
    - Ground truth (test_answer.csv): columns [id, label]
    - Submission (sample_submission.csv): columns [id, label]

    Behavior:
    - evaluate(y_true, y_pred): returns macro F1 over hard labels, matching by id
    - validate_submission(submission, ground_truth): validates format and contents
    """

    def __init__(self, value: str = "label", higher_is_better: bool = True):
        super().__init__(higher_is_better=higher_is_better)
        self.value = value

    def _ensure_dataframe(self, obj: Any, name: str) -> pd.DataFrame:
        if isinstance(obj, pd.DataFrame):
            return obj.copy()
        raise InvalidSubmissionError(f"{name} must be a pandas DataFrame. Please provide a valid pandas DataFrame.")

    def _normalize(self, df: pd.DataFrame) -> pd.DataFrame:
        # sort by first column and reset index
        return df.sort_values(by=df.columns[0]).reset_index(drop=True)

    def evaluate(self, y_true: Any, y_pred: Any) -> float:
        y_true = self._ensure_dataframe(y_true, "Ground truth")
        y_pred = self._ensure_dataframe(y_pred, "Submission")

        # Validate and align before scoring
        self.validate_submission(y_pred, y_true)

        y_true = self._normalize(y_true)
        y_pred = self._normalize(y_pred)

        return float(
            f1_score(
                y_true[self.value].values,
                y_pred[self.value].values,
                average="macro",
            )
        )

    def validate_submission(self, submission: Any, ground_truth: Any) -> str:
        submission = self._ensure_dataframe(submission, "Submission")
        ground_truth = self._ensure_dataframe(ground_truth, "Ground truth")

        # Basic length check
        if len(submission) != len(ground_truth):
            raise InvalidSubmissionError(
                f"Number of rows in submission ({len(submission)}) does not match ground truth ({len(ground_truth)}). Please ensure both have the same number of rows."
            )

        # Required columns must match ground truth columns
        sub_cols = list(submission.columns)
        true_cols = list(ground_truth.columns)
        required_cols = [true_cols[0], self.value]
        missing = [c for c in required_cols if c not in sub_cols]
        if missing:
            raise InvalidSubmissionError(f"Missing required columns in submission: {', '.join(missing)}.")

        extra_cols = set(sub_cols) - set(true_cols)
        if extra_cols:
            raise InvalidSubmissionError(f"Extra unexpected columns found in submission: {', '.join(sorted(extra_cols))}.")

        # Ensure ids match exactly (set equality and duplicates check)
        sub_ids = submission[submission.columns[0]].astype(str)
        true_ids = ground_truth[ground_truth.columns[0]].astype(str)

        if sub_ids.duplicated().any():
            dupes = sub_ids[sub_ids.duplicated()].unique().tolist()
            raise InvalidSubmissionError(f"Duplicate id(s) in submission: {', '.join(dupes)}.")

        if set(sub_ids) != set(true_ids):
            raise InvalidSubmissionError(
                "Submission ids must match the ground truth test ids exactly (no missing or extra ids)."
            )

        # Sort and check identical order if compared by first column
        submission_sorted = submission.sort_values(by=submission.columns[0]).reset_index(drop=True)
        ground_truth_sorted = ground_truth.sort_values(by=ground_truth.columns[0]).reset_index(drop=True)
        if (submission_sorted[submission_sorted.columns[0]].values != ground_truth_sorted[ground_truth_sorted.columns[0]].values).any():
            raise InvalidSubmissionError(
                "First column values do not match between submission and ground truth. Please ensure the first column values are identical."
            )

        # Validate labels belong to the allowed set observed in ground truth
        allowed = set(ground_truth[self.value].astype(str).unique())
        submitted_labels = set(submission[self.value].astype(str).unique())
        if not submitted_labels.issubset(allowed):
            bad = submitted_labels - allowed
            raise InvalidSubmissionError(
                f"Submission contains invalid label(s) not present in the ground truth label set: {', '.join(sorted(bad))}."
            )

        return "Submission is valid."
