import numpy as np
import jax.numpy as jnp
from model import OutlierRegression
from posterior import Basic
from jax import random, vmap
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
from tqdm import tqdm
import os

plt.rcParams.update({'font.size': 22})

Ns = [10, 20, 40, 80, 160]
seeds = [0, 1, 2, 3, 4]
objs = ['IntervalScore', 'PACMVIBasic','PVICRPS', 'VIBasic']
objmapping = {'IntervalScore':'$\\text{PVI}_{\\text{IS}}$', 'PACMVIBasic':'$\\text{PVI}_{\\text{log}}$',
              'VIBasic':'$\\text{VI}$', 'PVICRPS':'$\\text{PVI}_{\\text{CRPS}}$'}
pos = Basic(3)
prediction_sample = 100
alpha = 0.1
alpha2 = 0.5
scale = 10
data = []
for n in tqdm(Ns):
    for seed in seeds:
        for obj in objs:
            lamb = 0.01# if obj == 'VIBasic' else 0.01
            pvipath = f'result/OutlierRegression_{scale}_{n}_{alpha2}_20_1000_Basic_{obj}/VIBasic_{lamb}_{seed}_{100}'
            if not os.path.exists(pvipath):
                continue
            with open(pvipath, 'r') as f:
                test_ll = f.readline().split()[-1]
            data.append({'outlier N':str(n), 'objective':objmapping[obj], 'test ll':float(test_ll)})
            print(scale, obj, test_ll)

data = pd.DataFrame(data)
print(data)
data["outlier N"] = pd.Categorical(data["outlier N"], categories=['1', '10', '20', '40', '80', '160'], ordered=True)
ax = sns.lineplot(data=data, x='outlier N', y='test ll', hue='objective', errorbar=None)
ax.legend(loc='lower left', bbox_to_anchor=(-0.03, -0.09), handletextpad=0.4,labelspacing=0.3, columnspacing=0.8, ncol=2)
plt.ticklabel_format(
    axis='y',
    style='sci',
    scilimits=(0, 3)
)
legend = ax.legend_
legend.set_title(None)
legend.set_frame_on(False)
#plt.axhline(y=0.9, color='black', linestyle='--')
plt.ylim([-3e4, -1e4])
plt.title('Test log likelihood')
plt.tight_layout()
plt.show()
plt.clf()