from typing import Any
import numpy as np
import pandas as pd

from mledojo.metrics.base import CompetitionMetrics, InvalidSubmissionError


class MedicinalPropertiesMetrics(CompetitionMetrics):
    """
    Metrics for multilabel medicinal-properties prediction.

    Evaluation: micro-averaged log loss over all labels.
    Submission format: CSV with columns [id, label_...]
    where label_... are probability columns in [0,1].
    """

    def __init__(self, value: Any | None = None, higher_is_better: bool = False):
        super().__init__(higher_is_better)
        # value is unused for multilabel, kept for API parity
        self.value = value

    def evaluate(self, y_true: pd.DataFrame, y_pred: pd.DataFrame) -> float:
        # Align by id (first column) and compute micro log loss
        if not isinstance(y_true, pd.DataFrame) or not isinstance(y_pred, pd.DataFrame):
            raise InvalidSubmissionError("Both y_true and y_pred must be pandas DataFrames.")

        # Sort by id/first column to align rows for scoring
        id_col_true = y_true.columns[0]
        id_col_pred = y_pred.columns[0]
        y_true_sorted = y_true.sort_values(by=id_col_true).reset_index(drop=True)
        y_pred_sorted = y_pred.sort_values(by=id_col_pred).reset_index(drop=True)

        # Validate columns match exactly
        if list(y_true_sorted.columns) != list(y_pred_sorted.columns):
            raise InvalidSubmissionError(
                "Columns of prediction do not match ground truth. Ensure the header matches exactly."
            )

        # Validate ids (as a set) match
        if not np.array_equal(
            np.sort(y_true_sorted.iloc[:, 0].values), np.sort(y_pred_sorted.iloc[:, 0].values)
        ):
            raise InvalidSubmissionError(
                "IDs in prediction do not match ground truth IDs."
            )

        # Extract matrices
        y_t = y_true_sorted.iloc[:, 1:].to_numpy(dtype=float)
        y_p = y_pred_sorted.iloc[:, 1:].to_numpy(dtype=float)

        # Validate probabilities
        if not np.isfinite(y_p).all():
            raise InvalidSubmissionError("Predictions contain non-finite values (NaN/Inf).")
        if (y_p < 0).any() or (y_p > 1).any():
            raise InvalidSubmissionError("Predicted probabilities must be within [0, 1].")

        # Compute micro-averaged log loss
        eps = 1e-15
        y_p = np.clip(y_p, eps, 1 - eps)
        loss = -(y_t * np.log(y_p) + (1 - y_t) * np.log(1 - y_p))
        score = float(np.mean(loss))
        if not np.isfinite(score):
            raise InvalidSubmissionError("Computed score is not finite.")
        return score

    def validate_submission(self, submission: Any, ground_truth: Any) -> str:
        # Validation: require exact same order and columns as ground_truth
        if not isinstance(submission, pd.DataFrame):
            raise InvalidSubmissionError(
                "Submission must be a pandas DataFrame. Please provide a valid pandas DataFrame."
            )
        if not isinstance(ground_truth, pd.DataFrame):
            raise InvalidSubmissionError(
                "Ground truth must be a pandas DataFrame. Please provide a valid pandas DataFrame."
            )

        if list(submission.columns) != list(ground_truth.columns):
            raise InvalidSubmissionError(
                "Submission columns must exactly match the ground truth columns (including order)."
            )

        if len(submission) != len(ground_truth):
            raise InvalidSubmissionError(
                f"Number of rows in submission ({len(submission)}) does not match ground truth ({len(ground_truth)})."
            )

        # Must match exactly row-by-row (same order as public/test.csv and private/test_answer.csv)
        if not np.array_equal(submission.iloc[:, 0].values, ground_truth.iloc[:, 0].values):
            raise InvalidSubmissionError(
                "First column (IDs) does not match order of ground truth."
            )

        # probability values check
        probs = submission.iloc[:, 1:].to_numpy(dtype=float)
        if not np.isfinite(probs).all():
            raise InvalidSubmissionError("Submission contains non-finite values (NaN/Inf).")
        if (probs < 0).any() or (probs > 1).any():
            raise InvalidSubmissionError("Submission probabilities must be in [0, 1].")

        return "Submission is valid."
