######################################################################################## IMPORTS ########################################################################################

import pandas as pd
import numpy as np
import random
import ast
from collections import defaultdict
from itertools import combinations
import matplotlib.pyplot as plt
from scipy.stats import kendalltau, spearmanr
from tqdm import tqdm
from scipy.interpolate import make_interp_spline

######################################################################################## Voting Rules ########################################################################################

# Function to apply Borda Voting method
def borda_voting_rule(votes):
    scores = defaultdict(int)
    num_options = len(max(votes, key=len))  # get the number of options based on the longest vote

    for vote in votes:
        for idx, option in enumerate(vote):
            scores[option] += num_options - idx  # assign points based on the position in the vote

    # Return only the options, not the scores, sorted in descending order of scores
    return [option for option, score in sorted(scores.items(), key=lambda x: x[1], reverse=True)]

# Function to apply Copeland Voting method
def copeland_voting_rule(votes):
    # Count the frequency of each alternative
    frequency = defaultdict(int)
    for vote_list in votes:
        vote = vote_list[0]  # Assuming each vote_list contains a single integer
        frequency[vote] += 1
    scores = {alternative: 0 for alternative in frequency}
    for alt_a in frequency:
        for alt_b in frequency:
            if alt_a == alt_b:
                continue
            if frequency[alt_a] > frequency[alt_b]:
                scores[alt_a] += 1
            elif frequency[alt_a] < frequency[alt_b]:
                scores[alt_a] -= 1

    # Sort alternatives based on the simplified Copeland scores
    sorted_scores = sorted(scores.items(), key=lambda x: x[1], reverse=True)
    return [key for key, value in sorted_scores]

# Function to apply Maximin method
def maximin_voting_rule(votes):
    scores = defaultdict(lambda: float('-inf'))  # Initialize scores with negative infinity
    
    # Iterate through each vote
    for vote in votes:
        for idx, option in enumerate(vote):
            # Update the score of an option to its highest minimum position
            scores[option] = max(scores[option], len(vote) - idx)
    
    # Sort options based on their highest minimum score in descending order
    sorted_options = sorted(scores.items(), key=lambda x: x[1], reverse=True)
    
    # Return only the options, sorted by their highest minimum score
    return [option for option, score in sorted_options]

# Function to apply Schulze Voting method
def schulze_voting_rule(votes):
    # Create set of all alternatives
    alternatives_set = set()
    for vote in votes:
        alternatives_set.update(vote)

    # Create mapping from alternatives to indices
    alternatives = list(alternatives_set)
    alternatives_to_indices = {alternative: i for i, alternative in enumerate(alternatives)}

    # Initialize pairwise preference matrix and strongest paths matrix
    num_alternatives = len(alternatives)
    pairwise_pref = np.zeros((num_alternatives, num_alternatives))
    strongest_paths = np.zeros((num_alternatives, num_alternatives))

    # Compute pairwise preference matrix
    for vote in votes:
        for i, j in combinations(vote, 2):
            if vote.index(i) < vote.index(j):
                pairwise_pref[alternatives_to_indices[i]][alternatives_to_indices[j]] += 1
            else:
                pairwise_pref[alternatives_to_indices[j]][alternatives_to_indices[i]] += 1

    # Initialize strongest paths matrix with pairwise preferences
    for i in range(num_alternatives):
        for j in range(num_alternatives):
            if i != j:
                strongest_paths[i][j] = pairwise_pref[i][j]

    # Compute strongest paths matrix
    for i in range(num_alternatives):
        for j in range(num_alternatives):
            if i != j:
                for k in range(num_alternatives):
                    if k != i and k != j:
                        strongest_paths[j][k] = max(strongest_paths[j][k], min(strongest_paths[j][i], strongest_paths[i][k]))

    # Compute ranking
    ranking = []
    for i in range(num_alternatives):
        rank = sum(1 for j in range(num_alternatives) if strongest_paths[i][j] > strongest_paths[j][i])
        ranking.append((alternatives[i], rank))

    # Sort alternatives by rank in descending order
    ranking.sort(key=lambda x: -x[1])

    return [alternative for alternative, rank in ranking]

################################################################# SP Format Adapted Vote Aggregation #########################################

