import numpy as np
import seaborn as sns 
import argparse
import matplotlib
from matplotlib import pyplot as plt
import math
import os

matplotlib.rcParams.update({'font.size': 18})


parser = argparse.ArgumentParser(description='Settings')
#parser.add_argument('--model_type', default = 'target', choices=['target','shadow'])
#parser.add_argument('--P_x', default=0.5, type = float)
#parser.add_argument('--target_epsilon', default = 5.0, type = float)
parser.add_argument('--dataset', default = 'mnist', choices=['cifar10', 'fmnist', 'mnist', 'svhn_ext'])
parser.add_argument('--nan_allowed', default = "no", choices = ["yes","no"])

args = parser.parse_args()

#P_x = args.P_x
dataset = args.dataset

if args.nan_allowed == "yes":
    nan_allowed = True

else:
    nan_allowed = False

print(f"nan allowed = {nan_allowed}")
#target_epsilon = args.target_epsilon


backward_precs_eps_trials = []
forward_precs_eps_trials = []
baseline_eps_trials = []

for Trial in [1,2,3,4,5]:
    #print(f"On Trial {Trial}")
    
    backward_precs_eps = []
    forward_precs_eps = []
    baseline_eps = []

    for eps in [1.0,3.0,5.0,7.0,9.0,12.0,15.0]:
        #print(f"On {eps}")

        backward_precs = []
        forward_precs = []
        baseline = []

        for i in [1,2,3,4,5,6,7,8,9]:
            #print(f"On {i}")
        
            scores = np.load(f'lira_scores/scores_1000_{dataset}_{eps}_{i}_{Trial}.npy')
            #print(f"Number of nan: {np.count_nonzero(np.isnan(scores))}")

            if not nan_allowed:
                print("removing nan")
                scores = scores[~np.isnan(scores)]

            sampling = np.load(f'lira_samplings/sampling_{dataset}_{int(eps)}_target_{i}_{Trial}_0.npy')

            our_sampling = sampling[:1000]
            sum = np.sum(our_sampling)

            baseline.append(sum/1000)

            ind_sorted = np.argsort(scores)

            interval = int(0.025*len(scores))

            precision_list_backwards = []
            for i in range(1,41):
                accepted_inds = ind_sorted[:i*interval]
                correct = 0

                for ind in accepted_inds:
                    if sampling[ind]:
                        correct = correct + 1

                precision = correct / len(accepted_inds)
                precision_list_backwards.append(precision)

            backward_precs.append(max(precision_list_backwards))

            precision_list_forwards = []
            for i in range(40):
                accepted_inds = ind_sorted[i*interval:]
                correct = 0

                for ind in accepted_inds:
                    if sampling[ind]:
                        correct = correct + 1

                precision = correct / len(accepted_inds)
                precision_list_forwards.append(precision)

            forward_precs.append(max(precision_list_forwards))

        backward_precs_eps.append(backward_precs)
        forward_precs_eps.append(forward_precs)
        baseline_eps.append(baseline)

    backward_precs_eps_trials.append(backward_precs_eps)
    forward_precs_eps_trials.append(forward_precs_eps)
    baseline_eps_trials.append(baseline_eps)


backward_precs_eps_trials_np = np.array(backward_precs_eps_trials)
forward_precs_eps_trials_np = np.array(forward_precs_eps_trials)
baseline_eps_trials_np = np.array(baseline_eps_trials)

backward_precs_eps_trials_avg = np.mean(backward_precs_eps_trials_np, axis = 0)
forward_precs_eps_trials_avg = np.mean(forward_precs_eps_trials_np, axis = 0)
baseline_eps_trials_avg = np.mean(baseline_eps_trials_np, axis = 0)

baseline_attack = []

plt.figure()

for i in range(7):
  #baseline_attack.append(np.array([0.1,0.2,0.3,0.4,0.45,0.5,0.55,0.6,0.7,0.8,0.9]))
  baseline_attack.append(np.array([0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9]))

baseline_attack_trials = []

for i in range(len(forward_precs_eps_trials_np)):
  baseline_attack_trials.append(baseline_attack)

difference_trials = forward_precs_eps_trials_np - baseline_attack_trials

difference_reshape = difference_trials.reshape(35,9)
difference_avg = np.mean(difference_reshape, axis = 0)
difference_std = np.std(difference_reshape, axis=0)
difference_ci = (1.96 / math.sqrt(len(difference_reshape))) * difference_std


for i,eps in enumerate([1.0,3.0,5.0,7.0,9.0,12.0,15.0]):
  print(f"i {i}, eps {eps}")
  plt.plot([0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9], forward_precs_eps_trials_avg[i], label = f"eps = {eps}")

plt.plot([0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9], [0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9], label = 'baseline', linestyle='dashed')

plt.xlabel("sampling probability")
plt.ylabel("positive accuracy")

'''
if dataset == 'svhn_ext':
    plt.title("SVHN Extended LiRA Positive Accuracy")

else:
    plt.title(f"{dataset} LiRA Positive Accuracy")
'''

plt.legend()


if nan_allowed:
    plt.tight_layout()
    plt.savefig(f'Plots/{dataset}_LIRA_POS_ACC_5_trials.pdf')
else:
    print("Saving no nan plot")
    plt.tight_layout()
    plt.savefig(f'Plots/{dataset}_LIRA_POS_ACC_5_trials_no_nan.pdf')

'''
plt.plot([0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9], difference_avg)
plt.fill_between([0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9], difference_avg - difference_ci, difference_avg + difference_ci, alpha = 0.1)

plt.xlabel("sampling probability")
plt.ylabel("Average Improvement Over Baseline")

if dataset == 'svhn_ext':
    plt.title("SVHN Extended LiRA Average Positive Accuracy Improvement")

else:
    plt.title(f"{dataset} LiRA Average Positive Accuracy Improvement")


if nan_allowed:
    plt.savefig(f'Plots/{dataset}_LIRA_improvement_5_trials.pdf')
else:
    print("Saving no nan plot")
    plt.savefig(f'Plots/{dataset}_LIRA_improvement_5_trials_no_nan.pdf')

'''



#files.download("CIFAR10_LIRA_improvement_5_trials.pdf") 


