from PromptPATE.pate.pate_utils import *
import numpy as np
import pandas as pd
import os
from prompt_graph.utils import get_args
import ipdb
import time
import torch

def get_how_many_answered(predicted_labels):
    """
    This function counts how many labels we obtained. We provide the raw output from the PATE
    """
    total_len = len(predicted_labels)
    # ipdb.set_trace()
    # jx: question, shouldn't it be np.sum?
    #answered = np.mean(predicted_labels!=-1)
    answered = np.sum(predicted_labels!=-1)

    return answered

def get_how_many_correctly_answered(predicted_labels, true_labels):
    """
    Given the noisy predictions from the teachers and the true validation labels, we get the number of correctly answered labels.
    """
    correct_indices = predicted_labels==true_labels
    # ipdb.set_trace()
    #number_correct_indices = np.mean(correct_indices)
    #jx: question: shouldn't it be np.sum?
    number_correct_indices = np.sum(correct_indices)
    return number_correct_indices


def get_final_noisy_labels(noisy_labels, indices_answered, max_num_query):
    """
    Based on the noisy labels obtained from the query to PATE, the information which indices were not rejected, and how
    many queries we can answer with the privacy budget, it returns an array that has -1 everywhere where we cannot answer
    (due to reject or budget exhausted) and the noisy label everywhere else.
    """

    actually_answered = np.ones(noisy_labels.shape) * -1

    # exclude the queries that we could not answer due to too confident gnmax pre-filtering:
    actually_answered[indices_answered] = noisy_labels[indices_answered]    
    # exclude the queries that we cannot answer anymore because we ran out of budget:
    # jx: question, shouldn't we exclude the queries which are out of the max_num_query number of the indices_answered?
    # ipdb.set_trace()
    if np.sum(actually_answered!=-1) > max_num_query:
        indices_to_set_exclude = indices_answered[max_num_query:]
        actually_answered[indices_to_set_exclude] = -1
    # actually_answered[max_num_query:] = -1
    actually_answered = np.asarray(actually_answered, dtype=int)

    return actually_answered

def get_actually_consumed_epsilon(dp_eps):
    """
    This function tell us what epsilon we actually consumed.
    If the last number in the dp_eps is zero, this means that the budget was exhausted before the end of our queries.
    In this case, we need to take the second to last non-zero element. (Because the last element is above our target epsilon).
    In the case the last element is not zero, then we can return the last element.
    """
    if sum(dp_eps) == 0:
        print("No significant privacy costs incurred. Probably delta is quite large.")
        consumed_epsilon = 0.0
    elif dp_eps[-1] == 0.:
        consumed_epsilon = np.partition(dp_eps.flatten(), -2)[-2]
    else:
        consumed_epsilon = dp_eps[-1]

    return consumed_epsilon

def get_weights(w_min, w_max, num_teachers, scores):
    """
    This function generates a list of weights based on the scores
    which are between w_min and w_max.
    The sum of the weights is num_teachers.
    """
    # normalize the scores
    scores = np.array(scores)
    scores = (scores - np.min(scores)) / (np.max(scores) - np.min(scores))
    # get the weights
    weights = w_min + (w_max - w_min) * scores
    # normalize the weights
    weights = weights / np.sum(weights) * num_teachers
    return weights

