import numpy as np
import seaborn as sns 
import argparse
import matplotlib
from matplotlib import pyplot as plt
import math
import os

matplotlib.rcParams.update({'font.size': 18})


parser = argparse.ArgumentParser(description='Settings')
#parser.add_argument('--model_type', default = 'target', choices=['target','shadow'])
#parser.add_argument('--P_x', default=0.5, type = float)
#parser.add_argument('--target_epsilon', default = 5.0, type = float)
parser.add_argument('--dataset', default = 'mnist', choices=['cifar10', 'fmnist', 'mnist', 'svhn_ext'])
parser.add_argument('--nan_allowed', default = "no", choices = ["yes","no"])


args = parser.parse_args()

#P_x = args.P_x
dataset = args.dataset
# hi im snvith and i lesve my lsptop unlovked
if args.nan_allowed == "yes":
    nan_allowed = True

else:
    nan_allowed = False
print(f"nan allowed = {nan_allowed}")
#target_epsilon = args.target_epsilon


backward_precs_sample_trials = []
forward_precs_sample_trials = []
baseline_sample_trials = []

#NOTE: hardcoded eps for now

eps = 3.0

for Trial in [1,2,3,4,5]:
    #print(f"On Trial {Trial}")
    
    backward_precs_sample = []
    forward_precs_sample = []
    baseline_sample = []

    for i in [1,3,5,7,9]:
        #print(f"On {i}")
        
        scores = np.load(f'lira_scores/scores_1000_{dataset}_{i}_{Trial}_no_dp.npy')
        #print(f"Number of nan: {np.count_nonzero(np.isnan(scores))}")
        sampling = np.load(f'lira_samplings/sampling_{dataset}_target_{i}_{Trial}_0_no_dp.npy')
        our_sampling = sampling[:1000]

        if not nan_allowed:
            print("removing nan")
            our_sampling = our_sampling[~np.isnan(scores)]
            scores = scores[~np.isnan(scores)]

        
        sum = np.sum(our_sampling)

        baseline_sample.append(sum/1000)

        ind_sorted = np.argsort(scores)

        interval = int(0.025*len(scores))


        precision_list_forwards = []
        for i in range(40):
            accepted_inds = ind_sorted[i*interval:]
            correct = 0

            for ind in accepted_inds:
                if sampling[ind]:
                    correct = correct + 1

            precision = correct / len(accepted_inds)
            precision_list_forwards.append(precision)


        forward_precs_sample.append(max(precision_list_forwards))

    backward_precs_sample_trials.append(backward_precs_sample)
    forward_precs_sample_trials.append(forward_precs_sample)
    baseline_sample_trials.append(baseline_sample)



forward_precs_sample_trials_np = np.array(forward_precs_sample_trials)
baseline_sample_trials_np = np.array(baseline_sample_trials)

#backward_precs_sample_trials_avg = np.mean(backward_precs_sample_trials_np, axis = 0)
forward_precs_sample_trials_avg = np.mean(forward_precs_sample_trials_np, axis = 0)
baseline_sample_trials_avg = np.mean(baseline_sample_trials_np, axis = 0)

#NOTE: going to make confidence intervals now
forward_precs_sample_trials_std = np.std(forward_precs_sample_trials_np, axis = 0)
forward_precs_sample_trials_ci = (1.96 / math.sqrt(5)) * forward_precs_sample_trials_std


baseline_attack = []

plt.figure()

for i in range(5):
  #baseline_attack.append(np.array([0.1,0.2,0.3,0.4,0.45,0.5,0.55,0.6,0.7,0.8,0.9]))
  baseline_attack.append(np.array([0.1,0.3,0.5,0.7,0.9]))

baseline_attack_trials = []

for i in range(len(forward_precs_sample_trials_np)):
  baseline_attack_trials.append(baseline_attack)


#Now Plot

#print(f"i {i}, eps {P_x_shadow}")
plt.plot([0.1,0.3,0.5,0.7,0.9], forward_precs_sample_trials_avg)
plt.fill_between([0.1,0.3,0.5,0.7,0.9], forward_precs_sample_trials_avg - forward_precs_sample_trials_ci, forward_precs_sample_trials_avg + forward_precs_sample_trials_ci, alpha = 0.1)

plt.plot([0.1,0.3,0.5,0.7,0.9], [0.1,0.3,0.5,0.7,0.9], label = 'baseline', linestyle='dashed')

plt.xlabel("Sampling Probability $P_{x^*}(1)$")
plt.ylabel("Positive Accuracy")
plt.title(f"{dataset} with No DP")


#plt.legend()


if nan_allowed:
    plt.tight_layout()
    plt.savefig(f'Plots/{dataset}_no_dp_LIRA_POS_ACC_5_trials.pdf')
else:
    print("Saving no nan plot")
    plt.tight_layout()
    plt.savefig(f'Plots/{dataset}_no_dp_LIRA_POS_ACC_5_trials_no_nan.pdf')




