import numpy as np
from comparison_helpers import find_favorite_candidate, cand_preferred_over_committee, voter_unsatisfied, estimate_voter_rank, cand_preferred_over_committee_full_rank
from stable_lottery import find_approx_stable_lottery, oracle_defender_committee, dependent_rounding

def remove_candidates_from_lottery(candidates_to_remove, lottery_support, lottery_probs):
    fixed_lottery_support = []
    fixed_lottery_probs = []
    for i, lottery in enumerate(lottery_support):
        fixed_lottery = list(set(lottery) - candidates_to_remove)
        if len(fixed_lottery) > 0:
            fixed_lottery_support.append(fixed_lottery)
            fixed_lottery_probs.append(lottery_probs[i])
    # Renormalize lottery probs
    fixed_lottery_probs = fixed_lottery_probs/sum(fixed_lottery_probs)
    return fixed_lottery_support, fixed_lottery_probs

def iterated_rounding_pairwise(
    num_iters, 
    alpha, 
    beta, 
    committee_size, 
    num_voters, 
    voter_sampler, 
    population_prob_num_voters=300, 
    num_voters_best_S = None,
    use_stopping_prob=False, 
    stopping_prob=0.01, 
    best_S=True,
    always_add_one=False,
    approx_stable_lottery_num_iters=None,
    num_committee_trials=30,
    remove_accepted_candidates_from_lotteries=False,
    fix_samples_for_lottery=False):
    """Performs iterated rounding from pairwise comparisons.
    
    Args: 
      num_iters: number of iterations
      beta: float, tolerance parameter
      committee_size: int, size of committee
      num_voters: number of voters to sample each round
      voter_sampler: voter sampler class
    """
    if num_voters_best_S is None:
        num_voters_best_S = num_voters

    if approx_stable_lottery_num_iters is None:
        approx_stable_lottery_num_iters = num_iters

    full_T = set() # Final committee to output
    full_T_ordered = []
    
    # Placeholders for all previous committees and lotteries
    all_prev_Ss = []
    all_prev_lotteries = []
    
    voter_sampler.beta = beta
    voter_sampler.all_prev_Ss = all_prev_Ss
    voter_sampler.all_prev_lotteries = all_prev_lotteries

    K = int((1 - alpha) * committee_size)
    if always_add_one:
        K = 1
    
    for t in range(num_iters):
        print("Iter", t)
        print("Max pairwise comparisons:", voter_sampler.max_pairwise_comparisons())
        
        remaining_K = committee_size - len(full_T)
        if remaining_K == 0:
            print("Committee size limit reached, breaking")
            break
            
        cur_K = K
        if cur_K == 0:
            # Take remainder of committee size
            cur_K = committee_size - len(full_T)
        
        voter_sampler.all_prev_Ss = all_prev_Ss
        voter_sampler.all_prev_lotteries = all_prev_lotteries
        next_lottery_support, next_lottery_probs = find_approx_stable_lottery(
            approx_stable_lottery_num_iters, 
            cur_K, 
            num_voters, 
            voter_sampler,
            num_committee_trials=num_committee_trials,
            fix_samples=fix_samples_for_lottery)
        print("Lottery:\n", next_lottery_support, next_lottery_probs)
        print("Max pairwise comparisons:", voter_sampler.max_pairwise_comparisons())
        print("total voters sampled:", voter_sampler.total_samples)
        
        # Remove already included candidates from lottery
        if remove_accepted_candidates_from_lotteries:
            next_lottery_support, next_lottery_probs = remove_candidates_from_lottery(full_T, next_lottery_support, next_lottery_probs)
            print("Fixed lottery:\n", next_lottery_support, next_lottery_probs)

        next_S = find_satisfying_S(
            beta,
            num_voters_best_S, 
            voter_sampler,                      
            next_lottery_support,          
            next_lottery_probs, 
            strict=False, 
            best=best_S,
            verbose=False)
        print("next_S", next_S)
        print("Max pairwise comparisons:", voter_sampler.max_pairwise_comparisons())
        print("total voters sampled:", voter_sampler.total_samples)
        
        # Update placeholders
        K = int(alpha * K)
        if always_add_one:
            K = 1
        
        full_T_ordered.extend(list(set(next_S) - full_T))
        full_T.update(next_S)
        all_prev_Ss.append(next_S)
        all_prev_lotteries.append((next_lottery_support, next_lottery_probs))
        
        
        if use_stopping_prob:
            population_prob_unsatisfied = estimate_population_prob_unsatisfied(
                beta, 
                population_prob_num_voters, 
                voter_sampler, 
                all_prev_Ss, 
                all_prev_lotteries, 
                strict=False)
            print("Population probability unsatisfied", population_prob_unsatisfied)
            if population_prob_unsatisfied < stopping_prob:
                print("Population satisfied, stopping")
                break
    return full_T, all_prev_Ss, full_T_ordered


# Functions for finding satisfying S

