import numpy as np
import jax.numpy as jnp
from model import IntervalRegression
from posterior import Basic
from jax import random, vmap
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
from tqdm import tqdm
import os

plt.rcParams.update({'font.size': 22})

dfs = [0, 1, 2, 3, 4]
seeds = [0, 1, 2,3,4,5,6,7,8,9]
objs = ['IntervalScore', 'PACMVIBasic','PVICRPS', 'VIBasic']
objmapping = {'IntervalScore':'$\\text{PVI}_{\\text{IS}}$', 'PACMVIBasic':'$\\text{PVI}_{\\text{log}}$',
              'VIBasic':'$\\text{VI}$', 'PVICRPS':'$\\text{PVI}_{\\text{CRPS}}$'}
pos = Basic(3)
prediction_sample = 10000
alpha = 0.1
alpha2 = 0.5
data = []
for df in tqdm(dfs):
    mol = IntervalRegression(df=df, n=1000, alpha=alpha2, g=20)
    _, test_y = mol.data(random.PRNGKey(1))
    for seed in seeds:
        for obj in objs:
            pvipath = f'result/intervalRegression_{df}_{alpha2}_20_1000_Basic_{obj}/VIBasic_0.01_{seed}_{100}'
            if not os.path.exists(pvipath):
                continue
            with open(pvipath, 'r') as f:
                line = f.readline().split()[:-1]
            line = [float(l) for l in line]
            theta_sample = pos.sample(random.PRNGKey(0), jnp.array(line), prediction_sample)

            rng_key2 = random.split(random.PRNGKey(2), theta_sample.shape[0])
            ys = vmap(mol.sample_test_datapoint)(rng_key2, theta_sample)
            if ys.shape[1] == 1:
                ys = ys[:, 0, ...]
            y0 = jnp.swapaxes(ys, 0, 1)
            y = test_y
            l = jnp.quantile(y0, alpha / 2, axis=1)
            u = jnp.quantile(y0, 1 - alpha / 2, axis=1)
            coverage = np.sum((l <= y) & (y <= u)) / y.shape[0]
            dflabel = str(df) if df>0 else '$\\infty$'
            data.append({'df':dflabel, 'objective':objmapping[obj], 'coverage':float(coverage)})
            print(dflabel, obj, coverage)

data = pd.DataFrame(data)

data["df"] = pd.Categorical(data["df"], categories=['1', '2', '3', '4', '$\\infty$'], ordered=True)
ax = sns.lineplot(data=data, x='df', y='coverage', hue='objective', errorbar=None)
ax.legend(loc='upper right', bbox_to_anchor=(1.05, 1.05), handletextpad=0.4,labelspacing=0.3, columnspacing=0.8, ncol=2)
legend = ax.legend_
legend.set_title(None)
legend.set_frame_on(False)
plt.axhline(y=0.9, color='black', linestyle='--')
plt.ylim([0.8, 1.0])
plt.title('Coverage on test set')
plt.tight_layout()
plt.show()
plt.clf()