from typing import Any
import pandas as pd
import numpy as np
from mledojo.metrics.base import CompetitionMetrics, InvalidSubmissionError


class QuranSurahClassificationMetrics(CompetitionMetrics):
    """
    Metric class for Quran Surah classification from audio.

    Expected CSV schemas:
    - private/test_answer.csv: columns = ['id', 'surah_id']
    - public/sample_submission.csv: columns = ['id', 'surah_id']

    Metric: macro-F1 over classes 1..114 (higher is better).
    """

    def __init__(self, value: str = "surah_id", higher_is_better: bool = True):
        super().__init__(higher_is_better)
        self.value = value
        self.valid_labels = set(range(1, 115))  # 1..114 inclusive

    def _macro_f1(self, y_true: np.ndarray, y_pred: np.ndarray) -> float:
        labels = sorted(self.valid_labels)
        # Initialize counts
        tp = {c: 0 for c in labels}
        fp = {c: 0 for c in labels}
        fn = {c: 0 for c in labels}

        for t, p in zip(y_true, y_pred):
            if t == p:
                tp[t] += 1
            else:
                fp[p] += 1
                fn[t] += 1

        f1s = []
        for c in labels:
            precision = tp[c] / (tp[c] + fp[c]) if (tp[c] + fp[c]) > 0 else 0.0
            recall = tp[c] / (tp[c] + fn[c]) if (tp[c] + fn[c]) > 0 else 0.0
            if precision + recall == 0:
                f1 = 0.0
            else:
                f1 = 2 * precision * recall / (precision + recall)
            f1s.append(float(f1))
        return float(np.mean(f1s)) if len(f1s) else 0.0

    def evaluate(self, y_true: pd.DataFrame, y_pred: pd.DataFrame) -> float:
        # Sort both dataframes by first column before calculating score
        y_true = y_true.sort_values(by=y_true.columns[0]).reset_index(drop=True)
        y_pred = y_pred.sort_values(by=y_pred.columns[0]).reset_index(drop=True)

        # After validate_submission, columns and ids are aligned
        y_t = y_true[self.value].astype(int).to_numpy()
        y_p = y_pred[self.value].astype(int).to_numpy()
        return self._macro_f1(y_t, y_p)

    def validate_submission(self, submission: Any, ground_truth: Any) -> str:
        # Type checks
        if not isinstance(submission, pd.DataFrame):
            raise InvalidSubmissionError(
                "Submission must be a pandas DataFrame. Please provide a valid pandas DataFrame."
            )
        if not isinstance(ground_truth, pd.DataFrame):
            raise InvalidSubmissionError(
                "Ground truth must be a pandas DataFrame. Please provide a valid pandas DataFrame."
            )

        # Basic length checks
        if len(submission) != len(ground_truth):
            raise InvalidSubmissionError(
                f"Number of rows in submission ({len(submission)}) does not match ground truth ({len(ground_truth)}). Please ensure both have the same number of rows."
            )

        # Required columns: id and value (surah_id)
        required_cols = {ground_truth.columns[0], self.value}
        sub_cols = set(submission.columns)
        true_cols = set(ground_truth.columns)

        missing_cols = required_cols - sub_cols
        if missing_cols:
            raise InvalidSubmissionError(
                f"Missing required columns in submission: {', '.join(sorted(missing_cols))}."
            )

        # Extra columns are allowed only if identical to ground truth columns set
        extra_cols = sub_cols - true_cols
        if extra_cols:
            raise InvalidSubmissionError(
                f"Extra unexpected columns found in submission: {', '.join(sorted(extra_cols))}."
            )

        # Sort the submission and ground truth by the first column (id), then compare ids
        submission_sorted = submission.sort_values(by=submission.columns[0]).reset_index(drop=True)
        ground_truth_sorted = ground_truth.sort_values(by=ground_truth.columns[0]).reset_index(drop=True)

        if not np.array_equal(
            submission_sorted[submission_sorted.columns[0]].astype(str).to_numpy(),
            ground_truth_sorted[ground_truth_sorted.columns[0]].astype(str).to_numpy(),
        ):
            raise InvalidSubmissionError(
                "First column values (ids) do not match between submission and ground truth. Please ensure the first column values are identical."
            )

        # Validate value column contents are integers in the valid range
        try:
            pred_vals = submission_sorted[self.value].astype(int).to_list()
        except Exception as e:
            raise InvalidSubmissionError(
                f"Submission column `{self.value}` must contain integer values (1..114). Error: {e}"
            )
        try:
            true_vals = ground_truth_sorted[self.value].astype(int).to_list()
        except Exception as e:
            raise InvalidSubmissionError(
                f"Ground truth column `{self.value}` must contain integer values (1..114). Error: {e}"
            )

        if not all(v in self.valid_labels for v in pred_vals):
            raise InvalidSubmissionError(
                f"Submission contains labels outside the valid range 1..114."
            )
        if not all(v in self.valid_labels for v in true_vals):
            raise InvalidSubmissionError(
                f"Ground truth contains labels outside the valid range 1..114."
            )

        return "Submission is valid."
