from typing import Any

import numpy as np
import pandas as pd
from sklearn.metrics import f1_score

from mledojo.metrics.base import CompetitionMetrics, InvalidSubmissionError


VALID_LABELS = {"CC", "EC", "HGSC", "LGSC", "MC"}


class OvarianCancerClassificationMetrics(CompetitionMetrics):
    """
    Macro-averaged F1 score for Ovarian Cancer Subtype Classification.

    Expected schemas:
    - Ground truth (y_true): DataFrame with columns ["id", "label"].
    - Submission (y_pred): DataFrame with columns ["id", "label"].

    The id columns must match exactly and be in the same order for evaluation.
    """

    def __init__(self, value: str = "label", higher_is_better: bool = True):
        super().__init__(higher_is_better)
        self.value = value

    def evaluate(self, y_true: pd.DataFrame, y_pred: pd.DataFrame) -> float:
        # Basic type checks
        if not isinstance(y_true, pd.DataFrame):
            raise InvalidSubmissionError("y_true must be a pandas DataFrame.")
        if not isinstance(y_pred, pd.DataFrame):
            raise InvalidSubmissionError("y_pred must be a pandas DataFrame.")

        # Validate both dataframes have the same schema and aligned ids
        self.validate_submission(y_pred, y_true)

        # Sort by id (first column) deterministically before scoring
        id_col_true = y_true.columns[0]
        id_col_pred = y_pred.columns[0]
        y_true_sorted = y_true.sort_values(by=id_col_true).reset_index(drop=True)
        y_pred_sorted = y_pred.sort_values(by=id_col_pred).reset_index(drop=True)

        true_labels = y_true_sorted[self.value].astype(str).values
        pred_labels = y_pred_sorted[self.value].astype(str).values

        # Only consider classes present in ground truth for macro averaging
        present_classes = sorted(pd.unique(true_labels))
        score = f1_score(true_labels, pred_labels, labels=present_classes, average="macro")
        # Ensure numeric float
        return float(score)

    def validate_submission(self, submission: Any, ground_truth: Any) -> str:
        if not isinstance(submission, pd.DataFrame):
            raise InvalidSubmissionError(
                "Submission must be a pandas DataFrame. Please provide a valid pandas DataFrame."
            )
        if not isinstance(ground_truth, pd.DataFrame):
            raise InvalidSubmissionError(
                "Ground truth must be a pandas DataFrame. Please provide a valid pandas DataFrame."
            )

        # Column validation
        sub_cols = list(submission.columns)
        true_cols = list(ground_truth.columns)
        if len(sub_cols) != 2 or len(true_cols) != 2:
            raise InvalidSubmissionError(
                "Both submission and ground truth must have exactly two columns: 'id' and 'label'."
            )

        # Normalize header names for robust checking
        expected_cols = ["id", self.value]
        if [c.lower() for c in sub_cols] != expected_cols:
            raise InvalidSubmissionError(
                f"Submission must have columns exactly ['id', '{self.value}'] in this order."
            )
        if [c.lower() for c in true_cols] != expected_cols:
            raise InvalidSubmissionError(
                f"Ground truth must have columns exactly ['id', '{self.value}'] in this order."
            )

        # Length check
        if len(submission) != len(ground_truth):
            raise InvalidSubmissionError(
                f"Number of rows in submission ({len(submission)}) does not match ground truth ({len(ground_truth)})."
            )

        # Sort and compare ids
        submission_sorted = submission.sort_values(by=submission.columns[0]).reset_index(drop=True)
        ground_truth_sorted = ground_truth.sort_values(by=ground_truth.columns[0]).reset_index(drop=True)

        sub_ids = submission_sorted.iloc[:, 0].astype(str).values
        true_ids = ground_truth_sorted.iloc[:, 0].astype(str).values
        if (sub_ids != true_ids).any():
            raise InvalidSubmissionError(
                "First column 'id' values do not match between submission and ground truth."
            )

        # Check for duplicate ids in submission
        if len(np.unique(sub_ids)) != len(sub_ids):
            raise InvalidSubmissionError("Duplicate ids found in submission.")

        # Validate label domain
        labels = set(map(str, submission_sorted[self.value].unique()))
        invalid = labels - VALID_LABELS
        if invalid:
            raise InvalidSubmissionError(
                f"Submission contains invalid labels {sorted(invalid)}. Valid labels are {sorted(VALID_LABELS)}."
            )

        return "Submission is valid."