def SP_borda_rule(votes):
    unique_alts = set(alt for ranking in votes for alt in ranking)
    borda_scores = {alt: 0 for alt in unique_alts}
    
    for ranking in votes:
        num_options = len(ranking)
        for i, alt in enumerate(ranking):
            borda_scores[alt] += num_options - i - 1
    
    sorted_alts = sorted(borda_scores.items(), key=lambda x: x[1], reverse=True)
    return [option for option, score in sorted_alts]

def SP_copeland_rule(votes):

    # Create a dictionary to store pairwise comparisons
    pairwise_comparisons = {}

    # Fill the pairwise comparison dictionary with counts of wins for each pair
    for ranking in votes:
        for i, winner in enumerate(ranking):
            for loser in ranking[i+1:]:
                if (winner, loser) not in pairwise_comparisons:
                    pairwise_comparisons[(winner, loser)] = 0
                pairwise_comparisons[(winner, loser)] += 1

    # Calculate the Copeland scores
    unique_alts = set(alt for ranking in votes for alt in ranking)
    copeland_scores = defaultdict(int)

    for alt in unique_alts:
        wins = sum(1 for (winner, loser) in pairwise_comparisons if winner == alt)
        losses = sum(1 for (winner, loser) in pairwise_comparisons if loser == alt)
        copeland_scores[alt] = wins - losses

    # Sort the alternatives based on their Copeland scores
    sorted_alts = sorted(copeland_scores.items(), key=lambda x: x[1], reverse=True)

    # Return only the options, not the copelands
    return [option for option, count in sorted_alts]

def SP_maximin_rule(rankings):
    scores = defaultdict(int)
    unique_alts = set(alt for ranking in rankings for alt in ranking)
    maximin_scores = {alt: len(unique_alts) for alt in unique_alts}

    for ranking in rankings:
        for i, alt in enumerate(ranking):
            if scores[alt] > i or scores[alt] == 0:
                scores[alt] = i  # assign the minimum position the option has ever appeared
        for alt, score in scores.items():
            if score < maximin_scores[alt]:
                maximin_scores[alt] = score

    # Sort the alternatives based on their Maximin scores
    sorted_alts = sorted(maximin_scores.items(), key=lambda x: x[1])
    # Return only the options, not the maximins
    return [option for option, count in sorted_alts]

def SP_schulze_rule(rankings):
    # Create a set of unique alternatives
    alternatives = set(alt for options in rankings for alt in options)

    # Compute the pairwise preferences
    pairwise_preferences = defaultdict(int)
    for ranking in rankings:
        for i, winner in enumerate(ranking):
            for loser in ranking[i + 1:]:
                if winner in alternatives and loser in alternatives:
                    pairwise_preferences[(winner, loser)] += 1

    # Compute the strongest path matrix using Floyd-Warshall algorithm
    strongest_paths = {a: {b: 0 for b in alternatives} for a in alternatives}

    for i in alternatives:
        for j in alternatives:
            if i != j:
                if pairwise_preferences[(i, j)] > pairwise_preferences[(j, i)]:
                    strongest_paths[i][j] = pairwise_preferences[(i, j)]

    for i in alternatives:
        for j in alternatives:
            if i != j:
                for k in alternatives:
                    if i != k and j != k:
                        strongest_paths[j][k] = max(strongest_paths[j][k], min(strongest_paths[j][i], strongest_paths[i][k]))

    # Rank the alternatives based on the strongest path matrix
    sorted_alternatives = sorted(alternatives, key=lambda x: sum(strongest_paths[x][a] > strongest_paths[a][x] for a in alternatives), reverse=True)

    # Return only the options, not the schulzs
    return sorted_alternatives


################################################################## Vote Function ########################################################################################

