import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
from posterior import Basic, BasicFullRank
import jax.numpy as jnp
from jax import random
import scipy

def main():
    data = []
    with open('result/Golf2_0.0_100_BasicFullRank_VIBasic/VIBasic_0.0_0_100', 'r') as f:
        line = f.readline().split()[:-1]
        param = np.array([float(l) for l in line])
    post = BasicFullRank(2)
    samples1 = np.array(post.sample(random.PRNGKey(0), param, 1000))
    for s in samples1:
        data.append({'b':s[0], 'a':s[1], 'method':'VI'})


    with open('result/Golf2_0.0_100_BasicFullRank_PACMVIBasic/KLPrior_0.1_0_100', 'r') as f:
        line = f.readline().split()[:-1]
        param = np.array([float(l) for l in line])
    samples2 = np.array(post.sample(random.PRNGKey(0), param, 1000))
    for s in samples2:
        data.append({'b':s[0], 'a':s[1], 'method':'PVI'})

    samples3 = np.load('result/Golf2/PVI_0_100_0.0_VIBasic/sample.npz')['samples']
    for s in samples3:
        data.append({'b':s[0], 'a':s[1], 'method':'PVI-flow'})

    data = pd.DataFrame(data)
    sns.kdeplot(data=data, x='b', y='a', hue='method', common_norm=False)
    #plt.xlim([-0.5, -0.])
    #plt.ylim([2., 4])
    plt.show()

if __name__ == '__main__':
    main()