import numpy as np
import seaborn as sns 
import argparse
import matplotlib
from matplotlib import pyplot as plt
import math
import os

from sklearn import metrics

matplotlib.rcParams.update({'font.size': 18})


def gen_tpr_fpr(sampling, scores):
    tpr = [0.0]
    fpr = [0.0]
    metric_max, metric_min = scores.max(), scores.min()
    metric_range = metric_max - metric_min

    ind_sorted = np.argsort(scores)
    interval = int((1/40)*len(scores))

    
    for i in range(40):
        accepted_inds = ind_sorted[i*interval:]
        correct = 0

        for ind in accepted_inds:
            if sampling[ind]:
                correct = correct + 1
        
        incorrect = len(accepted_inds) - correct
        tpr.append(correct / sampling.sum())
        fpr.append(incorrect / (1- sampling).sum())

    

    tpr.append(1.0)
    fpr.append(1.0)

    print(tpr,fpr)

    return np.array(tpr), np.array(fpr)


parser = argparse.ArgumentParser(description='Settings')
parser.add_argument('--dataset', default = 'mnist', choices=['cifar10', 'fmnist', 'mnist', 'svhn_ext'])
parser.add_argument('--nan_allowed', default = "no", choices = ["yes","no"])

args = parser.parse_args()

dataset = args.dataset

if args.nan_allowed == "yes":
    nan_allowed = True

else:
    nan_allowed = False

print(f"nan allowed = {nan_allowed}")

eps = 3.0


plt.figure()

tpr_trials = []
fpr_values = np.arange(0.0,1.01, 0.01)

for Trial in [1,2,3,4,5]:

    print(f"On Trial {Trial}")

    tpr_sample = []
    for i in [1,3,5,7,9]:

        print(f"On {i}")

        scores = np.load(f'lira_scores/scores_1000_{dataset}_{eps}_{i}_{Trial}.npy')

        sampling = np.load(f'lira_samplings/sampling_{dataset}_{int(eps)}_target_{i}_{Trial}_0.npy')
        our_sampling = sampling[:1000]


        if not nan_allowed:
            print("removing nan")
            our_sampling = our_sampling[~np.isnan(scores)]
            scores = scores[~np.isnan(scores)]
        
        
        tpr, fpr = gen_tpr_fpr(our_sampling,scores)
        print(len(fpr), len (tpr))

        #TODO: do np.interp(x, fpr, tpr) to have standard x values, then plot the average of these curves over the trial with ci
        tpr_interp = np.interp(fpr_values, fpr, tpr)

        tpr_sample.append(tpr_interp)

    tpr_trials.append(tpr_sample)


tpr_trials_np = np.array(tpr_trials)
tpr_trials_avg = np.mean(tpr_trials_np, axis = 0)
tpr_trials_std = np.std(tpr_trials_np, axis = 0)
tpr_trials_ci = (1.96 / math.sqrt(5))* tpr_trials_std

baseline = np.arange(0,1.0, 0.001)
plt.plot(baseline, baseline, linestyle = 'dashed')

for i,sampling in enumerate([0.1,0.3,0.5,0.7,0.9]):
    plt.plot(fpr_values,tpr_trials_avg[i], label =f'{sampling}')
    plt.fill_between(fpr_values, tpr_trials_avg[i] - tpr_trials_ci[i], tpr_trials_avg[i] + tpr_trials_ci[i], alpha = 0.1)




plt.xlabel("False Positive Rate")
plt.xscale("log")
plt.ylabel("True Positive Rate")

yscale_log = False

if yscale_log:
    plt.yscale("log")
plt.legend()

plt.title(f"{dataset} with $\epsilon =$ {eps}")
plt.tight_layout()
plt.savefig(f"Plots/tpr_fpr_{dataset}_{eps}_{yscale_log}.pdf")