#This is where we define the voting rules we want to use
def votes_function(votes, aggregation_type, rule):
        
    if rule == 'Borda':
        if aggregation_type == 'vote':
            new_alts=borda_voting_rule(votes)
            return new_alts
        if aggregation_type == 'sp':
            new_alts=SP_borda_rule(votes)
            return new_alts
        
    if rule == 'Copeland':
        if aggregation_type == 'vote':
            new_alts=copeland_voting_rule(votes)
            return new_alts
        if aggregation_type == 'sp':
            new_alts=SP_copeland_rule(votes)
            return new_alts

    if rule == 'Maximin':
        if aggregation_type == 'vote':
            new_alts=maximin_voting_rule(votes)
            return new_alts
        if aggregation_type == 'sp':
            new_alts=SP_maximin_rule(votes)
            return new_alts
        
    if rule == 'Schulze':
        if aggregation_type == 'vote':
            new_alts=schulze_voting_rule(votes)
            return new_alts
        if aggregation_type == 'sp':
            new_alts=SP_schulze_rule(votes)
            return new_alts
################################################################## Partial-SP Algorithm ########################################################################################

#If in all votes, a is preferred over b return 1 else return 0
#We see for a pair of alternatives, how many times a is preferred over b
def infoab(votes, treatment, a, b):

    if treatment == 4 or treatment == 5 or treatment == 6: #Elicitation Formats - Rank-None, Rank-Top, Rank-Rank

        alts = votes
        idxa = alts.index(a)
        idxb = alts.index(b)
        if idxa < idxb:
            return 1
        elif idxa > idxb:
            return 0
        else:
            print('Cant find a or b')

    if treatment == 1 or treatment == 2 or treatment == 3 or treatment == 8: #Elicitation Formats - Top-None, Top-Top, Top-Approval(3), Top-Rank

        alts = votes[0]

        if alts == a:
            return 1
        elif alts == b:
            return 0
        else:
            return -1
        
    if treatment == 9: #Elicitation Formats - Approval(2) - Approval(2)

        alts1 = votes[0]
        alts2 = votes[1]

        if alts1 == a or alts2 == a:
            return 1
        elif alts1 == b or alts2 == b:
            return 0
        else:
            return -1

    if treatment == 7: #Elicitation Formats - Approval(3) - Rank
        alts1 = votes[0]
        alts2 = votes[1]
        alts3 = votes[2]

        if alts1 == a or alts2 == a or alts3 == a:
            return 1
        elif alts1 == b or alts2 == b or alts3 == b:
            return 0
        else:
            return -1

#If in all predictions, a is preferred over b return alpha else return beta
def predab(predictions, treatment, a, b, alpha, beta):

    if treatment == 3 or treatment == 6 or treatment ==7: #Elicitation Formats - Top-Rank, Approval(3) - Rank, Rank-Rank

        pred_alts = predictions
        idxa = pred_alts.index(a)
        idxb = pred_alts.index(b)
        if idxa < idxb:
            return alpha
        elif idxa > idxb:
            return beta
        else:
            print('Same location for prediction report')
    
    if treatment == 2 or treatment == 5: #Elicitation Formats - Top-Top, Rank-Top

        pred_alts = predictions[0]
        if pred_alts == a:
            return alpha
        elif pred_alts == b:
            return beta
        else:
            return 0.5
            
    if treatment == 8: #Elicitation Formats - Top-Approval(3)

        pred_alts1 = predictions[0]
        pred_alts2 = predictions[1]
        pred_alts3 = predictions[2]
        if pred_alts1 == a or pred_alts2 == a or pred_alts3 == a:
            return alpha
        elif pred_alts1 == b or pred_alts2 == b or pred_alts3 == b:
            return beta
        else:
            return 0.5
    if treatment == 9: #Elicitation Formats - Approval(2) - Approval(2)

        pred_alts1 = predictions[0]
        pred_alts2 = predictions[1]
        if pred_alts1 == a or pred_alts2 == a:
            return alpha
        elif pred_alts1 == b or pred_alts2 == b:
            return beta
        else:
            return 0.5

#For each pair of alternatives, check how many times a wins over b and how many times a loses over b and also check what is the prediction
#for each case, then calculate the prediction-normalized score for the pair of a and b and then see who wins.
def Aggregate_II(information, prediction):
    idx1 = information[information == 1].index
    idx0 = information[information == 0].index

    
    prediction_0 = pd.Series([1 - x for x in prediction], index=prediction.index)
    p11 = np.mean(prediction[idx1])
    p10 = np.mean(prediction[idx0])
    p01 = np.mean(prediction_0[idx1])
    p00 = np.mean(prediction_0[idx0])
    if len(idx1) == 0 or len(idx0) == 0:
        return 0
    nv1 = len(idx1) / (len(idx1) + len(idx0)) * (1 + p01/p10)
    nv0 = len(idx0) / (len(idx1) + len(idx0)) * (1 + p10/p01)
    
    if nv1 >= nv0:
        return 1
    else:
        return 0

