from typing import Any
import pandas as pd
import numpy as np

from mledojo.metrics.base import CompetitionMetrics, InvalidSubmissionError


class CattleWeightEstimationMetrics(CompetitionMetrics):
    """Metric class for cattle weight estimation using Mean Absolute Error (MAE).

    Expected CSV/DataFrame schema (aligned with public/sample_submission.csv and private/test_answer.csv):
      - filename: str, image filename present in public/test.csv
      - weight_kg: float, predicted or true weight in kilograms
    """

    CLIP_MIN = 30.0
    CLIP_MAX = 700.0

    def __init__(self, value: str = "weight_kg", higher_is_better: bool = False):
        # For MAE, lower is better
        super().__init__(higher_is_better)
        self.value = value

    def evaluate(self, y_true: pd.DataFrame, y_pred: pd.DataFrame) -> float:
        """Compute MAE between true and predicted weights after aligning by filename.

        Args:
            y_true: DataFrame with columns [filename, weight_kg]
            y_pred: DataFrame with columns [filename, weight_kg]
        Returns:
            Mean Absolute Error as a float.
        """
        self.validate_submission(y_pred, y_true)

        # Sort and align rows by filename
        y_true_sorted = y_true.sort_values(by=y_true.columns[0]).reset_index(drop=True)
        y_pred_sorted = y_pred.sort_values(by=y_pred.columns[0]).reset_index(drop=True)

        # Clip predictions to plausible anatomical range
        preds = y_pred_sorted[self.value].astype(float).clip(self.CLIP_MIN, self.CLIP_MAX).to_numpy()
        trues = y_true_sorted[self.value].astype(float).clip(self.CLIP_MIN, self.CLIP_MAX).to_numpy()

        mae = float(np.mean(np.abs(preds - trues)))
        if not np.isfinite(mae):
            raise InvalidSubmissionError("Computed MAE is non-finite; please check your predictions.")
        return mae

    def validate_submission(self, submission: Any, ground_truth: Any) -> str:
        """Validate submission format and contents.

        - Both must be pandas DataFrames
        - Row count must match
        - Columns must match exactly the ground truth columns
        - First column values (filenames) must match after sorting by the first column
        """
        if not isinstance(submission, pd.DataFrame):
            raise InvalidSubmissionError(
                "Submission must be a pandas DataFrame. Please provide a valid pandas DataFrame."
            )
        if not isinstance(ground_truth, pd.DataFrame):
            raise InvalidSubmissionError(
                "Ground truth must be a pandas DataFrame. Please provide a valid pandas DataFrame."
            )

        if len(submission) != len(ground_truth):
            raise InvalidSubmissionError(
                f"Number of rows in submission ({len(submission)}) does not match ground truth ({len(ground_truth)}). Please ensure both have the same number of rows."
            )

        # Sort both by the first column (filename)
        sub_sorted = submission.sort_values(by=submission.columns[0]).reset_index(drop=True)
        gt_sorted = ground_truth.sort_values(by=ground_truth.columns[0]).reset_index(drop=True)

        # Check first-column values are identical
        if not np.array_equal(sub_sorted.iloc[:, 0].to_numpy(), gt_sorted.iloc[:, 0].to_numpy()):
            raise InvalidSubmissionError(
                "First column values do not match between submission and ground truth. Please ensure filenames match those in public/test.csv."
            )

        sub_cols = list(submission.columns)
        gt_cols = list(ground_truth.columns)
        if sub_cols != gt_cols:
            # Provide detailed messaging about column mismatch
            sub_only = [c for c in sub_cols if c not in gt_cols]
            gt_only = [c for c in gt_cols if c not in sub_cols]
            msg_bits = []
            if gt_only:
                msg_bits.append(f"Missing required columns in submission: {', '.join(gt_only)}")
            if sub_only:
                msg_bits.append(f"Extra unexpected columns in submission: {', '.join(sub_only)}")
            if not msg_bits:
                msg_bits.append(f"Submission columns {sub_cols} must exactly match ground truth columns {gt_cols}")
            raise InvalidSubmissionError("; ".join(msg_bits) + ".")

        # Basic type checks for numeric column
        try:
            _ = pd.to_numeric(submission.iloc[:, 1])
        except Exception:
            raise InvalidSubmissionError("The second column (weight_kg) must be numeric.")

        return "Submission is valid."
