"""Selection algorithms for positional representation and positional proportionality."""

import cvxpy as cp
import math
import numpy as np
import pandas as pd


########## Data processing ##########


def scores_to_ranks(scores_df, all_tasks, tie_break_method='max'):
    """Converts dataframe of scores to ranks.
    
    Args: 
      scores_df: dataframe where each row represents a candidate, and each column represents a benchmark. 
        scores_df[i].index[j] = the score that voter i gives candidate j. Higher scores are better.
    
    Returns: ranks_df: ranks_df[i].index[j] = the rank that voter i gives candidate j. Lower ranks are better. 1-indexed.
    """
    scores_df[all_tasks] = scores_df[all_tasks].apply(pd.to_numeric)
    return scores_df.rank(ascending=False, method=tie_break_method, numeric_only=True)


########## IP for Positional Proportionality ##########


def construct_A_matrix_pp(ranks_df, alternative_index):
    m = len(ranks_df)
    a_row = ranks_df.loc[alternative_index]
    A = []
    for r in range(1,m):
        A.append((a_row <= r).astype(int).to_numpy())
    return np.array(A)

def construct_all_bs_pp(ranks_df, all_tasks):
    """Returns a dataframe where each row i is b for alternative i."""
    m = len(ranks_df)
    n = len(all_tasks)
    bs = []
    for r in range(1,m):
        bs.append((((ranks_df[all_tasks] <= r).sum(axis=1)) / n).to_numpy())
    return np.array(bs).T

def solve_lp_pp(ranks_df, all_tasks, epsilons, verbose=True, solver='CPLEX', cplex_params=None, starting_tasks=None):
    """Solves integer program for Positional Proportionality all epsilons."""
    m = len(ranks_df)
    n = len(all_tasks)
    all_tasks_array = np.array(all_tasks)

    x = cp.Variable(n, boolean=True)

    print("Constructing program parameters")
    A_matrices = []
    for i in range(m):
        A_i = construct_A_matrix_pp(ranks_df, alternative_index=i)
        A_matrices.append(A_i)
    A_ones = np.ones(A_matrices[0].shape)
    bs = construct_all_bs_pp(ranks_df, all_tasks)
    c = np.ones(n)

    if starting_tasks is not None:
        c_starting = np.array([[1 if task in starting_tasks else 0 for task in all_tasks]])
        num_starting = len(starting_tasks)

    # Construct constraint set
    x_results = []
    optimal_values = []
    
    for epsilon in epsilons:
        print("Constructing constraint set for epsilon %.2f" % epsilon)
        eps_vec = epsilon * np.ones(m-1)
        constraints = [c.T@x >= 1]
        if starting_tasks is not None:
            constraints.append(c_starting@x >= num_starting)
        for i in range(m):
            # Add a constraint for each alternative
            constraints.append(A_matrices[i] @ x >= (cp.multiply((bs[i] - eps_vec), (A_ones @ x))))
            constraints.append(A_matrices[i] @ x <= (cp.multiply((bs[i] + eps_vec), (A_ones @ x))))

        # Define and solve the CVXPY problem.
        print("Solving")
        prob = cp.Problem(cp.Minimize(c.T@x), constraints)

        if cplex_params is not None:
            prob.solve(verbose=verbose, solver=solver, cplex_params=cplex_params)
        else:
            prob.solve(verbose=verbose, solver=solver)

        print("Optimal value: %s" % prob.value)
        for variable in prob.variables():
            print("Variable %s: value %s" % (variable.name(), variable.value))
            x_result = variable.value
        print("Tasks:", all_tasks_array[x_result > 0.1])
        x_results.append(x_result)
        optimal_values.append(prob.value)
    return optimal_values #, x_results


########## IP for Positional Representation ##########


def construct_A_matrix_pr(ranks_df, alternative_index):
    m = len(ranks_df)
    a_row = ranks_df.loc[alternative_index]
    A = []
    for r in range(1,m):
        A.append((a_row <= r).astype(int).to_numpy())
    return np.array(A)

def construct_b_pr(ranks_df, all_tasks, alternative_index, g):
    b = []
    for r in range(1,m):
        b.append(math.floor(selection_algorithms.top_r_count(ranks_df, all_tasks, r, alternative_index) / g))
    return np.array(b)

def construct_all_bs_pr(ranks_df, all_tasks, g):
    """Returns a dataframe where each row i is b for alternative i."""
    m = len(ranks_df)
    bs = []
    for r in range(1,m):
        bs.append(np.floor(((ranks_df[all_tasks] <= r).sum(axis=1)) / g).to_numpy())
    return np.array(bs).T