def complete_ranking(lpairs):
    alts = []
    for v in lpairs:
        alts.extend([v[0], v[1]])
    alts = list(set(alts))
    score = [0] * len(alts)
    for v in lpairs:
        pos = alts.index(v[0])
        score[pos] += 1

    # Sort the alternatives by their scores
    sorted_alts = sorted(zip(alts, score), key=lambda x: x[1], reverse=True)
    # Extract the sorted alternatives only
    ranking = [x[0] for x in sorted_alts]

    return ranking


def sp_voting(df, treatment):
    rankings_df = pd.DataFrame(columns=['domain', 'question', 'ranking'])
    #Dropping the duplicates and converting the options to tuple
    df['options'] = df['options'].apply(lambda x: tuple(x) if not isinstance(x, tuple) else x)

    #Extracting the unique questions
    Questions = df[['domain', 'question', 'options']].drop_duplicates()

    #Grouping the dataframe by domain
    Q = Questions.groupby('domain').filter(lambda x: len(x) > 0)
    df=df.groupby('domain').filter(lambda x: len(x) > 0)

    #Parameters for the SP voting rule

    alpha_0 = 0.55
    beta_0 = 0.1

    for index_test, row_test in Q.iterrows():
        
        #We are extracting all the data for a particular subset and storing it in dfsub
        dfsub = df.loc[(df['domain'] == Q.loc[index_test]['domain']) & (df['question'] == Q.loc[index_test]['question']) & (df['treatment'] == treatment), :].copy()
        options=Q.loc[index_test]['options']
        pairs=list(combinations(options, 2))
        ordered_pairs = []
        for v in pairs:
            v1 = int(v[0])
            v2 = int(v[1])
            dfsub.loc[:, 'information'] = dfsub.apply(lambda x: infoab(x['votes'], x['treatment'], v1, v2), axis=1)
            #dfsub.loc[:, 'information'] = dfsub['votes'].map(lambda x: infoab(x, v1, v2))
            dfsub.loc[:, 'prediction'] = dfsub.apply(lambda x: predab(x['predictions'], x['treatment'], v1, v2, alpha_0, beta_0), axis=1)
            agg_alt = Aggregate_II(dfsub['information'], dfsub['prediction'])
            if agg_alt == 1:
                ordered_pairs.append([v1, v2])
            else:
                ordered_pairs.append([v2, v1])
        ranking = complete_ranking(ordered_pairs)
        # Check if the row already exists in the DataFrame
        row_exists = rankings_df.loc[(rankings_df['domain'] == row_test['domain']) & (rankings_df['question'] == row_test['question'])].shape[0] > 0
        if not row_exists:
            new_rows = []
            new_rows.append(pd.DataFrame({'domain': [row_test['domain']], 'question': [row_test['question']], 'ranking': [ranking]}))
            rankings_df = pd.concat([rankings_df] + new_rows, ignore_index=True)
        #This is where the voting rule comes which we will use to aggregate the sp votes
    final_ranking = rankings_df['ranking']
    return final_ranking

################################################################## Evaluation Metrics ########################################################################################

def calculate_correct_hits(ground_truth, rankings):
    correct_hits = []
    differences = []

    for d in range(1, 25):  
        total_pairs = 0
        correct_pairs = 0
        for i in range(len(ground_truth) - d):
            pair_a = (ground_truth[i], ground_truth[i + d])
            pair_b = (rankings[i], rankings[i + d])
            if (pair_a[0] < pair_a[1] and pair_b[0] < pair_b[1]) or (pair_a[0] > pair_a[1] and pair_b[0] > pair_b[1]):
                correct_pairs += 1
            total_pairs += 1

        fraction_correct_hits = correct_pairs / total_pairs
        correct_hits.append(fraction_correct_hits)
        differences.append(d)

    return differences, correct_hits

