import numpy as np

import torch
from torch import digamma, lgamma


def aleatoric_uncertainty(alpha):
    """
    Compute strict Aleatoric uncertainty for a batch.

    alpha: Tensor of shape [batch_size, n_classes], Dirichlet parameters
    returns: Tensor of shape [batch_size, 1]
    """
    S = torch.sum(alpha, dim=1, keepdim=True)  # total evidence
    p = alpha / S  # normalized class probabilities
    ale = torch.sum(p * (digamma(S + 1) - digamma(alpha + 1)), dim=1, keepdim=True)
    return ale


def epistemic_uncertainty(alpha):
    """
    Compute Epistemic uncertainty for a batch using Dirichlet differential entropy.

    alpha: Tensor of shape [batch_size, n_classes], Dirichlet parameters
    returns: Tensor of shape [batch_size, 1]
    """
    S = torch.sum(alpha, dim=1, keepdim=True)
    term1 = torch.sum(lgamma(alpha), dim=1, keepdim=True) - lgamma(S)
    term2 = torch.sum((alpha - 1) * (digamma(alpha) - digamma(S)), dim=1, keepdim=True)
    epi = term1 - term2
    return epi


def update_client_stats(
        client_id: int,
        round_idx: int,
        client_stats: dict,
        client_last_update: dict,
        new_stats: np.ndarray,
        beta: float = 0.8
):
    n_classes = new_stats.shape[0]

    # If client has no previous stats, initialize directly
    if client_id not in client_stats:
        client_stats[client_id] = new_stats.copy()
        client_last_update[client_id] = round_idx
        return

    # Compute effective decay factor considering rounds since last update
    rounds_since_last = max(1, round_idx - client_last_update[client_id])
    effective_beta = beta ** rounds_since_last

    # Extract old stats
    old_stats = client_stats[client_id]

    # EMA update for evidence and uncertainty; sample counts are updated directly
    updated_stats = old_stats.copy()
    updated_stats[:, :n_classes + 4] = effective_beta * old_stats[:, :n_classes + 4] + (1 - effective_beta) * new_stats[
                                                                                                              :,
                                                                                                              :n_classes + 4]
    # updated_stats[:, :n_classes] = effective_beta * old_stats[:, :n_classes] + (1 - effective_beta) * new_stats[:, :n_classes]
    # updated_stats[:, n_classes] = new_stats[:, n_classes]  # sample counts
    # updated_stats[:, n_classes + 1] = effective_beta * old_stats[:, n_classes + 1] + (1 - effective_beta) * new_stats[:, n_classes + 1]
    # updated_stats[:, n_classes + 2] = effective_beta * old_stats[:, n_classes + 2] + (1 - effective_beta) * new_stats[:,
    #                                                                                                         n_classes + 2]
    # updated_stats[:, n_classes + 3] = effective_beta * old_stats[:, n_classes + 3] + (1 - effective_beta) * new_stats[:,
    #                                                                                                         n_classes + 3]
    # Store updated stats
    client_stats[client_id] = updated_stats
    client_last_update[client_id] = round_idx


