import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
from posterior import Basic, BasicFullRank
import jax.numpy as jnp
import scipy
from jax import random
import scipy
import matplotlib
matplotlib.rcParams.update({'font.size': 22})
def main():
    r = (1.68 / 2) / 12
    R = (4.25 / 2) / 12
    overshot = 1.
    distance_tolerance = 3.
    xtrain = []
    ntrain = []
    ytrain = []
    with open('data/golf.dat', 'r') as f:
        lines = f.readlines()
        for l in lines:
            a, b, c = l.split()
            xtrain.append(float(a))
            ntrain.append(int(b))
            ytrain.append(int(c))

    for x, n, y in zip(xtrain, ntrain, ytrain):
        p = y/n
        std = np.sqrt(p * (1-p)/n)
        plt.plot(x, p, marker='o', markersize=2,  color='b')
        plt.vlines(x, p-std, p+std,  color='b', linewidth=0.5)

    xtest = []
    ntest = []
    ytest = []
    with open('data/golf2.dat', 'r') as f:
        lines = f.readlines()
        for l in lines:
            a, b, c = l.split()
            xtest.append(float(a))
            ntest.append(int(b))
            ytest.append(int(c))

    for x, n, y in zip(xtest, ntest, ytest):
        p = y/n
        std = np.sqrt(p * (1-p)/n)
        plt.plot(x, p, marker='o', markersize=2,  color='r')
        plt.vlines(x, p-std, p+std,  color='r', linewidth=0.5)
    #plt.show()

    with open('result/GolfGeo3_0.0_100_BasicFullRank_VIBasic/VIBasic_0.0_0_100', 'r') as f:
        line = f.readline().split()[:-1]
        param = np.array([float(l) for l in line])
    post = BasicFullRank(3)
    samples = np.array(post.sample(random.PRNGKey(0), param, 100))
    data = []
    for i in range(1,800):
        x = i/10
        sigma_angle = np.exp(samples[...,0])
        sigma_distance = np.exp(samples[...,1])
        sigma_y = np.exp(samples[...,2])
        p_angle = 2 * scipy.stats.norm.cdf( jnp.arcsin((R - r) / x) / sigma_angle) - 1
        p_distance = scipy.stats.norm.cdf((distance_tolerance - overshot) / ((x + overshot) * sigma_distance))\
                        - scipy.stats.norm.cdf((-overshot) / ((x + overshot) * sigma_distance))
        p = p_angle * p_distance
        for prob in p:
            data.append({'x':x, 'y':prob, 'curve':'Bayes'})


    with open('result/GolfGeo3_0.0_100_BasicFullRank_PACMVIBasic/VIBasic_0.0_0_100', 'r') as f:
        line = f.readline().split()[:-1]
        param = np.array([float(l) for l in line])
    samples = np.array(post.sample(random.PRNGKey(0), param, 100))
    for i in range(1,800):
        x = i/10
        sigma_angle = np.exp(samples[...,0])
        sigma_distance = np.exp(samples[...,1])
        sigma_y = np.exp(samples[...,2])
        p_angle = 2 * scipy.stats.norm.cdf( jnp.arcsin((R - r) / x) / sigma_angle) - 1
        p_distance = scipy.stats.norm.cdf((distance_tolerance - overshot) / ((x + overshot) * sigma_distance))\
                        - scipy.stats.norm.cdf((-overshot) / ((x + overshot) * sigma_distance))
        p = p_angle * p_distance
        for prob in p:
            data.append({'x':x, 'y':prob, 'curve':'PVI'})


    data = pd.DataFrame(data)
    sns.lineplot(data=data, x='x',y='y', hue='curve', errorbar='sd')
    plt.xlabel('Distance from hole (feet)')
    plt.ylabel('Probability of success')
    plt.title('Geometric model')
    plt.tight_layout()
    plt.show()

if __name__ == '__main__':
    main()