##Calculation of metrics for uncertainty quantification like AU, EU
from dataclasses import dataclass
from typing import Any, Dict, List

import numpy as np
from scipy.special import digamma

EPS = 1e-2


def entropy(p):
    # calculate entropy
    return -np.sum(p * np.log(p + EPS))


def cross_entropy(p, q):
    # mask out zero probabilities
    # calculate cross entropy
    return -np.sum(p * np.log(q + EPS))


def kl_divergence(p, q):
    # calculate KL divergence

    return np.sum(p * np.log((p + EPS) / (q + EPS)))


def jensen_shannon_divergence(p, q):
    m = 0.5 * (p + q)
    return 0.5 * kl_divergence(p, m) + 0.5 * kl_divergence(q, m)


def get_expected_kl_divergence(counts: list, p: list, likelihood_multiplier: float = 1):
    """
    #assumes p* comes from a dirchlet with alpha = 1 and counts = 1
    """
    alpha = 1 + np.array(counts) * likelihood_multiplier
    # calculate expected KL divergence
    alpha0 = np.sum(alpha)

    out = 0
    for i in range(len(counts)):
        # calculate expected KL divergence
        out += (alpha[i] / alpha0) * (
            digamma(alpha[i] + 1) - digamma(alpha0 + 1) - np.log(p[i] + EPS)
        )

    return out


def get_expected_entropy(counts: list, likelihood_multiplier: float = 1):
    """
    #assumes p* comes from a dirchlet with alpha = 1 and counts = 1
    #https://math.stackexchange.com/questions/3303461/expected-value-of-dirichlet-distribution-times-its-logarithm-component-wise
    """
    alpha = 1 + np.array(counts) * likelihood_multiplier
    # calculate expected entropy
    alpha0 = np.sum(alpha)

    out = 0
    for i in range(len(counts)):
        # calculate expected entropy
        out += (alpha[i] / alpha0) * (digamma(alpha[i] + 1) - digamma(alpha0 + 1))

    return -out


def get_EU(
    p_star: np.ndarray,
    p_hat: np.ndarray,
    counts: np.ndarray,
    likelihood_multipliers: List[float] = [1.0],
) -> Dict[str, Any]:
    """
    Calculate the epistemic uncertainty (EU) for a given sample.
    1. Calculate with the exact p*
    2. Use dirchlet with different likelihood multipliers

    Args:
        p_star (np.ndarray): The estimated P* distribution.
        p_hat (np.ndarray): The estimated P_hat distribution.
        counts (list): The counts of the answers.
        likelihood_multipliers (List[float]): List of likelihood multipliers for expected KL divergence
    """

    result = {"entropy": entropy(p_star), "kl_divergence": kl_divergence(p_star, p_hat)}

    for likelihood_multiplier in likelihood_multipliers:
        expected_kl = get_expected_kl_divergence(counts, p_hat, likelihood_multiplier)
        expected_entropy = get_expected_entropy(counts, likelihood_multiplier)

        result[f"expected_kl_{likelihood_multiplier}"] = expected_kl
        result[f"expected_entropy_{likelihood_multiplier}"] = expected_entropy

    return result


def align_answers_with_vocab(
    vocab: list, mapping: dict, answers: List[List[str]], counts: List[int]
) -> np.ndarray:
    """
    Aligns the answers with the vocabulary and counts. This is important since the answers might not be in the same order as the vocabulary.
    The run produced a mapping of answers to vocabulary. (The reason is that the model produced answer order (vocab) is not the same as the order in the dataset (answers)).
    This function uses the mapping to align the answers of the datasets and the counts with the vocabulary.

    Params:
        vocab (list): The vocabulary of possible answers in the order used by p_hat (model)
        mapping (dict): The mapping from answers to vocabulary used in the process of running the model
        answers (List[List[str]]): The answers from the dataset (could be multiple for one semantic group)
        counts (List[int]): The counts of the answers from the dataset
    Returns:
        np.ndarray: The aligned counts with the vocabulary. I.e. the counts are in the same order as the vocabulary.
    """

    answers = [
        x[0] for x in answers
    ]  # If multiple answer possible for one semantic group take the first one

    if len(answers) != len(counts):
        # In AMBIGQA could be that we have duplicates
        answers = list(dict.fromkeys(answers))  # Remove duplicates

    assert len(answers) == len(counts), "Answers and counts must have the same length"

    alinged_counts = []
    for voc in vocab:
        # 1. Find vocab in dict values and corresponding key(s)
        keys = [k for k, v in mapping.items() if v == voc]
        # Now check if key is in answers
        count = -1
        for key in keys:
            if key in answers:
                # get index of key in answers
                index = answers.index(key)
                # get the corresponding count
                count = counts[index]
                alinged_counts.append(count)
                break  # Edge case that we have more keys then vocabs
        # 2. If not found, append 0
        if count == -1:
            alinged_counts.append(0)
    return np.array(alinged_counts)