def solve_lp_pr(ranks_df, all_tasks, group_sizes, verbose=True, solver='CPLEX', cplex_params=None, starting_tasks=None):
    """Solves integer program for Positional Representation for all group_sizes."""
    m = len(ranks_df)
    n = len(all_tasks)
    all_tasks_array = np.array(all_tasks)

    x = cp.Variable(n, boolean=True)

    A_matrices = []
    for i in range(m):
        A_i = construct_A_matrix_pr(ranks_df, alternative_index=i)
        A_matrices.append(A_i)

    if starting_tasks is not None:
        c_starting = np.array([[1 if task in starting_tasks else 0 for task in all_tasks]])
        num_starting = len(starting_tasks)
        
    # Construct constraint set
    x_results = []
    optimal_values = []
    
    for g in group_sizes:
        print("Constructing constraint set for group size %d" % g)
        constraints = []
        if starting_tasks is not None:
            constraints.append(c_starting@x >= num_starting)
        bs = construct_all_bs_pr(ranks_df, all_tasks, g=g)
        for i in range(m):
            constraints.append(A_matrices[i] @ x >= bs[i])

        # Define and solve the CVXPY problem.
        print("Solving")
        c = np.ones(n)
        prob = cp.Problem(cp.Minimize(c.T@x), constraints)

        if cplex_params is not None:
            prob.solve(verbose=verbose, solver=solver, cplex_params=cplex_params)
        else:
            prob.solve(verbose=verbose, solver=solver)

        print("Optimal value: %s" % prob.value)
        for variable in prob.variables():
            print("Variable %s: value %s" % (variable.name(), variable.value))
            x_result = variable.value
        print("Tasks:", all_tasks_array[x_result > 0.1])
        x_results.append(x_result)
        optimal_values.append(prob.value)
    return optimal_values #, x_results


########## Greedy ##########


def run_greedy_all_gs(scores_df, all_tasks, starting_tasks=None, verbose=False):
    """Runs greedy algorithm for all possible group sizes g.
    
    Args:
      scores_df: dataframe where each row represents a candidate, and each column represents a benchmark. 
        scores_df[i].index[j] = the score that voter i gives candidate j. Higher scores are better.
      all_tasks: list of all tasks to consider.
      starting_tasks: list of strings representing names of benchmark columns to start with in the final set.

    Returns:
      final_K_lengths: List of sizes of the final subset |K| for each group size
      group_sizes: List of corresponding group sizes

    """
    final_K_lengths = []
    group_sizes = range(1, len(all_tasks))
    ranks_df = scores_to_ranks(scores_df, all_tasks)
    for g in group_sizes:
        print('Group size:', g)
        K = greedy_selection_from_ranks(g, ranks_df[all_tasks], all_tasks, starting_tasks=starting_tasks, verbose=verbose) 
        print(K)
        print('|K| =', len(K))
        final_K_lengths.append(len(K))
    return final_K_lengths, group_sizes


def greedy_selection_from_ranks(g, ranks_df, voter_columns, starting_tasks=None, verbose=True):
    """Runs Greedy selection algorithm.
    Args:
      ranks_df: dataframe where each row represents a candidate, and each column represents a benchmark. 
        ranks_df[i].index[j] = the rank that voter i gives candidate j. Lower ranks are better.
      voter_columns: list of strings representing names of benchmark columns.
      starting_tasks: list of strings representing names of benchmark columns to start with in the final set.
    
    Returns: list of strings which is a subset of voter_columns.
    """
    m = len(ranks_df) # Number of alternatives
    n = len(voter_columns) # Number of benchmarks
    
    # Construct colors
    S = [[] for _ in range(m)] # S[j] is the set of benchmarks that approve alternative j, but do not have a color.
    C = [set() for _ in range(n)] # C[i] is the set of colors assigned to benchmark i.
    c = 1
    
    # Get all rank-sorted voter columns
    sorted_voter_columns = []
    for i in range(n):
        sorted_voter_columns.append(ranks_df[voter_columns[i]].sort_values().dropna().index) # Candidates in rank-sorted order.

    for r in range(m): # Iterate over ranks
        # print("rank", r)
        for i in range(n): # Iterate over voters
            # print("voter", i)    
            # print(sorted_voter_columns[i])
            if len(sorted_voter_columns[i]) < r + 1: 
                continue;
            j = sorted_voter_columns[i][r] # alternative that voter i ranks at rank r
            S[j].append(i)
            # print(S)
            # Add colors if S[j] reaches group size
            if len(S[j]) == g:
                # print("Adding color", c)
                for voter in S[j]:
                    C[voter].add(c)
                c += 1
                S[j] = []
                # print("Colors:", C)
    if verbose:
        print("Coloring:", C)
    color_counts = [len(Ci) for Ci in C]


    # Select subset
    all_colors = set(range(1,c))
    K = []

    def remove_colors(voter_index, all_colors_input, C_input):
        all_colors_input = all_colors_input - C_input[voter_index]
        for voter in range(n):
            if voter != voter_index:
                C_input[voter] = C_input[voter] - C_input[voter_index]
        C_input[voter_index] = C_input[voter_index] - C_input[voter_index]
        return all_colors_input, C_input

    # Remove starting task colors
    if starting_tasks is not None:
        for i in range(n):
            if voter_columns[i] in starting_tasks:
                # Remove colors for this starting voter.
                K.append(i)
                all_colors, C = remove_colors(i, all_colors, C)

    max_iters = len(color_counts)
    iters = 0
    while len(all_colors) > 0:
        if verbose:
            print('t=%d' % iters)
            print('color_counts', color_counts)
            print('Q_t for t=%d:' % iters, len(all_colors))
        iters += 1
        # Get voter with the most colors
        i = color_counts.index(max(color_counts))
        K.append(i)
        if verbose:
            print('Chosen colors to remove from index %d:' % i, C[i])
        # Remove colors from rest of coloring
        all_colors, C = remove_colors(i, all_colors, C)
        # all_colors = all_colors - C[i]
        # Q_t = len(all_colors)
        # for voter in range(n):
        #     if voter != i:
        #         C[voter] = C[voter] - C[i]
        # C[i] = C[i] - C[i]
        color_counts = [len(Ci) for Ci in C]

        if iters >= max_iters:
            break
    if len(all_colors) > 0:
        print("Warning: Not all colors were covered!")
    
    if verbose:
        print("Final subset:", K)
        print("Length of final subset:", len(K))
    
    return np.array(voter_columns)[K]