def calculate_correct_hits_for_range_k(ground_truth, rankings, max_k):
    k_values = []
    correct_hits_values = []

    for k in range(1, max_k+1):
        correct_hits = 0

        # Consider only top k elements
        gt = ground_truth[:k]
        rk = rankings[:k]

        total_pairs = len(gt) - 1

        for i in range(total_pairs):
            pair_a = (gt[i], gt[i + 1])
            pair_b = (rk[i], rk[i + 1])

            if (pair_a[0] < pair_a[1] and pair_b[0] < pair_b[1]) or (pair_a[0] > pair_a[1] and pair_b[0] > pair_b[1]):
                correct_hits += 1

        if total_pairs > 0:
            fraction_correct_hits = correct_hits / total_pairs
        else:
            fraction_correct_hits = 0  # avoid division by zero when k=1

        k_values.append(k)
        correct_hits_values.append(fraction_correct_hits)

    return k_values, correct_hits_values

def calculate_top_t_hit_rate(ground_truth, rankings, max_k):
    k_values = []
    fraction_hits_values = []

    for k in range(1, max_k + 1):
        gt_set = set(ground_truth[:k])
        rk_set = set(rankings[:k])

        # Calculate the intersection of the top-k elements in both ground_truth and rankings
        intersection_count = len(gt_set & rk_set)

        # Fraction of the top-k elements from ground_truth that are present in rankings
        fraction_hits = intersection_count / k

        k_values.append(k)
        fraction_hits_values.append(fraction_hits)

    return k_values, fraction_hits_values


# Example for smoothing the CI
def smooth_ci(x, y_lower, y_upper, smoothing_factor=10):
    x_new = np.linspace(x.min(), x.max(), smoothing_factor)
    spl_lower = make_interp_spline(x, y_lower, k=3)  # k is the degree of the spline
    spl_upper = make_interp_spline(x, y_upper, k=3)
    
    y_lower_smooth = spl_lower(x_new)
    y_upper_smooth = spl_upper(x_new)
    
    return x_new, y_lower_smooth, y_upper_smooth


##################################################################  Partial-SP ########################################################################################


