from typing import Any, List

import numpy as np
import pandas as pd

from mledojo.metrics.base import CompetitionMetrics, InvalidSubmissionError


class ESGRegressionMetrics(CompetitionMetrics):
    """Metric class for ESG multi-target regression using MCRMSE (lower is better)."""

    TARGET_COLUMNS = ["e_score", "s_score", "g_score", "total_score"]

    def __init__(self, value: List[str] | None = None, higher_is_better: bool = False):
        super().__init__(higher_is_better)
        # value represents the target columns to evaluate
        self.value = value if value is not None else ESGRegressionMetrics.TARGET_COLUMNS

    @staticmethod
    def _ensure_required_columns(df: pd.DataFrame, required: List[str], name: str):
        missing = [c for c in required if c not in df.columns]
        if missing:
            raise InvalidSubmissionError(f"{name} is missing required columns: {missing}")

    @staticmethod
    def _coerce_numeric(df: pd.DataFrame, cols: List[str], name: str) -> pd.DataFrame:
        df = df.copy()
        for c in cols:
            if c == "id":
                continue
            df[c] = pd.to_numeric(df[c], errors="coerce")
            if df[c].isna().any():
                bad = df.loc[df[c].isna(), "id"].tolist()[:5]
                raise InvalidSubmissionError(
                    f"Non-numeric or NaN values found in column '{c}' of {name} for ids: {bad}..."
                )
            if not np.isfinite(df[c].to_numpy(dtype="float64")).all():
                raise InvalidSubmissionError(
                    f"Non-finite values (inf/-inf) found in column '{c}' of {name}"
                )
        return df

    def evaluate(self, y_true: pd.DataFrame, y_pred: pd.DataFrame) -> float:
        # Sort both dataframes by first column before calculating score
        y_true = y_true.sort_values(by=y_true.columns[0]).reset_index(drop=True)
        y_pred = y_pred.sort_values(by=y_pred.columns[0]).reset_index(drop=True)

        # Align columns to self.value
        cols = [c for c in self.value]
        y_true_arr = y_true[cols].to_numpy(dtype="float64")
        y_pred_arr = y_pred[cols].to_numpy(dtype="float64")

        if y_true_arr.shape != y_pred_arr.shape:
            raise InvalidSubmissionError(
                f"Shape mismatch for scoring: y_true {y_true_arr.shape} vs y_pred {y_pred_arr.shape}"
            )

        # Mean Columnwise Root Mean Squared Error
        diffs = y_pred_arr - y_true_arr
        mse = np.mean(np.square(diffs), axis=0)
        rmse = np.sqrt(np.maximum(mse, 0.0))
        score = float(np.mean(rmse))
        if not np.isfinite(score):
            raise InvalidSubmissionError("Computed score is not finite")
        return score

    def validate_submission(self, submission: Any, ground_truth: Any) -> str:
        if not isinstance(submission, pd.DataFrame):
            raise InvalidSubmissionError(
                "Submission must be a pandas DataFrame. Please provide a valid pandas DataFrame."
            )
        if not isinstance(ground_truth, pd.DataFrame):
            raise InvalidSubmissionError(
                "Ground truth must be a pandas DataFrame. Please provide a valid pandas DataFrame."
            )

        required_cols = ["id"] + list(self.value)
        self._ensure_required_columns(submission, required_cols, "submission")
        self._ensure_required_columns(ground_truth, required_cols, "ground truth")

        # Check number of rows
        if len(submission) != len(ground_truth):
            raise InvalidSubmissionError(
                f"Number of rows in submission ({len(submission)}) does not match ground truth ({len(ground_truth)})."
            )

        # Sort the submission and ground truth by the first column
        submission_sorted = submission.sort_values(by=submission.columns[0]).reset_index(drop=True)
        ground_truth_sorted = ground_truth.sort_values(by=ground_truth.columns[0]).reset_index(drop=True)

        # Check if first columns are identical
        if (
            submission_sorted[submission_sorted.columns[0]].values
            != ground_truth_sorted[ground_truth_sorted.columns[0]].values
        ).any():
            raise InvalidSubmissionError(
                "First column values do not match between submission and ground truth. Please ensure the first column values are identical."
            )

        # Check for missing/extra columns exactly
        sub_cols = set(submission.columns)
        true_cols = set(ground_truth.columns)
        missing_cols = true_cols - sub_cols
        extra_cols = sub_cols - true_cols
        if missing_cols:
            raise InvalidSubmissionError(
                f"Missing required columns in submission: {', '.join(sorted(missing_cols))}."
            )
        if extra_cols:
            raise InvalidSubmissionError(
                f"Extra unexpected columns found in submission: {', '.join(sorted(extra_cols))}."
            )

        # Coerce numeric and validate values are finite
        self._coerce_numeric(submission_sorted, required_cols, "submission")
        self._coerce_numeric(ground_truth_sorted, required_cols, "ground truth")

        return "Submission is valid."
