from typing import Any, Tuple

import pandas as pd
from sklearn.metrics import f1_score

from mledojo.metrics.base import CompetitionMetrics, InvalidSubmissionError


class ModelNet40ClassificationMetrics(CompetitionMetrics):
    """
    Macro-F1 metric for ModelNet40 3D object classification.

    Expected CSV schemas:
    - public/sample_submission.csv: columns ['id', 'label'] (predicted labels)
    - private/test_answer.csv: columns ['id', 'label'] (ground-truth labels)

    The 'id' column must match exactly between submission and ground truth.
    """

    def __init__(self, value: str = "label", higher_is_better: bool = True):
        super().__init__(higher_is_better)
        self.value = value

    def evaluate(self, y_true: pd.DataFrame, y_pred: pd.DataFrame) -> float:
        # sort by id (first column) to align, then compute macro F1 on the `value` column
        if not isinstance(y_true, pd.DataFrame) or not isinstance(y_pred, pd.DataFrame):
            raise InvalidSubmissionError("y_true and y_pred must be pandas DataFrames.")

        if self.value not in y_true.columns or self.value not in y_pred.columns:
            raise InvalidSubmissionError(
                f"Column '{self.value}' must exist in both y_true and y_pred."
            )

        id_col_true = y_true.columns[0]
        id_col_pred = y_pred.columns[0]

        y_true_sorted = y_true.sort_values(by=id_col_true).reset_index(drop=True)
        y_pred_sorted = y_pred.sort_values(by=id_col_pred).reset_index(drop=True)

        # Validate that ids align exactly after sorting
        if not y_true_sorted[id_col_true].equals(y_pred_sorted[id_col_pred]):
            raise InvalidSubmissionError(
                "IDs in predictions do not match ground truth IDs after sorting."
            )

        return float(
            f1_score(
                y_true_sorted[self.value], y_pred_sorted[self.value], average="macro"
            )
        )

    def validate_submission(
        self, submission: Any, ground_truth: Any
    ) -> Tuple[bool, str]:
        if not isinstance(submission, pd.DataFrame):
            raise InvalidSubmissionError(
                "Submission must be a pandas DataFrame. Please provide a valid pandas DataFrame."
            )
        if not isinstance(ground_truth, pd.DataFrame):
            raise InvalidSubmissionError(
                "Ground truth must be a pandas DataFrame. Please provide a valid pandas DataFrame."
            )

        # Basic column checks: must match ground truth columns (id + label)
        sub_cols = list(submission.columns)
        gt_cols = list(ground_truth.columns)
        if set(sub_cols) != set(gt_cols):
            missing = set(gt_cols) - set(sub_cols)
            extra = set(sub_cols) - set(gt_cols)
            msg_parts = []
            if missing:
                msg_parts.append(f"missing columns: {sorted(missing)}")
            if extra:
                msg_parts.append(f"extra columns: {sorted(extra)}")
            raise InvalidSubmissionError(
                "Submission columns must match ground truth columns. " + ", ".join(msg_parts)
            )

        # Length match
        if len(submission) != len(ground_truth):
            raise InvalidSubmissionError(
                f"Row count mismatch. submission={len(submission)}, ground_truth={len(ground_truth)}"
            )

        # Sort both by their first columns and ensure IDs match exactly
        sub_sorted = submission.sort_values(by=submission.columns[0]).reset_index(drop=True)
        gt_sorted = ground_truth.sort_values(by=ground_truth.columns[0]).reset_index(drop=True)

        # First column must be identical IDs
        if not sub_sorted[sub_sorted.columns[0]].equals(gt_sorted[gt_sorted.columns[0]]):
            raise InvalidSubmissionError(
                "First column (IDs) must match exactly between submission and ground truth."
            )

        # No NaNs or empty strings in required columns
        for col in [gt_sorted.columns[0], self.value]:
            if sub_sorted[col].isnull().any():
                raise InvalidSubmissionError(f"Column '{col}' contains null values.")
            if (sub_sorted[col].astype(str).str.len() == 0).any():
                raise InvalidSubmissionError(f"Column '{col}' contains empty strings.")

        return True, "Submission is valid."