def run_partial_sp(elicitation_format, rule, domain):
    # Find the corresponding key in the dictionary
    treatment = None
    for key, value in map_elicitation.items():
        if value == elicitation_format:
            treatment = key
    if elicitation_format == 'approval(3)-rank':
        elicitation_format = 'subset3-rank'
    if elicitation_format == 'top-approval(3)':
        elicitation_format = 'top-subset3'
    if elicitation_format == 'approval(2)-approval(2)':
        elicitation_format = 'subset2-subset2' 

    # Read CSV file
    df = pd.read_csv(f'Elicitation Formats/{elicitation_format}/{elicitation_format}_{domain}.csv')
    # Convert the string representation of lists in 'votes' to actual list
    df['votes'] = df['votes'].apply(ast.literal_eval)
    df['predictions'] = df['predictions'].apply(ast.literal_eval)
    df['options'] = df['options'].apply(ast.literal_eval)

    # Lists to store the Kendall Tau distances and Spearman correlations
    kendall_tau_sp = []
    kendall_tau_vote = []
    spearman_rho_sp = []
    spearman_rho_vote = []

    # Number of bootstrap samples
    num_bootstraps = 1000

    # Lists to store the bootstrap sample correct hit fractions and differences
    bootstrap_correct_hits_votes = []
    bootstrap_differences_votes = []
    bootstrap_correct_hits_sp = []
    bootstrap_differences_sp = []
    bootstrap_k_values_votes = []
    bootstrap_correct_hits_k_values_votes = []
    bootstrap_k_values_sp = []
    bootstrap_correct_hits_k_values_sp = []

    # Perform bootstrapping
    for control in tqdm(range(num_bootstraps), desc="Bootstrapping", unit="bootstrap"):
        # Group the data by question
        grouped = df.groupby('question')

        bootstrap_sample_dfs = []  # List to store each group's sampled DataFrame
        sp_partial_ground_truth = pd.DataFrame()
        for name, group in grouped:
            sampled_df = group.sample(20, replace=True)
            bootstrap_sample_dfs.append(sampled_df)

        # Concatenate all sampled DataFrames
        bootstrap_sample_df = pd.concat(bootstrap_sample_dfs, ignore_index=True)

        sp_partial_ground_truth = sp_voting(bootstrap_sample_df, treatment)

        #At this point we have partial ground truths for each subset

        sp_ranking=votes_function(sp_partial_ground_truth, aggregation_type='sp', rule=rule)
        vote_ranking = votes_function(bootstrap_sample_df['votes'], aggregation_type='vote', rule=rule)
        ground_truth_ranking = sorted (vote_ranking)

        common_alternatives = set(sp_ranking) & set(vote_ranking)

        ground_truth_ranking = [alt for alt in ground_truth_ranking if alt in common_alternatives]
        vote_ranking = [alt for alt in vote_ranking if alt in common_alternatives]
        sp_ranking = [alt for alt in sp_ranking if alt in common_alternatives]

        # Calculate Kendall Tau distance between ground_truth and sp_ranking
        tau_sp, _ = kendalltau(ground_truth_ranking, sp_ranking)
        kendall_tau_sp.append(tau_sp)

        # Calculate Kendall Tau distance between ground_truth and vote_ranking
        tau_vote, _ = kendalltau(ground_truth_ranking, vote_ranking)
        kendall_tau_vote.append(tau_vote)

        # Calculate Spearman correlation between ground_truth and sp_ranking
        rho_sp, _ = spearmanr(ground_truth_ranking, sp_ranking)
        spearman_rho_sp.append(rho_sp)

        # Calculate Spearman correlation between ground_truth and vote_ranking
        rho_vote, _ = spearmanr(ground_truth_ranking, vote_ranking)
        spearman_rho_vote.append(rho_vote)    
        
        differences_sp, correct_hits_sp = calculate_correct_hits(ground_truth_ranking, sp_ranking)
        differences_votes, correct_hits_votes = calculate_correct_hits(ground_truth_ranking, vote_ranking)
        
        # Store the correct hits and differences
        bootstrap_correct_hits_votes.append(correct_hits_votes)
        bootstrap_differences_votes.append(differences_votes)
        bootstrap_correct_hits_sp.append(correct_hits_sp)
        bootstrap_differences_sp.append(differences_sp)

        # Calculate correct hits for range k for the bootstrap sample
        k_values_votes, correct_hits_k_values_votes = calculate_top_t_hit_rate(ground_truth_ranking, vote_ranking, max_k=20)
        k_values_sp, correct_hits_k_values_sp = calculate_top_t_hit_rate(ground_truth_ranking, sp_ranking, max_k=20)
        # Store k_values and correct_hits_k_values
        bootstrap_k_values_votes.append(k_values_votes)
        bootstrap_correct_hits_k_values_votes.append(correct_hits_k_values_votes)
        bootstrap_k_values_sp.append(k_values_sp)
        bootstrap_correct_hits_k_values_sp.append(correct_hits_k_values_sp)


    ######################################################################################## Plotting ########################################################################################


    print_mapping_elicitation = {
        1: 'Top-None',
        2: 'Top-Top',
        3: 'Top-Rank',
        4: 'Rank-None',
        5: 'Rank-Top',
        6: 'Rank-Rank',
        7: 'Approval(3)-Rank',
        8: 'Top-Approval(3)',
        9: 'Approval(2)-Approval(2)'
    }

    print_elicitation=print_mapping_elicitation[treatment]


    # Convert lists to DataFrames
    df_correct_hits_votes = pd.DataFrame(bootstrap_correct_hits_votes)
    df_differences_votes = pd.DataFrame(bootstrap_differences_votes)
    df_correct_hits_sp = pd.DataFrame(bootstrap_correct_hits_sp)
    df_differences_sp = pd.DataFrame(bootstrap_differences_sp)
    df_k_values_votes = pd.DataFrame(bootstrap_k_values_votes)
    df_correct_hits_k_values_votes = pd.DataFrame(bootstrap_correct_hits_k_values_votes)
    df_k_values_sp = pd.DataFrame(bootstrap_k_values_sp)
    df_correct_hits_k_values_sp = pd.DataFrame(bootstrap_correct_hits_k_values_sp)

    kendall_tau_sp = pd.DataFrame(kendall_tau_sp, columns=['Kendall_Tau_SP'])
    kendall_tau_vote = pd.DataFrame(kendall_tau_vote, columns=['Kendall_Tau_Vote'])
    spearman_rho_sp = pd.DataFrame(spearman_rho_sp, columns=['Spearman_Rho_SP'])
    spearman_rho_vote = pd.DataFrame(spearman_rho_vote, columns=['Spearman_Rho_Vote'])

    # Calculate the 95% CI for correct hit fractions at each difference
    lower_bound_votes = np.percentile(df_correct_hits_votes, 2.5, axis=0)
    upper_bound_votes = np.percentile(df_correct_hits_votes, 97.5, axis=0)
    lower_bound_sp = np.percentile(df_correct_hits_sp, 2.5, axis=0)
    upper_bound_sp = np.percentile(df_correct_hits_sp, 97.5, axis=0)


    # Calculate the 95% CI for correct hit fractions at each k
    lower_bound_k_values_votes = np.percentile(df_correct_hits_k_values_votes, 2.5, axis=0)
    upper_bound_k_values_votes = np.percentile(df_correct_hits_k_values_votes, 97.5, axis=0)
    lower_bound_k_values_sp = np.percentile(df_correct_hits_k_values_sp, 2.5, axis=0)
    upper_bound_k_values_sp = np.percentile(df_correct_hits_k_values_sp, 97.5, axis=0)
    
    mean_tau_sp = np.mean(kendall_tau_sp['Kendall_Tau_SP'])
    ci_tau_sp = np.percentile(kendall_tau_sp['Kendall_Tau_SP'], [2.5, 97.5])
    mean_tau_vote = np.mean(kendall_tau_vote['Kendall_Tau_Vote'])
    ci_tau_vote = np.percentile(kendall_tau_vote['Kendall_Tau_Vote'], [2.5, 97.5])

    mean_rho_sp = np.mean(spearman_rho_sp['Spearman_Rho_SP'])
    ci_rho_sp = np.percentile(spearman_rho_sp['Spearman_Rho_SP'], [2.5, 97.5])
    mean_rho_vote = np.mean(spearman_rho_vote['Spearman_Rho_Vote'])
    ci_rho_vote = np.percentile(spearman_rho_vote['Spearman_Rho_Vote'], [2.5, 97.5])

    # Data preparation for plotting
    metrics = ['Kendall Tau SP', 'Kendall Tau Vote', 'Spearman Rho SP', 'Spearman Rho Vote']
    means = [mean_tau_sp, mean_tau_vote, mean_rho_sp, mean_rho_vote]
    ci_lower_bounds = [ci_tau_sp[0], ci_tau_vote[0], ci_rho_sp[0], ci_rho_vote[0]]
    ci_upper_bounds = [ci_tau_sp[1], ci_tau_vote[1], ci_rho_sp[1], ci_rho_vote[1]]
    errors = [(np.array(means) - np.array(ci_lower_bounds)), (np.array(ci_upper_bounds) - np.array(means))]

    # Plotting the bar plots with error bars
    plt.figure(figsize=(12, 6))
    bar_positions = range(len(metrics))
    plt.bar(bar_positions, means, yerr=errors, align='center', alpha=0.7, capsize=10, color=['red', 'blue', 'green', 'purple'])
    plt.xticks(bar_positions, metrics)
    plt.ylabel('Values')
    plt.title('Mean and 95% Confidence Intervals for Correlation Metrics')
    plt.grid(True, linestyle='--', alpha=0.6)
    plt.show()