########## Evaluation ##########


def top_r_count(ranks_df, voter_columns, r, alternative_index):
    """Equivalent to the C(N,r,a) function.
    Args:
      voter_columns: equivalent to N
      r: Maximum rank
      alternative_index: index in ranks_df indiciating alternative.
      
    Returns: Number of voters that rank alternative_index in the top r.
    """
    return (ranks_df[voter_columns].loc[alternative_index] <= r).sum()

def weak_pp_loss(g, ranks_df, voter_columns, subset_voter_columns, verbose=True):
    """Returns number of ranks and alternatives for which PR is violated."""
    m = len(ranks_df)
    loss = 0
    for r in range(m):
        for alternative_index in range(m):
            diff = math.floor(top_r_count(ranks_df, voter_columns, r, alternative_index) / g) - top_r_count(ranks_df, subset_voter_columns, r, alternative_index)
            if diff > 0:
                if verbose:
                    print('alternative_index', alternative_index)
                    print('r', r)
                    print('C(N,r,a)', top_r_count(ranks_df, voter_columns, r, alternative_index))
                    print('C(K,r,a)', top_r_count(ranks_df, subset_voter_columns, r, alternative_index))
                loss += 1
    return loss


def weak_pp_loss_fast(g, ranks_df, voter_columns, subset_voter_columns, verbose=False):
    """Returns number of ranks and alternatives for which PR is violated."""
    m = len(ranks_df)
    loss = 0
    for r in range(1,m):
        all_top_r_counts = np.floor(((ranks_df[voter_columns] <= r).sum(axis=1)) / g).to_numpy()
        subset_top_r_counts = ((ranks_df[subset_voter_columns] <= r).sum(axis=1)).to_numpy()
        diffs = all_top_r_counts - subset_top_r_counts
        if verbose:
            print('r', r)
            print('all_top_r_counts', all_top_r_counts)
            print('subset_top_r_counts', subset_top_r_counts)
            print('diffs', diffs)
        loss += (diffs > 0).sum()
    return loss


def strong_pp_loss_fast(epsilon, ranks_df, voter_columns, subset_voter_columns, verbose=False):
    """Returns number of ranks and alternatives for which PP is violated."""
    m = len(ranks_df)
    n = len(voter_columns)
    k = len(subset_voter_columns)
    loss = 0
    for r in range(1,m):
        all_top_r_ratios = (((ranks_df[voter_columns] <= r).sum(axis=1)) / n).to_numpy()
        subset_top_r_ratios = (((ranks_df[subset_voter_columns] <= r).sum(axis=1)) / k).to_numpy()
        diffs = np.abs(all_top_r_ratios - subset_top_r_ratios)
        if verbose:
            print('r', r)
            print('all_top_r_ratios', all_top_r_ratios)
            print('subset_top_r_ratios', subset_top_r_ratios)
            print('diffs', diffs)
        loss += (diffs > epsilon).sum()
    return loss