from typing import Any
import numpy as np
import pandas as pd

from mledojo.metrics.base import CompetitionMetrics, InvalidSubmissionError


class SephoraRatingsMetrics(CompetitionMetrics):
    """
    Evaluation metric for Sephora skincare review rating prediction.

    Uses Quadratic Weighted Kappa (QWK) on integer ratings in {1, 2, 3, 4, 5}.
    Predictions are clipped to [1, 5] and rounded to the nearest integer prior to scoring.
    """

    def __init__(self, value: str = "rating", higher_is_better: bool = True):
        super().__init__(higher_is_better)
        self.value = value

    @staticmethod
    def _qwk(y_true: np.ndarray, y_pred: np.ndarray) -> float:
        # Clip and round predictions to 1..5
        y_pred = np.clip(y_pred, 1, 5)
        y_pred = np.rint(y_pred).astype(int)
        y_true = y_true.astype(int)

        min_rating = 1
        max_rating = 5
        n_ratings = max_rating - min_rating + 1

        # Confusion matrix O
        O = np.zeros((n_ratings, n_ratings), dtype=float)
        for a, b in zip(y_true, y_pred):
            if np.isnan(a) or np.isnan(b):
                continue
            O[a - min_rating, b - min_rating] += 1

        # Histogram of ratings
        act_hist = np.sum(O, axis=1)
        pred_hist = np.sum(O, axis=0)

        # Expected matrix E (outer product normalized by N)
        N = np.sum(O)
        if N == 0:
            return 0.0
        E = np.outer(act_hist, pred_hist) / N

        # Weight matrix W
        W = np.zeros((n_ratings, n_ratings), dtype=float)
        for i in range(n_ratings):
            for j in range(n_ratings):
                W[i, j] = ((i - j) ** 2) / ((n_ratings - 1) ** 2)

        num = np.sum(W * O)
        den = np.sum(W * E)
        if den == 0:
            return 0.0

        kappa = 1.0 - num / den
        if not np.isfinite(kappa):
            return 0.0
        return float(kappa)

    def evaluate(self, y_true: pd.DataFrame, y_pred: pd.DataFrame) -> float:
        # Sort by the first column to align rows before evaluation
        y_true = y_true.sort_values(by=y_true.columns[0]).reset_index(drop=True)
        y_pred = y_pred.sort_values(by=y_pred.columns[0]).reset_index(drop=True)

        yt = pd.to_numeric(y_true[self.value], errors="coerce").values
        yp = pd.to_numeric(y_pred[self.value], errors="coerce").values
        return self._qwk(yt, yp)

    def validate_submission(self, submission: Any, ground_truth: Any) -> str:
        if not isinstance(submission, pd.DataFrame):
            raise InvalidSubmissionError(
                "Submission must be a pandas DataFrame. Please provide a valid pandas DataFrame."
            )
        if not isinstance(ground_truth, pd.DataFrame):
            raise InvalidSubmissionError(
                "Ground truth must be a pandas DataFrame. Please provide a valid pandas DataFrame."
            )

        if len(submission) != len(ground_truth):
            raise InvalidSubmissionError(
                f"Number of rows in submission ({len(submission)}) does not match ground truth ({len(ground_truth)}). Please ensure both have the same number of rows."
            )

        # Sort the submission and ground truth by the first column
        submission = submission.sort_values(by=submission.columns[0])
        ground_truth = ground_truth.sort_values(by=ground_truth.columns[0])

        # First column values (IDs) must match exactly
        if (submission[submission.columns[0]].astype(str).values != ground_truth[ground_truth.columns[0]].astype(str).values).any():
            raise InvalidSubmissionError(
                "First column values do not match between submission and ground truth. Please ensure the first column values are identical."
            )

        sub_cols = set(submission.columns)
        true_cols = set(ground_truth.columns)

        missing_cols = true_cols - sub_cols
        extra_cols = sub_cols - true_cols

        if missing_cols:
            raise InvalidSubmissionError(
                f"Missing required columns in submission: {', '.join(sorted(missing_cols))}."
            )
        if extra_cols:
            raise InvalidSubmissionError(
                f"Extra unexpected columns found in submission: {', '.join(sorted(extra_cols))}."
            )

        # Ensure prediction column is numeric
        try:
            _ = pd.to_numeric(submission[self.value], errors="raise")
        except Exception:
            raise InvalidSubmissionError(
                f"Column '{self.value}' must contain numeric values."
            )

        return "Submission is valid."
