from typing import Any
import numpy as np
import pandas as pd

from mledojo.metrics.base import CompetitionMetrics, InvalidSubmissionError

ID_COL = "Id"
LABEL_COL = "Y"
CLASS_COLUMNS = ["HQ", "LQ_EDIT", "LQ_CLOSE"]


class StackOverflowQualityMetrics(CompetitionMetrics):
    """Macro log loss for Stack Overflow question quality classification.

    Expected files/frames:
    - Ground truth (y_true): DataFrame with columns [Id, Y]
    - Submission (y_pred): DataFrame with columns [Id, HQ, LQ_EDIT, LQ_CLOSE]

    Score definition:
    - For each class, compute the mean negative log-likelihood among examples of that class
    - Final score is the arithmetic mean across classes (lower is better)
    """

    def __init__(self, value: str = LABEL_COL, higher_is_better: bool = False):
        super().__init__(higher_is_better)
        self.value = value

    def evaluate(self, y_true: pd.DataFrame, y_pred: pd.DataFrame) -> float:
        # Basic type checks
        if not isinstance(y_true, pd.DataFrame):
            raise InvalidSubmissionError("y_true must be a pandas DataFrame.")
        if not isinstance(y_pred, pd.DataFrame):
            raise InvalidSubmissionError("y_pred must be a pandas DataFrame.")

        # Validate submission and alignment with ground truth
        self.validate_submission(y_pred, y_true)

        # Align order by Id
        sub = y_pred.set_index(ID_COL).sort_index()
        ans = y_true.set_index(ID_COL).sort_index()

        # Map class labels to indices
        label_to_idx = {c: i for i, c in enumerate(CLASS_COLUMNS)}
        try:
            y_idx = ans[self.value].map(label_to_idx).to_numpy()
        except KeyError as e:
            raise InvalidSubmissionError(f"Ground truth is missing label column '{self.value}'.") from e

        if np.any(pd.isna(y_idx)):
            raise InvalidSubmissionError("Ground truth contains unknown class labels not in expected set.")

        # Probabilities matrix in the correct class order
        P = sub[CLASS_COLUMNS].to_numpy(dtype=float)

        # Numerical stability
        eps = 1e-15
        P = np.clip(P, eps, 1 - eps)

        # Compute class-wise log loss and average (macro)
        losses = []
        for k in range(len(CLASS_COLUMNS)):
            mask = (y_idx == k)
            if not np.any(mask):
                # If a class is absent in ground truth, skip it
                continue
            pk = P[mask, k]
            loss_k = -float(np.mean(np.log(pk)))
            if not np.isfinite(loss_k):
                raise InvalidSubmissionError("Encountered non-finite loss; check submission probabilities.")
            losses.append(loss_k)

        if not losses:
            raise InvalidSubmissionError("No valid classes present to compute loss.")

        return float(np.mean(losses))

    def validate_submission(self, submission: Any, ground_truth: Any) -> tuple[bool, str]:
        if not isinstance(submission, pd.DataFrame):
            raise InvalidSubmissionError(
                "Submission must be a pandas DataFrame. Please provide a valid pandas DataFrame."
            )
        if not isinstance(ground_truth, pd.DataFrame):
            raise InvalidSubmissionError(
                "Ground truth must be a pandas DataFrame. Please provide a valid pandas DataFrame."
            )

        # Required columns
        expected_sub_cols = [ID_COL] + CLASS_COLUMNS
        if list(submission.columns) != expected_sub_cols:
            raise InvalidSubmissionError(
                f"Submission must have columns exactly: {expected_sub_cols}."
            )

        expected_gt_cols = [ID_COL, self.value]
        for col in expected_gt_cols:
            if col not in ground_truth.columns:
                raise InvalidSubmissionError(
                    f"Ground truth must contain columns: {expected_gt_cols}."
                )

        # Id checks
        sub_ids = submission[ID_COL].tolist()
        gt_ids = ground_truth[ID_COL].tolist()
        if set(sub_ids) != set(gt_ids):
            raise InvalidSubmissionError(
                "Submission Ids must match ground truth Ids exactly (set equality)."
            )

        # Probability checks
        probs = submission[CLASS_COLUMNS].to_numpy(dtype=float)
        if not np.all(np.isfinite(probs)):
            raise InvalidSubmissionError("Submission contains non-finite values (nan/inf).")
        if np.any(probs < -1e-12):
            raise InvalidSubmissionError("Submission contains negative probabilities.")

        row_sums = probs.sum(axis=1)
        if not np.all(np.abs(row_sums - 1.0) <= 1e-6):
            raise InvalidSubmissionError(
                "Each submission row must sum to 1 across class columns (within 1e-6 tolerance)."
            )

        # Label set sanity
        allowed = set(CLASS_COLUMNS)
        gt_labels = set(ground_truth[self.value].astype(str).unique().tolist())
        if not gt_labels.issubset(allowed):
            raise InvalidSubmissionError(
                f"Ground truth contains labels outside of expected set {sorted(allowed)}."
            )

        return True, "Submission is valid."