def tune_pate(vote_array, threshold_list, sigma_threshold_list, sigma_gnmax_list, epsilon_list, delta_list, num_classes=2, savepath='', true_labels=None):
    """
    This function iterates through many parameter combinations for finding right PATE hyperparameters.
    It will create a csv file in which all parameters are logged together with how many queries they allow us to answer,
    and how many are correctly answered.
    """
    header = ['target_epsilon',  'achieved_eps', 'threshold', 'sigma_threshold', 'sigma_noise', 'delta', 'num_classes', 'num_answered', 'num_correctly_answered', 'accuracy', 'created_time']
 
    for threshold in threshold_list:
        for sigma_threshold in sigma_threshold_list:
            for sigma_gnmax in sigma_gnmax_list:
                for epsilon in epsilon_list:
                    for delta in delta_list:
                        # this part is for the privacy accounting:
                        max_num_query, dp_eps, partition, answered, order_opt = analyze_multiclass_confident_gnmax(
                            votes=vote_array,
                            threshold=threshold,
                            sigma_threshold=sigma_threshold,
                            sigma_gnmax=sigma_gnmax,
                            budget=epsilon,
                            delta=delta,
                            file=None)

                        # this is for getting the labels
                        noisy_labels, indices_answered = query(vote_array, threshold, sigma_threshold, sigma_gnmax,
                                                               num_classes)

                        final_labels = get_final_noisy_labels(noisy_labels, indices_answered, max_num_query)

                        num_answered = get_how_many_answered(final_labels)
                        if true_labels.any() != None:
                            num_correctly_answered = get_how_many_correctly_answered(final_labels, true_labels)
                        else:
                            num_correctly_answered = 0
                        achieved_epsilon = get_actually_consumed_epsilon(dp_eps)
                        # get current timesatmp using time module, in year-month-day-hour-minute-second format
                        current_time = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
                        write_results = [epsilon, achieved_epsilon, threshold, sigma_threshold, sigma_gnmax, delta, num_classes, num_answered, num_correctly_answered, num_correctly_answered/num_answered, current_time]
                        results_df = pd.DataFrame(write_results).T

                        results_df.columns = header
                        # this appends to the csv file that we have with mode 'a'
                        results_df.to_csv(savepath, mode='a', index=False,
                                                 header=not os.path.isfile(savepath))


def inference_pate(vote_array, threshold, sigma_threshold, sigma_gnmax, epsilon, delta, num_classes=10):
    """
    This function, given a vote array and the best found hyperparameters infers the teacher's final aggregated private votes.
    It saves them at the save path.
    """
    # this part is for the privacy accounting:
    max_num_query, dp_eps, partition, answered, order_opt = analyze_multiclass_confident_gnmax(votes=vote_array,
                                                                                               threshold=threshold,
                                                                                               sigma_threshold=sigma_threshold,
                                                                                               sigma_gnmax=sigma_gnmax,
                                                                                               budget=epsilon,
                                                                                               delta=delta,
                                                                                               file=None)
    # this is for getting the labels
    noisy_labels, indices_answered = query(vote_array, threshold, sigma_threshold, sigma_gnmax, num_classes, weights=None, weight=False)
    achieved_epsilon = get_actually_consumed_epsilon(dp_eps)
    print(achieved_epsilon)
    final_labels = get_final_noisy_labels(noisy_labels, indices_answered, max_num_query)
    # pd.DataFrame(final_labels).to_csv(savepath, index=False, header=None)
    return final_labels

def inference_pate_weight(vote_array, threshold, sigma_threshold, sigma_gnmax, epsilon, delta, num_classes=10):
    """
    This function, given a vote array and the best found hyperparameters infers the teacher's final aggregated private votes.
    It saves them at the save path.
    """
    # this part is for the privacy accounting:
    # load averge centrality score
    average_centrality_score = torch.load('./dataspace/CentralityScore/{}shot/{}_{}/seed_{}/{}_{}_{}.pt'.format(args.shot_num, args.dataset_name, args.pre_train_data, args.seed, args.pre_train_type, args.prompt_type, args.gnn_type))
    weights = get_weights(0.5, 2.0, num_teachers=vote_array.shape[1], scores=average_centrality_score)
    max_num_query, dp_eps = analyze_weighing_gnmax(votes_array=vote_array, weights=weights, sigma_gnmax=args.sigma_gnmax)
    # this is for getting the labels
    noisy_labels, indices_answered = query(vote_array, threshold, sigma_threshold, sigma_gnmax, num_classes, weights, weight=True)
    final_labels = get_final_noisy_labels(noisy_labels, indices_answered, max_num_query)
    # pd.DataFrame(final_labels).to_csv(savepath, index=False, header=None)
    return final_labels
    
args = get_args()