"""     # Plotting
    plt.figure(figsize=(14, 8))
    x = np.arange(1, 25)  # Difference values

    # Use linestyle and hatch for SP data
    x_new, ci_lower_sp_smooth, ci_upper_sp_smooth = smooth_ci(x, np.array(lower_bound_sp), np.array(upper_bound_sp))
    plt.plot(x, df_correct_hits_sp.mean(), label='Mean Correct Hits Partial-SP', color='red', linestyle='--')
    plt.fill_between(x_new, ci_lower_sp_smooth, ci_upper_sp_smooth, color='red', alpha=0.1, hatch='...')

    # Use linestyle and hatch for Votes data
    x_new, ci_lower_votes_smooth, ci_upper_votes_smooth = smooth_ci(x, np.array(lower_bound_votes), np.array(upper_bound_votes))
    plt.plot(x, df_correct_hits_votes.mean(), label='Mean Correct Hits Votes', color='blue', linestyle='-')
    plt.fill_between(x_new, ci_lower_votes_smooth, ci_upper_votes_smooth, color='blue', alpha=0.1, hatch='\\\\')

    plt.xlabel('Difference Value', fontsize=14)
    plt.ylabel('Correct Hits', fontsize=14)
    plt.title(f'Fraction of Correct Hits with 95% CI for Geography MTurk Data {print_elicitation} {rule} using Difference Metric (Partial-SP vs. Votes)', fontsize=12)
    plt.legend(fontsize=12)
    plt.xticks(fontsize=12)
    plt.yticks(fontsize=12)
    plt.show()

    # Plotting
    plt.figure(figsize=(14, 8))
    x = np.arange(1, 21)  # K values

    # Use linestyle and hatch for SP data
    x_new, ci_lower_sp_smooth, ci_upper_sp_smooth = smooth_ci(x, np.array(lower_bound_k_values_sp), np.array(upper_bound_k_values_sp))
    plt.plot(x, df_correct_hits_k_values_sp.mean(), label='Mean Correct Hits Partial-SP', color='red', linestyle='--')
    plt.fill_between(x_new, ci_lower_sp_smooth, ci_upper_sp_smooth, color='red', alpha=0.1, hatch='...')

    # Use linestyle and hatch for Votes data
    x_new, ci_lower_votes_smooth, ci_upper_votes_smooth = smooth_ci(x, np.array(lower_bound_k_values_votes), np.array(upper_bound_k_values_votes))
    plt.plot(x, df_correct_hits_k_values_votes.mean(), label='Mean Correct Hits Votes', color='blue', linestyle='-')
    plt.fill_between(x_new, ci_lower_votes_smooth, ci_upper_votes_smooth, color='blue', alpha=0.1, hatch='\\\\')

    plt.xlabel('K Value', fontsize=14)
    plt.ylabel('Correct Hits', fontsize=14)
    plt.title(F'Fraction of Correct Hits with 95% CI for Geography MTurk Data {print_elicitation} {rule} using Top-K Metric (Partial-SP vs. Votes)', fontsize=12)
    plt.legend(fontsize=12)
    plt.xticks(fontsize=12)
    plt.yticks(fontsize=12)
    plt.show()
    plt.close() """
    

