from matplotlib import pyplot as plt 
import matplotlib
import numpy as np
from sklearn.metrics import roc_curve, roc_auc_score
import os
import seaborn as sns
sns.set_theme()

matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42
SMALL_SIZE = 15
MEDIUM_SIZE = 20
BIGGER_SIZE = 20

plt.rc('font', size=SMALL_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=SMALL_SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=MEDIUM_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=9)    # legend fontsize
plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title


file = "R5_P_60_40_1000"

plt.figure(dpi=250)
    
scores = np.load(os.path.join("scores", "scores_{}.npy".format(file)), allow_pickle=True)
tokens = np.load(os.path.join("scores", "tokens_{}.npy".format(file)), allow_pickle=True)[1]
ppl = np.load(os.path.join("scores", "ppl13b_{}.npy".format(file)), allow_pickle=True)

T = [tokens.mean(), tokens.min(), tokens.max(), tokens.std()]
print(T)

fpr, tpr, thresholds = roc_curve([0]*len(scores[0]) + [1]*len(scores[1]), scores[0].tolist() + scores[1].tolist())
auc = roc_auc_score([0]*len(scores[0]) + [1]*len(scores[1]), scores[0].tolist() + scores[1].tolist())
met = np.max(tpr[fpr <= 0.01])
plt.plot(fpr, tpr, label="{:12s} AUROC:{:.3f}, TPR@1%FPR:{:.3f}, ppl:{:.1f}".format("Watermarking", auc, met, ppl[0].mean()))

for i in range(2, len(scores)):
    fpr, tpr, thresholds = roc_curve([0]*len(scores[0]) + [1]*len(scores[i]), scores[0].tolist() + scores[i].tolist())
    auc = roc_auc_score([0]*len(scores[0]) + [1]*len(scores[i]), scores[0].tolist() + scores[i].tolist())
    met = np.max(tpr[fpr <= 0.01])
    plt.plot(fpr, tpr, label="{:19s} AUROC:{:.3f}, TPR@1%FPR:{:.3f}, ppl:{:.1f}".format("pp" + str(i-1), auc, met, ppl[i-1].mean()))
    
best = scores[1:].min(0)
fpr, tpr, thresholds = roc_curve([0]*len(scores[0]) + [1]*len(scores[1]), scores[0].tolist() + best.tolist())
auc = roc_auc_score([0]*len(scores[0]) + [1]*len(scores[1]), scores[0].tolist() + best.tolist())
met = np.max(tpr[fpr <= 0.01])
ind = np.where(scores[1:] == best[np.newaxis,:])
plt.plot(fpr, tpr, label="{:16s} AUROC:{:.3f}, TPR@1%FPR:{:.3f}, ppl:{:.1f}".format("Best of ppi", auc, met, ppl[ind].mean()))
    
plt.plot(np.linspace(0, 1, 100), np.linspace(0, 1, 100), ':', label='{:15s} AUROC:{:.3f}, TPR@1%FPR:{:.3f}'.format("Random", 0.5, 0), color='k')
    
plt.legend()
plt.xlabel('FPR')
plt.ylabel('TPR')
plt.tight_layout()
plt.savefig('wm_1000.png')