from posterior import BasicFullRank, vec_to_tril_matrix
from model import Election2, Election3
import jax.numpy as jnp
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
import numpyro.distributions as dist
from jax import random, vmap
import jax.scipy as jsc

dim1 = 56
dim2 = 106
n_sample = 1000

post1 = BasicFullRank(dim1)
post2 = BasicFullRank(dim2)

with open('result/Election2_0.0_100_BasicFullRank_VIBasic/VIBasic_0.0_0_100', 'r') as f:
    nums = f.readline().split()
    par11 = np.array([float(n) for n in nums][:-1])

with open('result/Election2_0.0_100_BasicFullRank_PACMVIBasic/VIBasic_0.0_0_100', 'r') as f:
    nums = f.readline().split()
    par12 = np.array([float(n) for n in nums][:-1])

with open('result/Election3_0.0_100_BasicFullRank_VIBasic/VIBasic_0.0_0_100', 'r') as f:
    nums = f.readline().split()
    par21 = np.array([float(n) for n in nums][:-1])

with open('result/Election3_0.0_100_BasicFullRank_PACMVIBasic/VIBasic_0.0_0_100', 'r') as f:
    nums = f.readline().split()
    par22 = np.array([float(n) for n in nums][:-1])

loc11, scale11 = post1.extract_params(par11)
loc12, scale12 = post1.extract_params(par12)
loc21, scale21 = post2.extract_params(par21)
loc22, scale22 = post2.extract_params(par22)
scale_tril11 = vec_to_tril_matrix(scale11[dim1:], diagonal=-1) + jnp.diag(
    jnp.maximum(1e-2, jnp.exp(scale11[:dim1])))
d11 = jnp.diagonal(jnp.matmul(scale_tril11, scale_tril11.transpose()))
scale_tril12 = vec_to_tril_matrix(scale12[dim1:], diagonal=-1) + jnp.diag(
    jnp.maximum(1e-2, jnp.exp(scale12[:dim1])))
d12 = jnp.diagonal(jnp.matmul(scale_tril12, scale_tril12.transpose()))
scale_tril21 = vec_to_tril_matrix(scale21[dim2:], diagonal=-1) + jnp.diag(
    jnp.maximum(1e-2, jnp.exp(scale21[:dim2])))
d21 = jnp.diagonal(jnp.matmul(scale_tril21, scale_tril21.transpose()))
scale_tril22 = vec_to_tril_matrix(scale22[dim2:], diagonal=-1) + jnp.diag(
    jnp.maximum(1e-2, jnp.exp(scale22[:dim2])))
d22 = jnp.diagonal(jnp.matmul(scale_tril22, scale_tril22.transpose()))

for i in range(dim1):
    print(i, float(loc11[i]), float(d11[i]), float(loc12[i]), float(d12[i]))

for i in range(dim2):
    print(i, float(loc21[i]), float(d21[i]), float(loc22[i]), float(d22[i]))


sample11 = dist.MultivariateNormal(loc11, scale_tril=scale_tril11).sample(random.PRNGKey(0), (n_sample, ))
sample12 = dist.MultivariateNormal(loc12, scale_tril=scale_tril12).sample(random.PRNGKey(1), (n_sample, ))

sample21 = dist.MultivariateNormal(loc21, scale_tril=scale_tril21).sample(random.PRNGKey(2), (n_sample, ))
sample22 = dist.MultivariateNormal(loc22, scale_tril=scale_tril22).sample(random.PRNGKey(3), (n_sample, ))

#data = []
#for s in sample:
    #if m is not None:
    #    s = m.convert(s)
#    for i in range(self.dim):
#        data.append({'dim': i, 'val' : float(s[i]-gt[i])})
#data = pd.DataFrame(data)
#sns.set(style="ticks", rc={"lines.linewidth": 0.5})
#sns.pointplot(data=data, x='dim', y = 'val', errorbar='sd', linestyle = 'none', capsize = 0.5)
#plt.xticks([])
#plt.ylim([-2,2])
#plt.show()
#plt.clf()
#d = np.array(d)
#sns.distplot(d)
#plt.show()
#plt.clf()

states = ["Alabama", "Alaska", "Arizona", "Arkansas", "California", "Colorado", "Connecticut", "Delaware",
          "District of Columbia", "Florida", "Georgia", "Hawaii", "Idaho", "Illinois", "Indiana", "Iowa",
          "Kansas", "Kentucky", "Louisiana", "Maine", "Maryland", "Massachusetts", "Michigan", "Minnesota",
          "Mississippi", "Missouri", "Montana", "Nebraska", "Nevada", "New Hampshire", "New Jersey",
          "New Mexico", "New York", "North Carolina", "North Dakota", "Ohio", "Oklahoma", "Oregon",
          "Pennsylvania", "Rhode Island", "South Carolina", "South Dakota", "Tennessee", "Texas", "Utah",
          "Vermont", "Virginia", "Washington", "West Virginia", "Wisconsin", "Wyoming"]