def main(vote_array, threshold, sigma_threshold, sigma_gnmax, epsilon, delta, num_classes, savepath='', tune_hyper=False, true_labels=None, weighted_pate=False):


    # print(f"how many queries could we answer: {max_num_query}")
    # print(f"list of our (accumulated) epsilon: {dp_eps}")
    # #print(partition) # not really important for us
    # print(f"how many query do we expect to answer at any given point: {answered}")
    # print(f"what is the optimal RDP order at every step: {order_opt}")

    """
    @user, you can use this function for two purposes, either finding best hyperparameters of PATE, or once you
    know which hyperparameters you want, you can just infer the labels.
    """

    # To do an inference, do this here:
    if not tune_hyper:
        final_label_path = './dataspace/PateInference/{}shot/{}_{}/seed_{}/{}_{}_{}.txt'.format(args.shot_num, args.dataset_name, args.pre_train_data, args.seed, args.pre_train_type, args.prompt_type, args.gnn_type)
        if not os.path.exists(os.path.split(final_label_path)[0]):
                os.makedirs(os.path.split(final_label_path)[0])
        final_index_path = './datasapce/PateInference/{}shot/{}_{}/seed_{}/{}_{}_{}_index.txt'.format(args.shot_num, args.dataset_name, args.pre_train_data, args.seed, args.pre_train_type, args.prompt_type, args.gnn_type)
        
        if weighted_pate:
            final_labels = inference_pate_weight(vote_array, threshold, sigma_threshold, sigma_gnmax, epsilon, delta, num_classes=num_classes)
        else:
            final_labels = inference_pate(vote_array, threshold, sigma_threshold, sigma_gnmax, epsilon, delta, num_classes=num_classes)
        final_index = np.arange(len(final_labels))
        mask = (final_labels != -1)
        final_labels = final_labels[mask]
        final_index = final_index[mask]
        print(len(final_labels))
        print(len(final_labels), len(final_index))
        np.savetxt(final_label_path, final_labels, fmt="%i")
        np.savetxt(final_index_path, final_index, fmt="%i")
    else:

        tune_pate(vote_array, threshold, sigma_threshold, sigma_gnmax, epsilon, delta,
                  num_classes=num_classes, savepath=savepath, true_labels=true_labels)



if __name__ == "__main__":
    if args.dataset_name == 'Cora':
        num_classes = 7
    elif args.dataset_name == 'CiteSeer':
        num_classes = 6
    elif args.dataset_name == 'PubMed':
        num_classes = 3

    # load the votes
    votes_load_path = './dataspace/PateVotesArray/{}shot/{}_{}/seed_{}/{}_{}_{}_votes.txt'.format( args.shot_num, args.dataset_name, args.pre_train_data, args.seed, args.pre_train_type, args.prompt_type, args.gnn_type)
    true_label_load_path = './dataspace/PateVotesArray/{}shot/{}_{}/seed_{}/{}_{}_{}_true_labels.txt'.format( args.shot_num, args.dataset_name, args.pre_train_data, args.seed, args.pre_train_type, args.prompt_type, args.gnn_type)
    votes_df = pd.read_csv(votes_load_path, sep=" ", header=None)
    vote_array = votes_df.to_numpy(dtype=int) # THIS IS VERY IMPORTANT: MAKE SURE THE SHAPE IS (num_samples, num_teachers)

    true_label_df = pd.read_csv(true_label_load_path, sep=" ", header=None)
    true_labels = true_label_df.to_numpy(dtype=int).flatten()

    # where to save the outcome of the tuning.
    savepath = './dataspace/PateTune/{}shot/{}_{}/seed_{}/{}_{}_{}.csv'.format(args.shot_num, args.dataset_name, args.pre_train_data, args.seed, args.pre_train_type, args.prompt_type, args.gnn_type)
    if not os.path.exists(os.path.split(savepath)[0]):
            os.makedirs(os.path.split(savepath)[0])

main(vote_array, args.threshold, args.sigma_threshold, args.sigma_gnmax, 1, args.delta, num_classes, savepath = None, tune_hyper=False, true_labels=true_labels, weighted_pate=args.weighted_pate)

