from typing import Optional, Tuple, List, Dict
import logging
import os
import json
import torch
import torch.nn.functional as F

import numpy as np
import random

from scipy.optimize import brentq
from scipy.interpolate import interp1d
from sklearn.metrics import roc_auc_score, roc_curve


def load_id2label(path: str) -> Dict[str, int]:
    """
    Load id2label from JSON and create label2id. Validate against num_classes.

    Args:
        path (str): Path to id2label.json file.

    Returns:
        Dict[int, str]: id2label
    """
    with open(path, "r") as f:
        id2label = json.load(f)

    if not isinstance(id2label, dict):
        raise ValueError(f"Expected dict in {path}, got {type(id2label)}")
    return id2label


def save_id2label(id2label: Dict[int, str], dir: str, overwrite: bool = True) -> None:
    """
    Save id2label mapping to JSON file.

    If `overwrite` is False and the file already exists and is identical, skip saving.
    If `overwrite` is True and a different file exists, back it up.
    """
    path = os.path.join(dir, 'id2label.json')
    if os.path.exists(path):
        with open(path, "r") as f:
            existing = json.load(f)
        if existing == id2label:
            logging.info(f"Existing id2label at {path} is identical. Skipping save.")
            return
        elif not overwrite:
            logging.info(f"Not overwriting existing id2label at {path}")
            return
        else:
            # Backup
            backup_path = path.replace(".json", ".bak.json")
            logging.warning(f"Backing up existing id2label to {backup_path}")
            os.rename(path, backup_path)

    with open(path, "w") as f:
        json.dump(id2label, f, indent=2, ensure_ascii=False)
    logging.info(f"Saved id2label to {path}")


def build_verification_pairs(
    embeddings: torch.Tensor,
    speaker_ids: List[str],
    sample_ratio: Optional[float] = None,
) -> Tuple[float, float]:
    """
    Construct pairwise similarity scores and ground-truth labels for speaker verification during validation stage
    Note: better set sample_ratio if validation set is large

    Args:
        embeddings (Tensor): (N, D) L2-normalized speaker embeddings.
        speaker_ids (List[str]): Speaker ID for each embedding.
        sample_ratio (float, optional): If set (e.g. 0.1), randomly sample that proportion of all pairs.

    Returns:
        scores (List[float]): Cosine similarity scores for all valid (i < j) pairs.
        labels (List[int]): 1 for same-speaker pairs, 0 for different-speaker pairs.
    """

    if len(speaker_ids) < 2:
        logging.warning("Not enough samples to calculate EER/minDCF.")
        return 1.0, 1.0

    if embeddings.ndim != 2:
        raise ValueError(f"Embeddings must be 2D (N, D), got {embeddings.shape}")

    # Normalize
    embeddings = embeddings / torch.linalg.norm(embeddings, dim=1, keepdim=True)
    embeddings[torch.isnan(embeddings)] = 0.0
    sim_matrix = torch.matmul(embeddings, embeddings.T).cpu().numpy()

    N = len(speaker_ids)
    all_pairs = [(i, j) for i in range(N) for j in range(i + 1, N)]

    # Sample subset if requested
    if sample_ratio is not None and 0 < sample_ratio < 1.0:
        num_samples = int(len(all_pairs) * sample_ratio)
        all_pairs = random.sample(all_pairs, num_samples)
        # logging.info(f"[Validation] Sampled {num_samples} pairs out of {len(embeddings) * (len(embeddings) - 1) // 2} total pairs (sample_ratio={sample_ratio}) for EER calculation.")

    scores, labels = [], []
    for i, j in all_pairs:
        scores.append(sim_matrix[i, j])
        labels.append(1 if speaker_ids[i] == speaker_ids[j] else 0)

    return scores, labels




def compute_verification_metrics(similarities, labels, p_target=0.01, c_miss=1, c_fa=1):
    """
    Compute speaker verification metrics: AUC, EER, threshold@EER, and minDCF.

    Args:
        similarities (array-like): Cosine similarity scores between utterance pairs.
        labels (array-like): Ground truth binary labels (1 for same speaker, 0 for different).
        p_target (float): Prior probability of target trials (default 0.01).
        c_miss (float): Cost of a miss (default 1).
        c_fa (float): Cost of a false alarm (default 1).

    Returns:
        auc (float): Area under the ROC curve.
        eer (float): Equal error rate.
        threshold_at_eer (float): Similarity threshold at which EER occurs (approximate).
        min_dcf (float): Minimum Detection Cost Function value.
    """
    auc = roc_auc_score(labels, similarities)
    fpr, tpr, thresholds = roc_curve(labels, similarities)
    
    # Compute EER using brentq for better precision
    f = interp1d(fpr, tpr, kind='linear')
    eer = brentq(lambda x: 1. - x - f(x), 0., 1.)

    # Find closest threshold to EER (approximate, for reporting)
    fnr = 1 - tpr
    eer_idx = np.nanargmin(np.abs(fpr - eer))
    threshold_at_eer = thresholds[eer_idx]

    # minDCF
    c_det = c_miss * fnr * p_target + c_fa * fpr * (1 - p_target)
    min_dcf = np.min(c_det)

    return auc, eer, threshold_at_eer, min_dcf

