from typing import Any
import pandas as pd
import numpy as np
from sklearn.metrics import f1_score

from mledojo.metrics.base import CompetitionMetrics, InvalidSubmissionError


class EmotionClassificationMetrics(CompetitionMetrics):
    """Macro-F1 metric for emotion classification submissions.

    Expected schema:
    - Ground truth (y_true): columns ['id', 'Emotion']
    - Submission (y_pred): columns ['id', 'Emotion']
    """

    def __init__(self, value: str = "Emotion", higher_is_better: bool = True):
        super().__init__(higher_is_better)
        self.value = value

    def evaluate(self, y_true: pd.DataFrame, y_pred: pd.DataFrame) -> float:
        # Basic validation of input types
        if not isinstance(y_true, pd.DataFrame):
            raise InvalidSubmissionError("Ground truth must be a pandas DataFrame.")
        if not isinstance(y_pred, pd.DataFrame):
            raise InvalidSubmissionError("Submission must be a pandas DataFrame.")

        # Align both by first column (id)
        id_col_true = y_true.columns[0]
        id_col_pred = y_pred.columns[0]

        y_true_sorted = y_true.sort_values(by=id_col_true).reset_index(drop=True)
        y_pred_sorted = y_pred.sort_values(by=id_col_pred).reset_index(drop=True)

        # Ensure the IDs are identical after sorting
        if (y_true_sorted[id_col_true].astype(str).values != y_pred_sorted[id_col_pred].astype(str).values).any():
            raise InvalidSubmissionError(
                "ID values do not match between ground truth and submission after sorting."
            )

        # Ensure the label column exists
        if self.value not in y_true_sorted.columns:
            raise InvalidSubmissionError(f"Ground truth does not contain required column '{self.value}'.")
        if self.value not in y_pred_sorted.columns:
            raise InvalidSubmissionError(f"Submission does not contain required column '{self.value}'.")

        # Cast to string labels to avoid dtype issues
        y_t = y_true_sorted[self.value].astype(str).values
        y_p = y_pred_sorted[self.value].astype(str).values

        # Restrict labels to those that appear in the ground truth
        labels_present = np.unique(y_t)
        if len(labels_present) == 0:
            return 0.0

        score = f1_score(y_t, y_p, labels=labels_present, average="macro", zero_division=0)
        if not np.isfinite(score):
            score = 0.0
        return float(score)

    def validate_submission(self, submission: Any, ground_truth: Any) -> str:
        if not isinstance(submission, pd.DataFrame):
            raise InvalidSubmissionError(
                "Submission must be a pandas DataFrame. Please provide a valid pandas DataFrame."
            )
        if not isinstance(ground_truth, pd.DataFrame):
            raise InvalidSubmissionError(
                "Ground truth must be a pandas DataFrame. Please provide a valid pandas DataFrame."
            )

        # Check row counts
        if len(submission) != len(ground_truth):
            raise InvalidSubmissionError(
                f"Number of rows in submission ({len(submission)}) does not match ground truth ({len(ground_truth)})."
            )

        # Sort both and check first-column identity (ID alignment)
        submission_sorted = submission.sort_values(by=submission.columns[0]).reset_index(drop=True)
        ground_truth_sorted = ground_truth.sort_values(by=ground_truth.columns[0]).reset_index(drop=True)

        if (
            submission_sorted[submission.columns[0]].astype(str).values
            != ground_truth_sorted[ground_truth.columns[0]].astype(str).values
        ).any():
            raise InvalidSubmissionError(
                "First column values (IDs) do not match between submission and ground truth."
            )

        # Ensure required label column exists
        if self.value not in submission.columns:
            raise InvalidSubmissionError(f"Missing required column in submission: '{self.value}'.")

        # Check for extra/missing columns compared to ground truth schema
        sub_cols = set(submission.columns)
        true_cols = set(ground_truth.columns)
        missing_cols = true_cols - sub_cols
        extra_cols = sub_cols - true_cols
        if missing_cols:
            raise InvalidSubmissionError(f"Missing required columns in submission: {', '.join(sorted(missing_cols))}.")
        if extra_cols:
            raise InvalidSubmissionError(f"Extra unexpected columns found in submission: {', '.join(sorted(extra_cols))}.")

        # Optional: check for duplicate IDs
        if not submission[submission.columns[0]].is_unique:
            raise InvalidSubmissionError("Submission IDs must be unique.")

        # Basic dtype normalization for labels
        _ = submission[self.value].astype(str)

        return "Submission is valid."
