from typing import Any, List
import numpy as np
import pandas as pd

from mledojo.metrics.base import CompetitionMetrics, InvalidSubmissionError


class SpotifyPlaylistIntentMetrics(CompetitionMetrics):
    """
    Macro-F1 averaged across multi-label targets after thresholding probabilities.

    - Labels are the non-id columns in the ground truth (test_answer.csv)
    - Predictions must follow the exact schema of sample_submission.csv
    - Threshold defaults to `value` (float), typically 0.5
    """

    def __init__(self, value: float = 0.5, higher_is_better: bool = True):
        super().__init__(higher_is_better)
        self.threshold = float(value)

    def _target_columns(self, df: pd.DataFrame) -> List[str]:
        return [c for c in df.columns if c != "id"]

    def _validate_columns_exact(self, submission: pd.DataFrame, sample: pd.DataFrame):
        if list(submission.columns) != list(sample.columns):
            raise InvalidSubmissionError(
                "Submission columns must exactly match sample_submission.csv (names and order)."
            )

    def _validate_ids(self, submission: pd.DataFrame):
        if not submission["id"].is_unique:
            raise InvalidSubmissionError("Duplicate ids in submission.")
        # Coerce to numeric, ensure integer-like
        coerced = pd.to_numeric(submission["id"], errors="coerce")
        if coerced.isna().any():
            raise InvalidSubmissionError("id column contains non-numeric values.")
        if not (coerced == np.floor(coerced)).all():
            raise InvalidSubmissionError("id column must be integer-like.")

    def _validate_values_range(self, submission: pd.DataFrame):
        for c in submission.columns:
            if c == "id":
                continue
            vals = pd.to_numeric(submission[c], errors="coerce").astype(float).values
            if not np.isfinite(vals).all():
                raise InvalidSubmissionError(f"Non-finite values in column {c}.")
            if not (((vals >= 0.0) & (vals <= 1.0)).all()):
                raise InvalidSubmissionError(f"Values out of [0,1] in column {c}.")

    def evaluate(self, y_true: pd.DataFrame, y_pred: pd.DataFrame) -> float:
        # Sort both by id and align
        y_true = y_true.sort_values(by=y_true.columns[0]).reset_index(drop=True)
        y_pred = y_pred.sort_values(by=y_pred.columns[0]).reset_index(drop=True)

        # Basic shape/schema checks
        true_cols = self._target_columns(y_true)
        pred_cols = self._target_columns(y_pred)
        assert true_cols == pred_cols, "Target columns between y_true and y_pred must match."
        assert (y_true.iloc[:, 0].values == y_pred.iloc[:, 0].values).all(), "Ids must align."

        y_true_mat = y_true[true_cols].to_numpy(dtype=float)
        y_pred_prob = y_pred[pred_cols].to_numpy(dtype=float)
        y_pred_prob = np.clip(y_pred_prob, 0.0, 1.0)

        thr = self.threshold
        f1s = []
        for j in range(y_true_mat.shape[1]):
            yt = y_true_mat[:, j].astype(int)
            if yt.sum() == 0:
                # Skip labels with no positives
                continue
            yp = (y_pred_prob[:, j] >= thr).astype(int)
            tp = int(((yp == 1) & (yt == 1)).sum())
            fp = int(((yp == 1) & (yt == 0)).sum())
            fn = int(((yp == 0) & (yt == 1)).sum())
            denom = 2 * tp + fp + fn
            f1 = (2.0 * tp) / denom if denom > 0 else 0.0
            if not np.isfinite(f1):
                f1 = 0.0
            f1s.append(f1)
        if len(f1s) == 0:
            return 0.0
        return float(np.mean(f1s))

    def validate_submission(self, submission: Any, ground_truth: Any) -> str:
        if not isinstance(submission, pd.DataFrame):
            raise InvalidSubmissionError("Submission must be a pandas DataFrame.")
        if not isinstance(ground_truth, pd.DataFrame):
            raise InvalidSubmissionError("Ground truth must be a pandas DataFrame.")

        # Ground truth is expected to be test_answer.csv (id + targets)
        # Sample schema is: id + targets
        sample_like = pd.DataFrame(columns=ground_truth.columns)

        # Column, id, and value checks
        self._validate_columns_exact(submission, sample_like)
        self._validate_ids(submission)
        self._validate_values_range(submission)

        # Row count must match ground truth
        if len(submission) != len(ground_truth):
            raise InvalidSubmissionError(
                f"Row count mismatch: submission has {len(submission)} rows, ground truth has {len(ground_truth)} rows."
            )

        # Sort both and ensure ids match the ground truth id set and order when sorted
        sub_sorted = submission.sort_values(by=submission.columns[0]).reset_index(drop=True)
        gt_sorted = ground_truth.sort_values(by=ground_truth.columns[0]).reset_index(drop=True)
        if (sub_sorted.iloc[:, 0].astype(int).values != gt_sorted.iloc[:, 0].astype(int).values).any():
            raise InvalidSubmissionError("Submission ids must match test ids.")

        return "Submission is valid."
