import numpy as np
import jax.numpy as jnp
from model import OutlierRegression
from posterior import Basic
from jax import random, vmap
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
from tqdm import tqdm
import os

plt.rcParams.update({'font.size': 22})

scales = [1, 5, 10, 20, 40]
seeds = [0, 1, 2, 3, 4]
objs = ['IntervalScore', 'PACMVIBasic','PVICRPS', 'VIBasic']
objmapping = {'IntervalScore':'$\\text{PVI}_{\\text{IS}}$', 'PACMVIBasic':'$\\text{PVI}_{\\text{log}}$',
              'VIBasic':'$\\text{VI}$', 'PVICRPS':'$\\text{PVI}_{\\text{CRPS}}$'}
pos = Basic(3)
prediction_sample = 100
n = 40
alpha = 0.1
alpha2 = 0.5
data = []
for scale in tqdm(scales):
    for seed in seeds:
        for obj in objs:
            lamb = 0.01#0.0 if obj == 'VIBasic' else 0.01
            pvipath = f'result/OutlierRegression_{scale}_{n}_{alpha2}_20_1000_Basic_{obj}/VIBasic_{lamb}_{seed}_{100}'
            if not os.path.exists(pvipath):
                continue
            with open(pvipath, 'r') as f:
                test_ll = f.readline().split()[-1]
            data.append({'outlier scale':str(scale), 'objective':objmapping[obj], 'test ll':float(test_ll)})
            print(scale, obj, test_ll)

data = pd.DataFrame(data)

data["outlier scale"] = pd.Categorical(data["outlier scale"], categories=['1', '5', '10', '20', '40'], ordered=True)
print(data)
ax = sns.lineplot(data=data, x='outlier scale', y='test ll', hue='objective', errorbar=None)
ax.legend(loc='lower left', bbox_to_anchor=(-0.03, -0.09), handletextpad=0.4,labelspacing=0.3, columnspacing=0.8, ncol=2)
plt.ticklabel_format(
    axis='y',
    style='sci',
    scilimits=(0, 3)
)
legend = ax.legend_
legend.set_title(None)
legend.set_frame_on(False)
plt.ylim([-3e4, -1e4])

#plt.axhline(y=0.9, color='black', linestyle='--')
#plt.ylim([0.8, 1.0])
plt.title('Test log likelihood')
plt.tight_layout()
plt.show()
plt.clf()