import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
from posterior import Basic, BasicFullRank
import jax.numpy as jnp
from jax import random
import scipy
import matplotlib
matplotlib.rcParams.update({'font.size': 22})
def main():
    xtrain = []
    ntrain = []
    ytrain = []
    with open('data/golf.dat', 'r') as f:
        lines = f.readlines()
        for l in lines:
            a, b, c = l.split()
            xtrain.append(float(a))
            ntrain.append(int(b))
            ytrain.append(int(c))

    for x, n, y in zip(xtrain, ntrain, ytrain):
        p = y/n
        std = np.sqrt(p * (1-p)/n)
        plt.plot(x, p, marker='o', markersize=2,  color='b')
        plt.vlines(x, p-std, p+std,  color='b', linewidth=0.5)

    xtest = []
    ntest = []
    ytest = []
    with open('data/golf2.dat', 'r') as f:
        lines = f.readlines()
        for l in lines:
            a, b, c = l.split()
            xtest.append(float(a))
            ntest.append(int(b))
            ytest.append(int(c))

    for x, n, y in zip(xtest, ntest, ytest):
        p = y/n
        std = np.sqrt(p * (1-p)/n)
        plt.plot(x, p, marker='o', markersize=2,  color='r')
        plt.vlines(x, p-std, p+std,  color='r', linewidth=0.5)
    #plt.show()

    with open('result/Golf2_0.0_100_BasicFullRank_VIBasic/VIBasic_0.0_0_100', 'r') as f:
        line = f.readline().split()[:-1]
        param = np.array([float(l) for l in line])
    post = BasicFullRank(2)
    samples = np.array(post.sample(random.PRNGKey(0), param, 100))
    data = []
    for i in range(800):
        x = i/10
        logits = samples[...,0] * x + samples[...,1]
        p = scipy.special.expit(logits)
        for prob in p:
            data.append({'x':x, 'y':prob, 'curve':'Bayes'})


    with open('result/Golf2_0.0_100_BasicFullRank_PACMVIBasic/KLPrior_0.1_0_100', 'r') as f:
        line = f.readline().split()[:-1]
        param = np.array([float(l) for l in line])
    samples = np.array(post.sample(random.PRNGKey(0), param, 100))
    for i in range(800):
        x = i/10
        logits = samples[...,0] * x + samples[...,1]
        p = scipy.special.expit(logits)
        for prob in p:
            data.append({'x':x, 'y':prob, 'curve':'PVI'})
        print(np.mean(p), np.median(p), np.std(p))

    samples3 = np.load('result/Golf2/PVI_0_100_0.0_VIBasic/sample.npz')['samples']
    np.random.shuffle(samples3)
    samples3 = samples3[:100]
    for i in range(800):
        x = i/10
        logits = samples3[...,0] * x + samples3[...,1]
        p = scipy.special.expit(logits)
        for prob in p:
            data.append({'x':x, 'y':prob, 'curve':'PVI-flow'})


    data = pd.DataFrame(data)
    sns.lineplot(data=data, x='x',y='y', hue='curve', errorbar='sd')
    plt.xlabel('Distance from hole (feet)')
    plt.ylabel('Probability of success')
    plt.title('Logistic regression')
    plt.tight_layout()
    plt.show()

if __name__ == '__main__':
    main()