import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
from posterior import Basic, BasicFullRank
import jax.numpy as jnp
from jax import random
import scipy

def main():
    xtrain = []
    ntrain = []
    ytrain = []
    with open('data/golf.dat', 'r') as f:
        lines = f.readlines()
        for l in lines:
            a, b, c = l.split()
            xtrain.append(float(a))
            ntrain.append(int(b))
            ytrain.append(int(c))

    for x, n, y in zip(xtrain, ntrain, ytrain):
        p = y/n
        std = np.sqrt(p * (1-p)/n)
        plt.plot(x, p, marker='o', markersize=2,  color='r')
        plt.vlines(x, p-std, p+std,  color='r', linewidth=0.5)

    xtest = []
    ntest = []
    ytest = []
    with open('data/golf2.dat', 'r') as f:
        lines = f.readlines()
        for l in lines:
            a, b, c = l.split()
            xtest.append(float(a))
            ntest.append(int(b))
            ytest.append(int(c))

    for x, n, y in zip(xtest, ntest, ytest):
        p = y/n
        std = np.sqrt(p * (1-p)/n)
        plt.plot(x, p, marker='o', markersize=2,  color='b')
        plt.vlines(x, p-std, p+std,  color='b', linewidth=0.5)
    #plt.show()

    with open('result/Golf_0.0_100_BasicFullRank_VIBasic/VIBasic_0.0_0_100', 'r') as f:
        line = f.readline().split()[:-1]
        param = np.array([float(l) for l in line])
    post = BasicFullRank(2)
    samples = np.array(post.sample(random.PRNGKey(0), param, 100))
    data = []
    for i in range(800):
        x = i/10
        logits = samples[...,0] * x + samples[...,1]
        p = scipy.special.expit(logits)
        for prob in p:
            data.append({'x':x, 'y':prob, 'curve':'VI'})


    with open('result/Golf_0.0_100_BasicFullRank_PACMVIBasic/VIBasic_0.0_0_100', 'r') as f:
        line = f.readline().split()[:-1]
        param = np.array([float(l) for l in line])
    samples = np.array(post.sample(random.PRNGKey(0), param, 100))
    for i in range(800):
        x = i/10
        logits = samples[...,0] * x + samples[...,1]
        p = scipy.special.expit(logits)
        for prob in p:
            data.append({'x':x, 'y':prob, 'curve':'PVI'})


    samples3 = np.load('result/Golf_0.0_100_5_5/PVI_0_100/sample.npz')['samples'][:100]
    for i in range(800):
        x = i/10
        logits = samples3[...,0] * x + samples3[...,1]
        p = scipy.special.expit(logits)
        for prob in p:
            data.append({'x':x, 'y':prob, 'curve':'PVI-flow'})


    data = pd.DataFrame(data)
    sns.lineplot(data=data, x='x',y='y', hue='curve', errorbar='sd')
    plt.title('Posterior using the red training points')
    plt.show()

if __name__ == '__main__':
    main()