"""
Functions for computing approximately stable lotteries for committee selection.
Based on Section 2.5 of the paper "Group fairness in committee selection".
"""

from collections import Counter
import math
import numpy as np
import sampler_classes
from comparison_helpers import find_favorite_candidate, cand_preferred_over_committee, voter_unsatisfied

def find_approx_stable_lottery(num_iters, committee_size, num_voters, voter_sampler, resample_every_iter=False, fix_samples=False, eps=0.05, seed=None, num_committee_trials=10):
    """
    Approximately stable lottery from Section 2.5 of paper "Group fairness in committee selection".

    Args:
      num_iters: number of iterations
      committee_size: size of the committee
      num_voters: number of voters to sample in each iteration
      voter_sampler: sampler for voters

    Returns:
        support: np.ndarray shape (s, k) of committees
        probs:   np.ndarray shape (s,)   of probabilities
    """
    rng = np.random.default_rng(seed)
    m = int(voter_sampler.num_candidates)
    k = int(committee_size)
    if k > m:
        raise ValueError(f"committee_size={k} cannot exceed number of candidates m={m}.")

    # MWU over m singleton attacker committees (i,)
    logw = np.zeros(m, dtype=float)
    eta = math.sqrt(2.0 * math.log(max(m, 2)) / float(num_iters))

    def attacker_probs():
        z = logw - logw.max()
        p = np.exp(z)
        return p / p.sum()

    fixed_samples = None
    if fix_samples:
        fixed_samples = voter_sampler.sample(num_voters)

    def gain_vector_vs(Sd):
        # return a length-m vector where the i-th element is Pr_v[i ≻ Sd]   under Ranking prefs
        voter_samples = fixed_samples
        if not fix_samples:
            voter_samples = voter_sampler.sample(num_voters)
        pref = np.zeros(m)
        for i in range(m):
            for voter in voter_samples:
                if cand_preferred_over_committee(voter, i, Sd):
                    pref[i] += 1
        return pref / num_voters

    defender_draws = []
    for _ in range(num_iters):
        p = attacker_probs()
        att_dist = {(i,): float(p[i]) for i in range(m)}  # Δa over singletons

        Sd = oracle_defender_committee(
            attacker_dist=att_dist,
            voter_sampler=voter_sampler,
            fixed_samples=fixed_samples,
            K=k,
            num_committee_trials=num_committee_trials,
            num_voters_eval=num_voters,
            resample_every_iter=resample_every_iter      
            # seed=int(rng.integers(0, 2**32 - 1)),
        )
        defender_draws.append(Sd)

        logw += eta * gain_vector_vs(Sd)

    cnt = Counter(defender_draws)
    support = np.array(list(cnt.keys()), dtype=int)          # (s, k)
    probs = np.array([c / num_iters for c in cnt.values()])  # (s,)
    return support, probs


def oracle_defender_committee(attacker_dist, voter_sampler, K, num_committee_trials, num_voters_eval, fixed_samples=None, resample_every_iter=False):
    """
    Oracle defender committee from Section 2.5 of paper "Group fairness in committee selection".
    returns a size-K committee that is (1+eps)-stable against attacker_dist
    """
    rng = np.random.default_rng()
    
    Sd = dependent_rounding(attacker_dist, K)
    best_committee = None
    best_attacked_frequency = 100
    voter_samples = fixed_samples
    if fixed_samples is None:
        voter_samples = voter_sampler.sample(num_voters_eval)
    for _ in range(num_committee_trials):
        committee = Sd()
        attacked_frequency = 0
        # covert attacker_dist to a sampler function
        attacker_dist_sampler = lambda : rng.choice(list(attacker_dist.keys()), p=[attacker_dist[k] / sum(attacker_dist.values()) for k in attacker_dist.keys()])
        if resample_every_iter:
            voter_samples = voter_sampler.sample(num_voters_eval)
        for voter in voter_samples:
            attacker_candidate = attacker_dist_sampler()
            # flatten if key is tuple
            cand = attacker_candidate[0] if isinstance(attacker_candidate, tuple) else attacker_candidate
            if cand_preferred_over_committee(voter, cand, committee, strict=True):
                attacked_frequency += 1
        attacked_frequency_ratio = attacked_frequency / num_voters_eval
        if attacked_frequency_ratio < best_attacked_frequency:
            best_attacked_frequency = attacked_frequency_ratio
            best_committee = committee
    return best_committee

def dependent_rounding(attacker_dist, K):
    """
    Dependent rounding from Gandhi et al. (pairwise dependent rounding).

    Args:
      attacker_dist: attacker distribution over singleton candidates.
        Accepts either:
          - {i: prob, ...}  or
          - {(i,): prob, ...}
        (Does not need to sum to 1)
      K: size of the committee

    Returns:
      Sd: defender lottery over size-K committees (represented as a sampler function).
          Call Sd() to draw one size-K committee (sorted tuple of candidate indices).
    """
    if K <= 0:
        raise ValueError("K must be positive.")
    if not attacker_dist:
        raise ValueError("attacker_dist must be non-empty.")

    # --- normalize attacker distribution and infer m ---
    keys = list(attacker_dist.keys())
    def _to_idx(key):
        return int(key[0]) if isinstance(key, tuple) else int(key)

    m = max(_to_idx(k) for k in keys) + 1
    if K > m:
        raise ValueError(f"K={K} cannot exceed number of candidates m={m}.")

    total = float(sum(attacker_dist.values()))
    if total <= 0:
        raise ValueError("attacker_dist must have positive total weight.")

    p = np.zeros(m, dtype=float)
    for key, w in attacker_dist.items():
        p[_to_idx(key)] += float(w) / total  # p is a probability vector

    # --- probability matching marginals for L=1: q_i = min(1, K * p_i) ---
    alpha = np.minimum(1.0, K * p).astype(float)

    # Fill alpha up to sum=K (keeping alpha_i <= 1). Any fill order is fine.
    deficit = float(K) - float(alpha.sum())
    if deficit > 1e-12:
        slack = 1.0 - alpha
        for i in np.argsort(-slack):  # fill larger slack first
            if deficit <= 1e-12:
                break
            add = min(deficit, slack[i])
            alpha[i] += add
            deficit -= add

    # Numerical cleanup (optional)
    alpha = np.clip(alpha, 0.0, 1.0)
    # alpha *= (K / alpha.sum())  # keep sum exactly K (tiny adjustment)

    rng = np.random.default_rng()

    def Sd():
        """Sample one size-K committee from Gandhi dependent rounding on alpha."""
        x = alpha.copy()

        # indices with fractional values
        frac = [i for i in range(m) if 1e-12 < x[i] < 1.0 - 1e-12]

        while len(frac) >= 2:
            i = frac.pop()
            j = frac.pop()

            xi, xj = x[i], x[j]
            d_plus  = min(1.0 - xi, xj)
            d_minus = min(xi, 1.0 - xj)
            denom = d_plus + d_minus

            if denom > 1e-15:
                if rng.random() < (d_minus / denom):
                    x[i] = xi + d_plus
                    x[j] = xj - d_plus
                else:
                    x[i] = xi - d_minus
                    x[j] = xj + d_minus

            # re-add if still fractional
            for idx in (i, j):
                if 1e-12 < x[idx] < 1.0 - 1e-12:
                    frac.append(idx)

        # Convert to an exact-K set (robust to small float drift)
        chosen = tuple(sorted(int(i) for i in np.argsort(-x)[:K]))
        return chosen

    return Sd

