from typing import Any
import math
import pandas as pd

from mledojo.metrics.base import CompetitionMetrics, InvalidSubmissionError


class FacialBeautyBalancedRMSEMetrics(CompetitionMetrics):
    """
    Group-balanced RMSE for SCUT-FBP5500 facial beauty regression.

    Expects:
    - submission: DataFrame with columns ["image_id", value]
    - ground_truth: DataFrame with columns at least ["image_id", value]; may include extra columns like "group".

    The evaluation computes per-group RMSE (if a "group" column exists in ground_truth),
    and averages them. If no group column exists, it falls back to standard RMSE.
    """

    def __init__(self, value: str = "beauty", higher_is_better: bool = False):
        super().__init__(higher_is_better)
        self.value = value

    def evaluate(self, y_true: pd.DataFrame, y_pred: pd.DataFrame) -> float:
        # sort both by first column to align rows
        id_col_true = y_true.columns[0]
        id_col_pred = y_pred.columns[0]
        y_true = y_true.sort_values(by=id_col_true).reset_index(drop=True)
        y_pred = y_pred.sort_values(by=id_col_pred).reset_index(drop=True)

        # Validate alignment quickly
        if (y_true[id_col_true].values != y_pred[id_col_pred].values).any():
            raise InvalidSubmissionError(
                "IDs in y_true and y_pred do not match or are not in the same order."
            )

        # Clip predictions for stability and fairness to [1, 5]
        preds = y_pred[self.value].astype(float).clip(lower=1.0, upper=5.0)
        trues = y_true[self.value].astype(float)

        # If group information exists, compute balanced RMSE across groups
        if "group" in y_true.columns:
            groups = y_true["group"].values
            per_group_rmse = []
            for g in pd.unique(groups):
                mask = groups == g
                if not mask.any():
                    raise InvalidSubmissionError(f"No samples for group {g} in ground truth.")
                err2 = (trues[mask] - preds[mask]) ** 2
                mse = float(err2.mean())
                rmse = math.sqrt(mse)
                if math.isnan(rmse) or math.isinf(rmse):
                    raise InvalidSubmissionError("Numerical issue in RMSE computation.")
                per_group_rmse.append(rmse)
            return float(sum(per_group_rmse) / len(per_group_rmse))
        else:
            # Standard RMSE
            err2 = (trues - preds) ** 2
            return float(math.sqrt(float(err2.mean())))

    def validate_submission(self, submission: Any, ground_truth: Any) -> str:
        if not isinstance(submission, pd.DataFrame):
            raise InvalidSubmissionError(
                "Submission must be a pandas DataFrame. Please provide a valid pandas DataFrame."
            )
        if not isinstance(ground_truth, pd.DataFrame):
            raise InvalidSubmissionError(
                "Ground truth must be a pandas DataFrame. Please provide a valid pandas DataFrame."
            )

        # basic length check
        if len(submission) != len(ground_truth):
            raise InvalidSubmissionError(
                f"Number of rows in submission ({len(submission)}) does not match ground truth ({len(ground_truth)})."
            )

        # Sort the submission and ground truth by their first columns
        submission = submission.sort_values(by=submission.columns[0]).reset_index(drop=True)
        ground_truth = ground_truth.sort_values(by=ground_truth.columns[0]).reset_index(drop=True)

        # Check first column values are identical after sort
        if (submission[submission.columns[0]].values != ground_truth[ground_truth.columns[0]].values).any():
            raise InvalidSubmissionError(
                "First column values (IDs) do not match between submission and ground truth."
            )

        # Required columns in submission
        required_cols = {submission.columns[0], self.value}
        sub_cols = set(submission.columns)
        if not required_cols.issubset(sub_cols):
            missing = required_cols - sub_cols
            raise InvalidSubmissionError(f"Missing required columns in submission: {', '.join(missing)}.")

        # Ground truth must have at least the value column
        if self.value not in ground_truth.columns:
            raise InvalidSubmissionError(
                f"Ground truth must contain the column '{self.value}'."
            )

        # Validate numeric predictions and reasonable ranges
        try:
            pd.to_numeric(submission[self.value])
        except Exception:
            raise InvalidSubmissionError(
                f"Submission column '{self.value}' must be numeric."
            )

        return "Submission is valid."