def evi_agg_weights_(client_ids, n_classes, client_stats):
    """
    Aggregation weights based on evidence concentration, aleatoric + epistemic uncertainty,
    and relative sample contribution.
    """

    # Step 1: Compute global baseline for normalization
    global_sample_sum = 0
    client_count = 0
    global_unc_sum, global_ale_sum, global_epi_sum = 0.0, 0.0, 0.0

    for cid in client_ids:
        if cid in client_stats:
            stats = client_stats[cid]
            global_sample_sum += stats[:, n_classes].sum()
            global_unc_sum += stats[:, n_classes + 1].mean()
            global_ale_sum += stats[:, n_classes + 2].mean()
            global_epi_sum += stats[:, n_classes + 3].mean()
            client_count += 1

    if client_count == 0:
        return [1.0 / len(client_ids)] * len(client_ids)

    global_total_samples = global_sample_sum
    global_avg_unc = global_unc_sum / client_count
    global_avg_ale = global_ale_sum / client_count
    global_avg_epi = global_epi_sum / client_count

    scores = []
    for cid in client_ids:
        if cid in client_stats:
            stats = client_stats[cid]
            client_score = 1e-6  # avoid zero

            concentration_sum = 0.0
            ale_sum = 0.0
            epi_sum = 0.0
            unc_sum = 0.0

            for i in range(n_classes):
                row_evidence = stats[i, :n_classes]
                total_evidence = row_evidence.sum() + 1e-8

                # A. Evidence concentration score (diagonal / total evidence)
                diag_evidence = stats[i, i]
                concentration_score = diag_evidence / total_evidence
                concentration_sum += concentration_score

                # B. Uncertainty-based scores
                unc = stats[i, n_classes + 1]
                unc_sum += unc

                # C. Aleatoric penalty
                ale = stats[i, n_classes + 2]
                ale_sum += ale

                # D. Epistemic penalty
                epi = stats[i, n_classes + 3]
                epi_sum += epi

            R_unc = global_avg_unc / (unc_sum + 1e-8)
            R_ale = global_avg_ale / (ale_sum + 1e-8)
            R_epi = global_avg_epi / (epi_sum + 1e-8)
            client_score = concentration_sum * R_unc * R_ale * R_epi
            scores.append(client_score)
        else:
            scores.append(1.0)  # default weight for new client

    # Step 3: Normalize scores to weights
    total_score = sum(scores)
    if total_score > 0:
        weights = [s / total_score for s in scores]
    else:
        weights = [1.0 / len(scores)] * len(scores)

    return weights

    # Step 2: Compute quality score for each client
    # scores = []
    # for cid in client_ids:
    #     if cid in client_stats:
    #         stats = client_stats[cid]
    #         client_score = 1e-6  # avoid zero
    #
    #         for i in range(n_classes):
    #             row_evidence = stats[i, :n_classes]
    #             total_evidence = row_evidence.sum() + 1e-8
    #
    #             # A. Evidence concentration score (diagonal / total evidence)
    #             diag_evidence = stats[i, i]
    #             concentration_score = diag_evidence / total_evidence
    #
    #             unc = stats[i, n_classes + 1]
    #             unc_score = global_unc_ale / (unc + 1e-8)
    #
    #             # B. Aleatoric penalty (lower is better)
    #             ale = stats[i, n_classes + 2]
    #             ale_score = global_avg_ale / (ale + 1e-8)
    #
    #             # C. Epistemic penalty (lower is better)
    #             epi = stats[i, n_classes + 3]
    #             epi_score = global_avg_epi / (epi + 1e-8)
    #
    #             # D. Sample contribution (sqrt scaling to avoid dominance)
    #             # sample_count = stats[i, n_classes]
    #             # sample_score = np.sqrt(sample_count / (global_total_samples / (n_classes * client_count) + 1e-8))
    #
    #             # E. Class-level quality factor
    #             class_quality = concentration_score * ale_score * epi_score * unc_score
    #             client_score += class_quality
    #
    #         scores.append(client_score)
    #     else:
    #         scores.append(1.0)  # default weight for new client
    #
    # # Step 3: Normalize scores to weights
    # total_score = sum(scores)
    # if total_score > 0:
    #     weights = [s / total_score for s in scores]
    # else:
    #     weights = [1.0 / len(scores)] * len(scores)
    #
    # return weights


def evi_agg_weights(client_ids, n_classes, client_stats):
    """
    Evidence-weighted aggregation based on:
      - Evidence concentration (Q^(k))
      - Relative aleatoric uncertainty (R_ale^(k))
      - Relative epistemic uncertainty (R_epi^(k))
      - Relative total uncertainty (R_tot^(k))
    """

    # Step 1: Compute global sums for normalization (Σ_j Σ_i ...)
    global_unc_sum = 0.0
    global_ale_sum = 0.0
    global_epi_sum = 0.0
    global_tot_sum = 0.0
    client_count = 0

    for cid in client_ids:
        if cid in client_stats:
            stats = client_stats[cid]
            global_unc_sum += stats[:, n_classes + 1].sum()
            global_ale_sum += stats[:, n_classes + 2].sum()
            global_epi_sum += stats[:, n_classes + 3].sum()
            global_tot_sum += stats[:, n_classes + 1].sum()  # same as unc_sum for consistency
            client_count += 1

    if client_count == 0:
        return [1.0 / len(client_ids)] * len(client_ids)

    # Step 2: Compute client-level reliability score s^(k)
    scores = []
    eps = 1e-8

    for cid in client_ids:
        if cid in client_stats:
            stats = client_stats[cid]

            # --- (1) Evidence Concentration Q^(k)
            concentration_sum = 0.0
            for i in range(n_classes):
                row_evidence = stats[i, :n_classes]
                total_evidence = row_evidence.sum() + eps
                diag_evidence = stats[i, i]
                q_i = diag_evidence / total_evidence
                concentration_sum += q_i
            Q_k = concentration_sum / n_classes  # same as Eq. (q_i avg over N)

            # --- (2) Uncertainty sums for R terms
            unc_sum = stats[:, n_classes + 1].sum()  # total uncertainty
            ale_sum = stats[:, n_classes + 2].sum()
            epi_sum = stats[:, n_classes + 3].sum()

            # --- (3) Relative ratios (R terms)
            R_ale = global_ale_sum / (ale_sum + eps)
            R_epi = global_epi_sum / (epi_sum + eps)
            R_tot = global_tot_sum / (unc_sum + eps)

            # --- (4) Combined reliability s^(k)
            s_k = Q_k * R_ale * R_epi * R_tot
            scores.append(s_k)
        else:
            scores.append(1.0)  # default for new or missing clients

    # Step 3: Normalize s^(k) to obtain w^(k)
    total_score = sum(scores)
    if total_score > 0:
        weights = [s / total_score for s in scores]
    else:
        weights = [1.0 / len(scores)] * len(scores)

    return weights
