import pandas as pd
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
import pickle


from datetime import datetime
now = datetime.now()
now_str = now.strftime("%Y-%m-%d-%H-%M-%S")


N = 1000 # number of different prompts
cutoff = 10000 # best of n, max 10000


fname1 = '/data/private_models/[ANONYMIZED]_models/misc/generations/scores/combined_scores_llama_7b_large_gold1.pkl'
std1 = 2.6550517082214355
fname2 = '/data/private_models/[ANONYMIZED]_models/misc/generations/scores/combined_scores_llama_7b_large_gold2.pkl'
std2 = 1.784685730934143

with open(fname1, 'rb') as f:
    data1 = pickle.load(f)
with open(fname2, 'rb') as f:
    data2 = pickle.load(f)

data1, data2 = np.array(data1) / std1, np.array(data2) / std2
assert data1.shape == data2.shape
data1, data2 = data1[:N, :cutoff], data2[:N, :cutoff]
gold_data = (data1 + data2) / 2


base_dir = '/data/private_models/[ANONYMIZED]_models/misc/generations/scores/'
proxy_names = [
    'combined_scores_pythia_1-4_small_sam1.pkl',
    'combined_scores_pythia_1-4_large_sam1.pkl',
]
# proxy_stds = [
#     3.580739736557007,
#     5.127547740936279,
#     8.5927095413208
# ]
# proxy_datas = []
# for i, name in enumerate(proxy_names):
#     full_dir = base_dir + name
#     with open(full_dir, 'rb') as f:
#         proxy_data = pickle.load(f)
#     proxy_data = np.array(proxy_data) / proxy_stds[i]
#     proxy_datas.append(proxy_data)


import matplotlib.pyplot as plt


kl = []
for i in range(cutoff):
    kl.append(np.log(i+1) - (i)/(i+1))


use_kl = False
if use_kl:
    xlabel = 'kl'
    xplot = kl
else:
    xlabel = "num examples"
    xplot = range(len(y_gold))

colors = ['green', 'blue', 'purple', 'pink', 'black', 'red', 'orange']


exp_title = "llama_no_sam"

fig, ax1 = plt.subplots()
lines = []
for i, name in enumerate(proxy_names):
    with open(f"plot_files/{name}_aggregate.pkl", "rb") as f:
        y_proxy, y_gold = pickle.load(f)
        y_proxy = np.array(y_proxy) - y_proxy[0]
        y_gold = np.array(y_gold) - y_gold[0]
        
    line, = ax1.plot(xplot, y_gold, label=f"gold{name}", color=colors[i])
    lines.append(line)
        
    line, = ax1.plot(xplot, y_proxy, '--', label=f"proxy{name}", color=colors[i])
    lines.append(line)

ax1.set_xlabel(xlabel)
ax1.set_ylabel('Reward')

    
labels = [line.get_label() for line in lines]
plt.legend(lines, labels, bbox_to_anchor=(1.1, 1.03), loc='upper left')
plt.title(f"Aggregate - {exp_title}")
plt.show()


plt.savefig(f'plots/{exp_title}_aggregate.png')


colors = ['green', 'blue', 'purple']
fig, ax1 = plt.subplots()
lines = []
for i, name in enumerate(proxy_names):
    with open(f"plot_files/{name}_learned.pkl", "rb") as f:
        y_proxy, y_gold = pickle.load(f)
        y_proxy = np.array(y_proxy) - y_proxy[0]
        y_gold = np.array(y_gold) - y_gold[0]
        
    line, = ax1.plot(xplot, y_gold, label=f"gold{name}", color=colors[i])
    lines.append(line)
        
    line, = ax1.plot(xplot, y_proxy, '--', label=f"proxy{name}", color=colors[i])
    lines.append(line)

ax1.set_xlabel(xlabel)
ax1.set_ylabel('Reward')

    
labels = [line.get_label() for line in lines]

plt.title(f"Learned - {exp_title}")
plt.legend(lines, labels, bbox_to_anchor=(1.1, 1.03), loc='upper left')
plt.show()


plt.savefig(f'plots/{exp_title}_learned.png')


colors = ['green', 'blue', 'purple']
fig, ax1 = plt.subplots()
lines = []
for i, name in enumerate(proxy_names):
    with open(f"plot_files/{name}_gamed.pkl", "rb") as f:
        y_proxy, y_gold = pickle.load(f)
        y_proxy = np.array(y_proxy) - y_proxy[0]
        y_gold = np.array(y_gold) - y_gold[0]
        
    line, = ax1.plot(xplot, y_gold, label=f"gold{name}", color=colors[i])
    lines.append(line)
        
    line, = ax1.plot(xplot, y_proxy, '--', label=f"proxy{name}", color=colors[i])
    lines.append(line)

ax1.set_xlabel(xlabel)
ax1.set_ylabel('Reward')

plt.title(f"Gamed - {exp_title}")
labels = [line.get_label() for line in lines]
plt.legend(lines, labels, bbox_to_anchor=(1.1, 1.03), loc='upper left')
plt.show()


plt.savefig(f'plots/{exp_title}_gamed.png')