mapping = {
    "Alabama": "AL",
    "Alaska": "AK",
    "Arizona": "AZ",
    "Arkansas": "AR",
    "California": "CA",
    "Colorado": "CO",
    "Connecticut": "CT",
    "Delaware": "DE",
    "Florida": "FL",
    "Georgia": "GA",
    "Hawaii": "HI",
    "Idaho": "ID",
    "Illinois": "IL",
    "Indiana": "IN",
    "Iowa": "IA",
    "Kansas": "KS",
    "Kentucky": "KY",
    "Louisiana": "LA",
    "Maine": "ME",
    "Maryland": "MD",
    "Massachusetts": "MA",
    "Michigan": "MI",
    "Minnesota": "MN",
    "Mississippi": "MS",
    "Missouri": "MO",
    "Montana": "MT",
    "Nebraska": "NE",
    "Nevada": "NV",
    "New Hampshire": "NH",
    "New Jersey": "NJ",
    "New Mexico": "NM",
    "New York": "NY",
    "North Carolina": "NC",
    "North Dakota": "ND",
    "Ohio": "OH",
    "Oklahoma": "OK",
    "Oregon": "OR",
    "Pennsylvania": "PA",
    "Rhode Island": "RI",
    "South Carolina": "SC",
    "South Dakota": "SD",
    "Tennessee": "TN",
    "Texas": "TX",
    "Utah": "UT",
    "Vermont": "VT",
    "Virginia": "VA",
    "Washington": "WA",
    "West Virginia": "WV",
    "Wisconsin": "WI",
    "Wyoming": "WY",
    "District of Columbia": "DC",
    "American Samoa": "AS",
    "Guam": "GU",
    "Northern Mariana Islands": "MP",
    "Puerto Rico": "PR",
    "United States Minor Outlying Islands": "UM",
    "U.S. Virgin Islands": "VI",
}

m1 = Election2()
m2 = Election3()

elpds11 = vmap(m1.test_log_likelihoods)(sample11)
elpd11 = jnp.sum(jsc.special.logsumexp(elpds11, axis=0) - jnp.log(len(sample11)))
elpds12 = vmap(m1.test_log_likelihoods)(sample12)
elpd12 = jnp.sum(jsc.special.logsumexp(elpds12, axis=0) - jnp.log(len(sample12)))
elpds21 = vmap(m2.test_log_likelihoods)(sample21)
elpd21 = jnp.sum(jsc.special.logsumexp(elpds21, axis=0) - jnp.log(len(sample21)))
elpds22 = vmap(m2.test_log_likelihoods)(sample22)
elpd22 = jnp.sum(jsc.special.logsumexp(elpds22, axis=0) - jnp.log(len(sample22)))
print(elpd11, elpd12, elpd21, elpd22)
lst = [1, 17, 22, 49]
means11 = np.mean(sample11, axis=0)
means12 = np.mean(sample12, axis=0)
means21 = np.mean(sample21, axis=0)
means22 = np.mean(sample22, axis=0)

sns.set(style="ticks", rc={"lines.linewidth": 2}, font_scale = 3)

model_mapping = {'Election2': 'constant coeff', 'Election3': 'varying coeff'}

def plot(sample, name, means, model, elpd, id):
    data = []
    for s in sample:
        for i, c in enumerate(s[:51]):
            if i not in lst:
                continue
            if model == 'Election2':
                for x in range(1, 6):
                    data.append({'state': mapping[states[i]], 'x': x, 'y': float(c + means[51] + s[-1] * x )})
            else:
                for x in range(1, 6):
                    data.append({'state': mapping[states[i]], 'x': x, 'y': float(c + means[51] + s[55 + i] * x)})
    data = pd.DataFrame(data)

    plt.figure(figsize=(5, 5))
    sns.lineplot(data=data, x='x', y='y', hue='state', errorbar='sd', err_style='bars', err_kws={'capsize':5}, legend=(name == 'PVI' and model == 'Election3'))
    if name == 'PVI' and model == 'Election3':
        plt.legend(ncol=1, columnspacing=0.8, handletextpad = 0.4, prop={'size': 16}, frameon=True) #loc=(1.02, 0.23),
    plt.ylim([-1, 3])
    plt.xlabel('Income group')
    plt.ylabel('')
    plt.yticks([])
    plt.title(f'{name} {model_mapping[model]}')
    plt.xticks([], [])
    plt.tight_layout()
    #plt.show()
    plt.savefig(f'figure2/{id}.pdf')
    plt.clf()

plot(sample11, 'VI', means11, 'Election2', elpd11, '11')
plot(sample12, 'PVI', means12, 'Election2', elpd12, '12')
plot(sample21, 'VI', means21, 'Election3', elpd21, '21')
plot(sample22, 'PVI', means22, 'Election3', elpd22, '22')
