from typing import Any
import pandas as pd
import numpy as np

from mledojo.metrics.base import CompetitionMetrics, InvalidSubmissionError


REQUIRED_COLUMNS = ["subject_id", "segment_start", "segment_end", "label"]
KEY_COLUMNS = ["subject_id", "segment_start", "segment_end"]


def _ensure_int_series(s: pd.Series, name: str) -> pd.Series:
    if s.isna().any():
        raise InvalidSubmissionError(f"Column {name} contains NA values.")
    # allow strings/numbers that can be converted to int
    try:
        vals = s.astype(int)
    except Exception:
        # try stricter conversion via float->int only if integral
        try:
            f = s.astype(float)
            if not np.all(np.floor(f) == f):
                raise InvalidSubmissionError(f"Column {name} must contain integers.")
            vals = f.astype(int)
        except Exception as e:
            raise InvalidSubmissionError(f"Column {name} must be integer-convertible: {e}")
    return vals


def _macro_f1(y_true: np.ndarray, y_pred: np.ndarray, num_classes: int = 7) -> float:
    eps = 1e-12
    f1s = []
    for c in range(num_classes):
        tp = int(np.sum((y_true == c) & (y_pred == c)))
        fp = int(np.sum((y_true != c) & (y_pred == c)))
        fn = int(np.sum((y_true == c) & (y_pred != c)))
        precision = tp / (tp + fp + eps)
        recall = tp / (tp + fn + eps)
        f1 = 2.0 * precision * recall / (precision + recall + eps)
        f1s.append(f1)
    return float(np.mean(f1s))


class OEPProctoringMetrics(CompetitionMetrics):
    """Macro-F1 classification metric for OEP segment labels (0..6)."""

    def __init__(self, value: str = "label", higher_is_better: bool = True):
        super().__init__(higher_is_better)
        self.value = value  # kept for API parity; fixed to "label" for this task

    def evaluate(self, y_true: pd.DataFrame, y_pred: pd.DataFrame) -> float:
        if not isinstance(y_true, pd.DataFrame) or not isinstance(y_pred, pd.DataFrame):
            raise InvalidSubmissionError("y_true and y_pred must be pandas DataFrames")

        # Validate columns presence
        true_cols = set(y_true.columns)
        pred_cols = set(y_pred.columns)
        required_true = set(KEY_COLUMNS + [self.value])
        required_pred = set(REQUIRED_COLUMNS)
        if not required_true.issubset(true_cols):
            missing = required_true - true_cols
            raise InvalidSubmissionError(f"Ground truth missing required columns: {', '.join(sorted(missing))}")
        if not required_pred.issubset(pred_cols):
            missing = required_pred - pred_cols
            raise InvalidSubmissionError(f"Predictions missing required columns: {', '.join(sorted(missing))}")

        # Sort and align by keys
        y_true_sorted = y_true.sort_values(KEY_COLUMNS).reset_index(drop=True).copy()
        y_pred_sorted = y_pred.sort_values(KEY_COLUMNS).reset_index(drop=True).copy()

        # Keys must match exactly
        if not np.array_equal(y_true_sorted[KEY_COLUMNS].values, y_pred_sorted[KEY_COLUMNS].values):
            raise InvalidSubmissionError("First three key columns (subject_id, segment_start, segment_end) do not match between y_true and y_pred.")

        # Cast values to integers and validate ranges
        yt = _ensure_int_series(y_true_sorted[self.value], self.value).to_numpy()
        yp = _ensure_int_series(y_pred_sorted["label"], "label").to_numpy()
        if (yt < 0).any() or (yt > 6).any():
            raise InvalidSubmissionError("Ground truth labels must be in [0, 6].")
        if (yp < 0).any() or (yp > 6).any():
            raise InvalidSubmissionError("Submission labels must be in [0, 6].")

        return _macro_f1(yt, yp, num_classes=7)

    def validate_submission(self, submission: Any, ground_truth: Any) -> str:
        import pandas as pd  # local import to avoid hard dependency when unused

        if not isinstance(submission, pd.DataFrame):
            raise InvalidSubmissionError("Submission must be a pandas DataFrame.")
        if not isinstance(ground_truth, pd.DataFrame):
            raise InvalidSubmissionError("Ground truth must be a pandas DataFrame.")

        # Column checks
        sub_cols = set(submission.columns)
        gt_cols = set(ground_truth.columns)
        required_gt = set(KEY_COLUMNS + [self.value])
        if not set(REQUIRED_COLUMNS).issubset(sub_cols):
            missing = set(REQUIRED_COLUMNS) - sub_cols
            raise InvalidSubmissionError(f"Missing required columns in submission: {', '.join(sorted(missing))}.")
        if not required_gt.issubset(gt_cols):
            missing = required_gt - gt_cols
            raise InvalidSubmissionError(f"Missing required columns in ground truth: {', '.join(sorted(missing))}.")

        # Length match
        if len(submission) != len(ground_truth):
            raise InvalidSubmissionError(
                f"Number of rows in submission ({len(submission)}) does not match ground truth ({len(ground_truth)})."
            )

        # Sort by keys and ensure they match exactly
        sub_sorted = submission.sort_values(KEY_COLUMNS).reset_index(drop=True).copy()
        gt_sorted = ground_truth.sort_values(KEY_COLUMNS).reset_index(drop=True).copy()
        if not np.array_equal(sub_sorted[KEY_COLUMNS].values, gt_sorted[KEY_COLUMNS].values):
            raise InvalidSubmissionError(
                "First three key columns do not match between submission and ground truth."
            )

        # Check dtypes/values are valid
        # Keys must be integer-convertible for segment_start and segment_end
        _ensure_int_series(sub_sorted["segment_start"], "segment_start")
        _ensure_int_series(sub_sorted["segment_end"], "segment_end")
        # Label integer and in range
        labels = _ensure_int_series(sub_sorted["label"], "label")
        if (labels < 0).any() or (labels > 6).any():
            raise InvalidSubmissionError("Submission label must be integer in [0, 6].")

        return "Submission is valid."
