from typing import Any
import math
import pandas as pd
import numpy as np
from sklearn.metrics import average_precision_score

from mledojo.metrics.base import CompetitionMetrics, InvalidSubmissionError


REQUIRED_SUBMISSION_COLUMNS = ["video_id", "p_video_fake", "p_audio_fake"]
GT_COLUMNS = ["video_id", "label_video_fake", "label_audio_fake"]


class LAVDFDetectionMetrics(CompetitionMetrics):
    """
    Metric class for LAV-DF Deepfake Manipulation Detection.
    Computes mean Average Precision (mAP) over two labels: video_fake and audio_fake.
    """

    def __init__(self, value: str = "mAP", higher_is_better: bool = True):
        super().__init__(higher_is_better)
        self.value = value

    def evaluate(self, y_true: pd.DataFrame, y_pred: pd.DataFrame) -> float:
        # Sort both dataframes by first column (video_id) to align rows
        y_true = y_true.sort_values(by=y_true.columns[0]).reset_index(drop=True)
        y_pred = y_pred.sort_values(by=y_pred.columns[0]).reset_index(drop=True)

        # Validate columns minimally before scoring
        if list(y_true.columns)[:3] != GT_COLUMNS:
            raise InvalidSubmissionError(
                f"Ground truth must have columns {GT_COLUMNS}, got {list(y_true.columns)}"
            )
        if list(y_pred.columns)[:3] != REQUIRED_SUBMISSION_COLUMNS:
            raise InvalidSubmissionError(
                f"Submission must have columns {REQUIRED_SUBMISSION_COLUMNS}, got {list(y_pred.columns)}"
            )

        # Check ids match positionally after sorting
        if not np.array_equal(
            y_true[y_true.columns[0]].values, y_pred[y_pred.columns[0]].values
        ):
            raise InvalidSubmissionError(
                "video_id values in submission do not match ground truth after sorting."
            )

        # Extract labels and predictions
        yv_true = y_true["label_video_fake"].astype(int).values
        ya_true = y_true["label_audio_fake"].astype(int).values

        # Clip predictions to [0, 1]
        yv_pred = np.clip(y_pred["p_video_fake"].astype(float).values, 0.0, 1.0)
        ya_pred = np.clip(y_pred["p_audio_fake"].astype(float).values, 0.0, 1.0)

        # Compute AP for both tasks and average
        ap_v = average_precision_score(yv_true, yv_pred)
        ap_a = average_precision_score(ya_true, ya_pred)
        m_ap = float((ap_v + ap_a) / 2.0)
        if math.isnan(m_ap) or math.isinf(m_ap):
            raise InvalidSubmissionError("Metric computation failed: NaN/Inf encountered.")
        return m_ap

    def validate_submission(self, submission: Any, ground_truth: Any) -> str:
        if not isinstance(submission, pd.DataFrame):
            raise InvalidSubmissionError(
                "Submission must be a pandas DataFrame. Please provide a valid pandas DataFrame."
            )
        if not isinstance(ground_truth, pd.DataFrame):
            raise InvalidSubmissionError(
                "Ground truth must be a pandas DataFrame. Please provide a valid pandas DataFrame."
            )

        # Check required columns in submission
        sub_cols = list(submission.columns)
        if sub_cols != REQUIRED_SUBMISSION_COLUMNS:
            raise InvalidSubmissionError(
                f"Submission must have columns exactly {REQUIRED_SUBMISSION_COLUMNS}, got {sub_cols}"
            )

        # Check ground truth columns minimally
        gt_cols = list(ground_truth.columns)
        if gt_cols[:3] != GT_COLUMNS:
            raise InvalidSubmissionError(
                f"Ground truth must have columns {GT_COLUMNS}, got {gt_cols}"
            )

        # Length must match
        if len(submission) != len(ground_truth):
            raise InvalidSubmissionError(
                f"Number of rows in submission ({len(submission)}) does not match ground truth ({len(ground_truth)})."
            )

        # Sort and verify ids are identical and unique
        sub_sorted = submission.sort_values(by="video_id").reset_index(drop=True)
        gt_sorted = ground_truth.sort_values(by="video_id").reset_index(drop=True)

        sub_ids = sub_sorted["video_id"].astype(str).values
        gt_ids = gt_sorted["video_id"].astype(str).values
        if not np.array_equal(sub_ids, gt_ids):
            raise InvalidSubmissionError(
                "First column values (video_id) do not match between submission and ground truth."
            )
        # Check duplicates
        if len(set(sub_ids)) != len(sub_ids):
            raise InvalidSubmissionError("Duplicate video_id values found in submission.")

        # Validate prediction ranges and numeric type
        for col in ["p_video_fake", "p_audio_fake"]:
            vals = pd.to_numeric(sub_sorted[col], errors="coerce").values
            if np.any(np.isnan(vals)) or np.any(~np.isfinite(vals)):
                raise InvalidSubmissionError(
                    f"Column {col} contains NaN/Inf or non-numeric values."
                )
            if np.any(vals < 0.0) or np.any(vals > 1.0):
                raise InvalidSubmissionError(
                    f"Column {col} has values outside [0, 1]."
                )

        return "Submission is valid."
