import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import jax.numpy as jnp
from posterior import BasicFullRank
import scipy
import numpyro.distributions as dist
from jax import random
plt.rcParams.update({'font.size': 25})

da = 50
db = 5
dc = 5
def main():
    with open('result/Voting0_0.0_1000_BasicFullRank_PACMVIBasic/VIBasic_0.0_0_100', 'r') as f:
        line = f.readline().split()
        m1 = float(line[0])
        s1 = np.exp(float(line[1]))

    with open('result/Voting0_0.0_1000_BasicFullRank_VIBasic/VIBasic_0.0_0_100', 'r') as f:
        line = f.readline().split()
        m2 = float(line[0])
        s2 = np.exp(float(line[1]))
    data = []
    for i in range(10000):
        data.append({'Method':'PVI', '$\\beta_0$':np.random.normal(m1, s1)})
        data.append({'Method':'VI', '$\\beta_0$':np.random.normal(m2, s2)})
    data = pd.DataFrame(data)
    plt.figure(figsize=(6, 6))

    sns.kdeplot(data=data, x='$\\beta_0$', hue='Method',)
    plt.ylim([0,4])
    plt.xlim([-1,1])
    plt.title('y ~ 1')
    plt.ylabel('')
    plt.yticks([])
    plt.tight_layout()
    plt.savefig('figure2/voting0.pdf')
    #plt.show()
    plt.clf()

    data = []
    with open('result/Voting4_0.0_1000_BasicFullRank_PACMVIBasic/VIBasic_0.0_0_100', 'r') as f:
        line = f.readline().split()
        vals = [float(l) for l in line]
        test_val = vals[-1]
        vals = vals[:-1]
        m = BasicFullRank(da )
        loc, tril = m.posterior_parameters(jnp.array(vals))
        rk = scipy.stats.rankdata(loc[:50])
        sample = dist.MultivariateNormal(loc, scale_tril=tril).sample(random.PRNGKey(0), (1000,))
        for s in sample:
            for i, d in enumerate(s[:50]):
                data.append({'Method': 'PVI', 'State': rk[i], 'Val': float(d )})
    try:
        with open('result/Voting4_0.0_1000_BasicFullRank_VIBasic/VIBasic_0.0_0_100', 'r') as f:
            line = f.readline().split()
            vals = [float(l) for l in line]
            test_val = vals[-1]
            vals = vals[:-1]
            m = BasicFullRank(da)
            loc, tril = m.posterior_parameters(jnp.array(vals))
            sample = dist.MultivariateNormal(loc, scale_tril=tril).sample(random.PRNGKey(1), (1000,))
            for s in sample:
                for i, d in enumerate(s[:50]):
                    data.append({'Method': 'VI', 'State': rk[i], 'Val': float(d)})
    except:
        pass

    data = pd.DataFrame(data)
    plt.figure(figsize=(5,5))

    sns.boxplot(data=data, x='State', y='Val', hue='Method', showfliers=False, legend=False, width=.8, linewidth=0.7, fill=False)
    # plt.xlim([0,2])
    #plt.ylabel('$\\beta_{1,state}$')
    plt.ylabel('')
    plt.yticks([])
    plt.xticks([])
    plt.title('1|sta')
    plt.tight_layout()
    plt.savefig('figure2/voting4.pdf')
    plt.clf()


    data = []
    with open('result/Voting1_0.0_1000_BasicFullRank_PACMVIBasic/VIBasic_0.0_0_100', 'r') as f:
        line = f.readline().split()
        vals = [float(l) for l in line]
        test_val = vals[-1]
        vals = vals[:-1]
        print(test_val)
        print(len(vals), (da+db)+(da+db) *(da+db-1)//2)
        m = BasicFullRank(da + db)
        loc, tril = m.posterior_parameters(jnp.array(vals))
        rk = scipy.stats.rankdata(loc[:50])
        sample = dist.MultivariateNormal(loc, scale_tril=tril).sample(random.PRNGKey(0), (1000, ))
        means = np.mean(sample,axis=0)
        for s in sample:
            for i, d in enumerate(s[:50]):
                data.append({'Method':'PVI', 'State':rk[i], 'Val':float(d+means[da])})
        #d = jnp.diagonal(jnp.matmul(tril, tril.transpose()))
        #for e in d:
        #    data.append({'Method': 'PVI', '$var(\\beta)$': float(e)})

    with open('result/Voting1_0.0_1000_BasicFullRank_VIBasic/VIBasic_0.0_0_100', 'r') as f:
        line = f.readline().split()
        vals = [float(l) for l in line]
        test_val = vals[-1]
        vals = vals[:-1]
        m = BasicFullRank(da + db )
        loc, tril = m.posterior_parameters(jnp.array(vals))
        d = jnp.diagonal(jnp.matmul(tril, tril.transpose()))
        sample = dist.MultivariateNormal(loc, scale_tril=tril).sample(random.PRNGKey(1), (1000,))
        means = np.mean(sample, axis=0)
        for s in sample:
            for i, d in enumerate(s[:50]):
                data.append({'Method': 'VI', 'State': rk[i], 'Val': float(d+means[da])})

    data = pd.DataFrame(data)
    plt.figure(figsize=(5,5))

    sns.boxplot(data=data, x='State', y='Val', hue='Method',showfliers=False, legend=False, linewidth=0.7, fill=False)
    #plt.xlim([0,2])
#    plt.ylabel('$\\beta_{1,state}+\\bar{\\beta}_{2,1}$')
    plt.ylabel('')

    plt.yticks([])
    plt.xticks([])
    plt.title('1|sta+1|eth')
    plt.tight_layout()
    plt.savefig('figure2/voting1.pdf')
    plt.clf()

    data = []
    with open('result/Voting2_0.0_1000_BasicFullRank_PACMVIBasic/VIBasic_0.0_0_100', 'r') as f:
        line = f.readline().split()
        vals = [float(l) for l in line]
        test_val = vals[-1]
        vals = vals[:-1]
        m = BasicFullRank(da+db+1 )
        loc, tril = m.posterior_parameters(jnp.array(vals))
        d = jnp.diagonal(jnp.matmul(tril, tril.transpose()))
        rk = scipy.stats.rankdata(loc[:50])
        sample = dist.MultivariateNormal(loc, scale_tril=tril).sample(random.PRNGKey(2), (1000,))
        means = np.mean(sample, axis=0)
        stds = np.std(sample,axis=0)
        for i in range(len(means)):
            print(i, float(means[i]), float(stds[i]))

        for s in sample:
            for i, d in enumerate(s[:50]):
                data.append({'Method': 'PVI', 'State': rk[i], 'Val': float(d+means[da]+s[-1])})

    with open('result/Voting2_0.0_1000_BasicFullRank_VIBasic/VIBasic_0.0_0_100', 'r') as f:
        line = f.readline().split()
        vals = [float(l) for l in line]
        test_val = vals[-1]
        vals = vals[:-1]
        m = BasicFullRank(da+db+1)
        loc, tril = m.posterior_parameters(jnp.array(vals))
        d = jnp.diagonal(jnp.matmul(tril, tril.transpose()))
        sample = dist.MultivariateNormal(loc, scale_tril=tril).sample(random.PRNGKey(3), (1000,))
        means = np.mean(sample, axis=0)
        for s in sample:
            for i, d in enumerate(s[:50]):
                data.append({'Method': 'VI', 'State': rk[i], 'Val': float(d+means[da]+s[-1])})

    data = pd.DataFrame(data)
    plt.figure(figsize=(5,5))

    sns.boxplot(data=data, x='State', y='Val', hue='Method',showfliers=False, legend=False, linewidth=0.7, fill=False)
    #plt.xlim([0,2])
    #plt.ylabel('$\\beta_{1,state}+\\bar{\\beta}_{2,1}+\\beta_3$')
    plt.ylabel('')
    plt.yticks([])
    plt.xticks([])
    plt.title('1|sta+1|eth+inc')
    plt.tight_layout()
    plt.savefig('figure2/voting2.pdf')
    plt.clf()

    data = []
    with open('result/Voting3_0.0_1000_BasicFullRank_PACMVIBasic/VIBasic_0.0_0_100', 'r') as f:
        line = f.readline().split()
        vals = [float(l) for l in line]
        test_val = vals[-1]
        vals = vals[:-1]
        m = BasicFullRank(da*2+db )
        loc, tril = m.posterior_parameters(jnp.array(vals))
        d = jnp.diagonal(jnp.matmul(tril, tril.transpose()))
        #for e in d:
        #    data.append({'Method': 'PVI', '$Var(\\beta)$': float(e)})
        rk = scipy.stats.rankdata(loc[:50]+loc[55:])
        sample = dist.MultivariateNormal(loc, scale_tril=tril).sample(random.PRNGKey(2), (1000,))
        means = np.mean(sample, axis=0)
        stds = np.std(sample, axis=0)
        for i in range(len(means)):
            print(i, float(means[i]), float(stds[i]))
        for s in sample:
            for i, d in enumerate(s[:50]):
                data.append({'Method': 'PVI', 'State': rk[i], 'Val': float(d+means[51]+s[55+i])})

    with open('result/Voting3_0.0_1000_BasicFullRank_VIBasic/VIBasic_0.0_0_100', 'r') as f:
        line = f.readline().split()
        vals = [float(l) for l in line]
        test_val = vals[-1]
        vals = vals[:-1]
        m = BasicFullRank(da*2+db)
        loc, tril = m.posterior_parameters(jnp.array(vals))
        d = jnp.diagonal(jnp.matmul(tril, tril.transpose()))
        #for e in d:

        #    data.append({'Method': 'VI', '$Var(\\beta)$': float(e)})
        sample = dist.MultivariateNormal(loc, scale_tril=tril).sample(random.PRNGKey(3), (1000,))
        means = np.mean(sample, axis=0)
        for s in sample:
            for i, d in enumerate(s[:50]):
                data.append({'Method': 'VI', 'State': rk[i], 'Val': float(d+means[51]+s[55+i])})

    data = pd.DataFrame(data)
    plt.figure(figsize=(5,5))
    sns.boxplot(data=data, x='State', y='Val', hue='Method',showfliers=False, linewidth=0.7, fill=False)

    #sns.kdeplot(data=data, x='$Var(\\beta)$', hue='Method',legend=False)
    #plt.xlim([0,2])
    #plt.ylabel('$\\beta_{1,state}+\\bar{\\beta}_{2,1}+\\beta_{3,state}$')
    plt.ylabel('')
    plt.yticks([])
    plt.xticks([])
    plt.title('(1+inc|sta)+1|eth')
    plt.tight_layout()
    plt.savefig('figure2/voting3.pdf')
    plt.clf()


if __name__ == '__main__':
    main()