def estimate_pstar(
    wandb_config: Dict[str, Any],
    run: List[Dict[str, Any]],
    pstar_estimators: List[Any],
    likelihood_multipliers: List[float] = [1.0],
) -> Dict[str, Any]:
    """
    Estimate P* for a given run using the provided P* estimators.

    Args:
        run (List[Dict[str, Any]]): The input run containing samples.
        pstar_estimators (List[UEMetric]): List of P* estimators to use.
        likelihood_multipliers (List[float]): List of likelihood multipliers for expected KL divergence and entropy calculations.

    Returns:
        List[Dict[str, Any]]: A list of dictionaries containing the P* estimates for each sample.
    """

    uq_estimators = wandb_config["orchestrator"]["kwargs"]["uq_estimators"]
    results = []
    for i, sample in enumerate(
        run
    ):  # assumes aligment of samples between run and pstar_estimators
        sample_results = {}
        question = sample["question"]
        id = sample["id"]
        vocab = sample["vocab"]
        mapping = sample["mapping"]
        p_hat = sample["p_hat"]
        for estimator, estimator_name in pstar_estimators:
            estimator_sample = estimator[i]
            question_estimator = estimator_sample["question"]
            id_estimator = estimator_sample["id"]

            # assert question and id are the same
            assert question == question_estimator, (
                f"Question mismatch: {question} != {question_estimator}"
            )
            assert id == id_estimator, f"ID mismatch: {id} != {id_estimator}"

            counts = align_answers_with_vocab(
                vocab,
                mapping,
                estimator_sample["answer"],  # unpack because of list
                estimator_sample["counts"],
            )
            p_star = (
                counts / counts.sum()
                if counts.sum() > 0
                else np.full(
                    len(counts), 1.0 / max(len(counts), 1), dtype=np.float64
                )  # uniform distribution if no counts
            )

            epistemic_uncertainty = get_EU(
                p_star, p_hat, counts, likelihood_multipliers=likelihood_multipliers
            )

            correct = True
            if not np.any((p_hat > 0) & (p_star > 0)):
                correct = False

            sample_results[estimator_name] = {
                "question": question,
                "answer": sample["answer"],
                "tokens_decoded_generated": sample["tokens_decoded_generated"],
                "raw_counts": estimator_sample["counts"],  # raw ds counts
                "counts": counts,  # these ara aligned with the vocab
                "vocab": vocab,
                "p_hat": p_hat,
                "p_star": p_star,
                "correct": correct,
                "epistemic_uncertainty": epistemic_uncertainty,
                "uq_estimators": {
                    uq_estimator: sample[uq_estimator] for uq_estimator in uq_estimators
                },
            }
        results.append(sample_results)
    return results


@dataclass
class ConcordanceResult:
    C: float  # concordance (Somers' D -> C = (D+1)/2)
    D: float  # Somers' D
    tau_a: float  # Kendall's tau-a
    tau_b: float  # Kendall's tau-b
    gamma: float  # Goodman–Kruskal gamma
    c: int  # concordant pairs
    d: int  # discordant pairs
    t_x: int  # tied on x (but not y)
    t_y: int  # tied on y (but not x)
    t_xy: int  # tied on both x and y
    n: int  # number of observations


def auc_cont(y, x) -> ConcordanceResult:
    """
    Concordance statistic per Therneau & Atkinson, 'Concordance' vignette,
    pages 1–2 (CRAN survival package):
      C = (c + 0.5 * t_x) / (c + d + t_x)
    where:
      c = #concordant pairs, d = #discordant pairs,
      t_x = #tied on x only, t_y = #tied on y only, t_xy = #tied on both.
    """
    x = np.asarray(x)
    y = np.asarray(y)
    if x.ndim != 1 or y.ndim != 1 or x.size != y.size:
        raise ValueError("x and y must be 1D arrays of the same length.")
    n = x.size
    if n < 2:
        raise ValueError("Need at least two observations.")

    # Pairwise differences; use upper triangle (i<j) to count each pair once.
    dx = x[:, None] - x[None, :]
    dy = y[:, None] - y[None, :]
    mask = np.triu(np.ones((n, n), dtype=bool), k=1)

    # Basic masks
    x_gt, x_lt, x_eq = dx > 0, dx < 0, dx == 0
    y_gt, y_lt, y_eq = dy > 0, dy < 0, dy == 0

    # Counts, restricted to i<j
    c = np.sum(((x_gt & y_gt) | (x_lt & y_lt)) & mask)
    d = np.sum(((x_gt & y_lt) | (x_lt & y_gt)) & mask)
    t_x = np.sum((x_eq & ~y_eq) & mask)  # tied on x only
    t_y = np.sum((y_eq & ~x_eq) & mask)  # tied on y only
    t_xy = np.sum((x_eq & y_eq) & mask)  # tied on both

    # Statistics per vignette (eqns 1–5)
    denom_tau_a = c + d + t_x + t_y + t_xy
    tau_a = (c - d) / denom_tau_a if denom_tau_a > 0 else np.nan

    denom_tau_b = (c + d + t_x) * (c + d + t_y)
    tau_b = (c - d) / np.sqrt(denom_tau_b) if denom_tau_b > 0 else np.nan

    denom_gamma = c + d
    gamma = (c - d) / denom_gamma if denom_gamma > 0 else np.nan

    denom_D = c + d + t_x
    D = (c - d) / denom_D if denom_D > 0 else np.nan

    C = (
        (D + 1) / 2 if np.isfinite(D) else np.nan
    )  # equivalently (c + 0.5*t_x)/(c + d + t_x)

    return ConcordanceResult(
        C=float(C),
        D=float(D),
        tau_a=float(tau_a),
        tau_b=float(tau_b),
        gamma=float(gamma),
        c=int(c),
        d=int(d),
        t_x=int(t_x),
        t_y=int(t_y),
        t_xy=int(t_xy),
        n=n,
    ).C