def estimate_population_unsat_rate(beta, num_voters, voter_sampler, S, lottery_support, lottery_probs, strict=True, verbose=False):
    """Estimates P_v(Rank(v;S,Delta) <= beta), or the proportion of unsatisfied voters.
    
    Args:
      beta: tolerance parameter
      num_voters: number of voters to sample from population
      voter_sampler: voter sampler class
      S: input committee for which to compute estimate, numpy array of length k
      lottery_support: list of committees as numpy array of shape (number of committees, k)
      lottery_probs: list of probabilities of the same length as lottery_support
    
    Returns: float, proportion of unsatisfied voters for input committee S.
    """
    # Sample voters from population.
    sampled_voters = voter_sampler.sample(num_voters, verbose=False, strict=strict)
    
    num_voters_below_beta = 0
    for voter in sampled_voters:
        rank = estimate_voter_rank(voter, S, lottery_support, lottery_probs, strict=strict)
        if verbose:
            print("voter", voter.full_rank)
            print("rank", rank)
        if rank <= beta:
            num_voters_below_beta += 1
    return num_voters_below_beta/num_voters


def find_satisfying_S(beta, num_voters, voter_sampler, lottery_support, lottery_probs, strict=True, verbose=False, best=False):
    """Finds a committee S such that P_v(Rank(v;S,Delta) <= beta) <= beta.
   
    Args:
      beta: tolerance parameter
      num_voters: number of voters to sample from population
      voter_sampler: voter sampler class
      lottery_support: list of committees as numpy array of shape (number of committees, k)
      lottery_probs: list of probabilities of the same length as lottery_support
      best: if best, find the S that minimizes the unsat rate.
    
    Returns: 
      S: a single committee as a numpy array of length k.
    """
    best_unsat_rate = 1
    best_S = None
    for S in lottery_support:
        unsat_rate = estimate_population_unsat_rate(beta, num_voters, voter_sampler, S, lottery_support, lottery_probs, strict=strict, verbose=verbose)
        if verbose:
            print("S", S)
            print("unsat_rate", unsat_rate)
        if unsat_rate <= beta:
            if not best:
                return S
        if unsat_rate <= best_unsat_rate:
            best_unsat_rate = unsat_rate
            best_S = S
    if verbose:
        print("best S:", best_S)
        print("best unsat rate:", best_unsat_rate)
    return best_S


def estimate_population_prob_unsatisfied(beta, num_voter_samples, voter_sampler, all_prev_Ss, all_prev_lotteries, strict=True, verbose=False):
    """Samples unsatisfied voters using rejection sampling."""
    sampled_voters = voter_sampler.sample_original(num_voter_samples)
    num_unsatisfied = 0
    for voter in sampled_voters: 
        if verbose:
            print('voter', voter.full_rank)
        if voter_unsatisfied(
            beta, 
            voter,
            all_prev_Ss, 
            all_prev_lotteries, 
            strict=strict, 
            verbose=verbose):            
            if verbose:
                print('unsatisfied')
            num_unsatisfied += 1

    return num_unsatisfied/num_voter_samples


# Functions for evalution
def estimate_approx_factor(S, K, num_candidates, voter_sampler, num_voter_samples, fixed_samples=None, verbose=False):
    sampled_voters = fixed_samples
    if fixed_samples is None:
        sampled_voters = voter_sampler.sample_original(num_voter_samples)
    else:
        num_voter_samples = len(fixed_samples)
    
    if verbose:
        print("sampled_voters:\n", [sampled_voter.full_rank for sampled_voter in sampled_voters])
    # Filter for candidates not in S
    candidates_not_in_S = set(range(num_candidates)) - set(S)
    if verbose:
        print("candidates_not_in_S:", candidates_not_in_S)
    c_preferred_probs = []
    for c in candidates_not_in_S:
        # Get proportion of voters that prefer c over S
        num_voters_preferring_c = 0
        for voter in sampled_voters:
            if cand_preferred_over_committee(voter, c, S, strict=True):
                num_voters_preferring_c += 1
        c_preferred_prob = num_voters_preferring_c/num_voter_samples
        c_preferred_probs.append(c_preferred_prob)
    if verbose: 
        print("c_preferred_probs:", c_preferred_probs)
    return max(c_preferred_probs) * K


def cand_preferred_over_committee_full_rank_helper(voter_values_as_ranks, cand, S):
    """Returns True iff cand is preferred over all candidates in S. """
    # Get max voter rank in S
    max_rank_S = len(voter_values_as_ranks)
    for S_cand in S:
        rank_S = voter_values_as_ranks[S_cand]
        if max_rank_S >= rank_S:
            max_rank_S = rank_S
    rank_cand = voter_values_as_ranks[cand]
    return rank_cand < max_rank_S


def estimate_approx_factor_full_ranks(S, K, num_candidates, voter_values_as_ranks, verbose=False):    
    num_voter_samples = len(voter_values_as_ranks)

    # Filter for candidates not in S
    candidates_not_in_S = set(range(num_candidates)) - set(S)
    if verbose:
        print("candidates_not_in_S:", candidates_not_in_S)
    c_preferred_probs = []
    for c in candidates_not_in_S:
        # Get proportion of voters that prefer c over S
        num_voters_preferring_c = 0
        for voter_rank in voter_values_as_ranks:
            # need voter rank for everything in S
            if cand_preferred_over_committee_full_rank_helper(voter_rank, c, S):
                num_voters_preferring_c += 1
        c_preferred_prob = num_voters_preferring_c/num_voter_samples
        c_preferred_probs.append(c_preferred_prob)
    if verbose: 
        print("c_preferred_probs:", c_preferred_probs)
    return max(c_preferred_probs) * K
