from typing import Any
import pandas as pd
import numpy as np

from mledojo.metrics.base import CompetitionMetrics, InvalidSubmissionError


class JobSalaryPredictionMetrics(CompetitionMetrics):
    """Metric class for Job Salary Prediction competition using RMSE.

    Expects dataframes with two columns:
    - First column: identifier (e.g., 'job_id')
    - Target column (default): 'target_salary'
    """

    def __init__(self, value: str = "target_salary", higher_is_better: bool = False):
        super().__init__(higher_is_better)
        self.value = value

    def evaluate(self, y_true: pd.DataFrame, y_pred: pd.DataFrame) -> float:
        # Sort both dataframes by first column before calculating score
        y_true = y_true.sort_values(by=y_true.columns[0]).reset_index(drop=True)
        y_pred = y_pred.sort_values(by=y_pred.columns[0]).reset_index(drop=True)

        if y_true.shape[0] != y_pred.shape[0]:
            raise InvalidSubmissionError(
                f"y_true and y_pred must have the same number of rows. Got {y_true.shape[0]} and {y_pred.shape[0]}."
            )

        # Validate ids align exactly after sorting
        if not np.array_equal(y_true.iloc[:, 0].values, y_pred.iloc[:, 0].values):
            raise InvalidSubmissionError(
                "Identifier columns do not align between y_true and y_pred after sorting."
            )

        # Compute RMSE
        diff = y_true[self.value].astype(float).values - y_pred[self.value].astype(float).values
        rmse = float(np.sqrt(np.mean(diff ** 2)))
        return rmse

    def validate_submission(self, submission: Any, ground_truth: Any) -> str:
        if not isinstance(submission, pd.DataFrame):
            raise InvalidSubmissionError(
                "Submission must be a pandas DataFrame. Please provide a valid pandas DataFrame."
            )
        if not isinstance(ground_truth, pd.DataFrame):
            raise InvalidSubmissionError(
                "Ground truth must be a pandas DataFrame. Please provide a valid pandas DataFrame."
            )

        if len(submission) != len(ground_truth):
            raise InvalidSubmissionError(
                f"Number of rows in submission ({len(submission)}) does not match ground truth ({len(ground_truth)}). Please ensure both have the same number of rows."
            )

        # Sort the submission and ground truth by the first column
        submission_sorted = submission.sort_values(by=submission.columns[0]).reset_index(drop=True)
        ground_truth_sorted = ground_truth.sort_values(by=ground_truth.columns[0]).reset_index(drop=True)

        # Check if first columns are identical
        if not np.array_equal(
            submission_sorted.iloc[:, 0].values, ground_truth_sorted.iloc[:, 0].values
        ):
            raise InvalidSubmissionError(
                "First column values do not match between submission and ground truth. Please ensure the first column values are identical."
            )

        sub_cols = set(submission.columns)
        true_cols = set(ground_truth.columns)

        missing_cols = true_cols - sub_cols
        extra_cols = sub_cols - true_cols

        if missing_cols:
            raise InvalidSubmissionError(
                f"Missing required columns in submission: {', '.join(sorted(missing_cols))}."
            )
        if extra_cols:
            raise InvalidSubmissionError(
                f"Extra unexpected columns found in submission: {', '.join(sorted(extra_cols))}."
            )

        # Validate numeric predictions
        try:
            pd.to_numeric(submission[self.value])
        except Exception:
            raise InvalidSubmissionError(
                f"Column '{self.value}' in submission must be numeric."
            )

        return "Submission is valid."