######################################################################################## Main Function ##################################################################

map_elicitation = {1: 'top-none', 2: 'top-top', 3: 'top-rank', 4: 'rank-none', 5: 'rank-top', 6: 'rank-rank', 7: 'approval(3)-rank', 8: 'top-approval(3)', 9: 'approval(2)-approval(2)'}

# Ask user to enter the elicitation format
excluded_formats = ['top-none', 'rank-none']
choice=input("Enter 1 to run all combinations of domain, elicitation format, and voting rule OR Enter 2 to run a specific combination of domain, elicitation format, and voting rule: ")
if choice == '1':
    elicitation_formats = [value for key, value in map_elicitation.items() if value not in excluded_formats]
    voting_rules = ['Borda', 'Copeland', 'Maximin', 'Schulze']
    domains = ['Geography', 'Movies', 'Paintings']
    for elicitation_format in elicitation_formats:
        for rule in voting_rules:
            for domain in domains:
                print("******************************************************************************************")
                print(f"Running {elicitation_format}, {rule}, {domain}")
                run_partial_sp(elicitation_format, rule, domain)
                print(f"Completed {elicitation_format}, {rule}, {domain}")
if choice == '2':
    print("In the following, you will be asked to enter the elicitation format, voting rule, and domain for which you want to evaluate the Partial-SP algorithm. Make sure to enter the options in the same case as shown in the prompt.")
    elicitation_format = input("Choose an elicitation format amongst top-top or top-rank or rank-top or rank-rank or approval(3)-rank or top-approval(3) or approval(2)-approval(2): ")
    rule=input("Choose a voting rule amongst Borda, Copeland, Maximin, Schulze: ")
    domain = input("Enter the domain amongst Geography, Movies, and Paintings: ")
    run_partial_sp(elicitation_format, rule, domain)
