import jax
import jax.scipy as jsc
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from jax import random, grad, vmap
import jax.numpy as jnp
import numpyro.distributions as dist
from numpyro.distributions.util import vec_to_tril_matrix
import numpy as np
from data import get_matrix_test
class BasicFullRank:
    def __init__(self, dim):
        self.dim = dim
        self.gradq = grad(self.log_posterior, argnums=1)

    def extract_params(self, params):
        loc = params[:self.dim]
        scale = params[self.dim:]
        return loc, scale

    def log_posterior(self, theta, params):
        loc, scale = self.extract_params(params)
        scale_tril = vec_to_tril_matrix(scale[self.dim:], diagonal = -1) + jnp.diag(jnp.maximum(1e-2, jnp.exp(scale[:self.dim])))

        return jnp.sum(dist.MultivariateNormal(loc, scale_tril=scale_tril).log_prob(theta))

    def sample(self, key, params, number = 1):
        loc, scale = self.extract_params(params)
        scale_tril = vec_to_tril_matrix(scale[self.dim:], diagonal = -1) + jnp.diag(jnp.maximum(1e-2, jnp.exp(scale[:self.dim])))
        return dist.MultivariateNormal(loc, scale_tril=scale_tril).sample(key, (number,))

    def posterior_parameters(self, params):
        loc, scale = self.extract_params(params)
        scale_tril = vec_to_tril_matrix(scale[self.dim:], diagonal = -1) + jnp.diag(jnp.maximum(1e-2, jnp.exp(scale[:self.dim])))
        return (loc, scale_tril)

    def diagonosis(self, params, m = None, name=None, gt=None):
        loc, scale = self.extract_params(params)
        if gt is None:
            gt = loc
        scale_tril = vec_to_tril_matrix(scale[self.dim:], diagonal=-1) + jnp.diag(
            jnp.maximum(1e-2, jnp.exp(scale[:self.dim])))
        d = jnp.diagonal(jnp.matmul(scale_tril, scale_tril.transpose()))
        for i in range(self.dim):
            print(i, float(loc[i]), float(d[i]))
        sample = dist.MultivariateNormal(loc, scale_tril=scale_tril).sample(random.PRNGKey(0), (10000, ))
        data = []
        for s in sample:
            #if m is not None:
            #    s = m.convert(s)
            for i in range(self.dim):
                data.append({'dim': i, 'val' : float(s[i]-gt[i])})
        data = pd.DataFrame(data)
        sns.set(style="ticks", rc={"lines.linewidth": 0.5})
        sns.pointplot(data=data, x='dim', y = 'val', errorbar='sd', linestyle = 'none', capsize = 0.5)
        plt.xticks([])
        plt.ylim([-2,2])
        plt.show()
        plt.clf()
        #d = np.array(d)
        #sns.distplot(d)
        #plt.show()
        #plt.clf()

        states = ["Alabama", "Alaska", "Arizona", "Arkansas", "California", "Colorado", "Connecticut", "Delaware",
                  "District of Columbia", "Florida", "Georgia", "Hawaii", "Idaho", "Illinois", "Indiana", "Iowa",
                  "Kansas", "Kentucky", "Louisiana", "Maine", "Maryland", "Massachusetts", "Michigan", "Minnesota",
                  "Mississippi", "Missouri", "Montana", "Nebraska", "Nevada", "New Hampshire", "New Jersey",
                  "New Mexico", "New York", "North Carolina", "North Dakota", "Ohio", "Oklahoma", "Oregon",
                  "Pennsylvania", "Rhode Island", "South Carolina", "South Dakota", "Tennessee", "Texas", "Utah",
                  "Vermont", "Virginia", "Washington", "West Virginia", "Wisconsin", "Wyoming"]
        mapping = {
            "Alabama": "AL",
            "Alaska": "AK",
            "Arizona": "AZ",
            "Arkansas": "AR",
            "California": "CA",
            "Colorado": "CO",
            "Connecticut": "CT",
            "Delaware": "DE",
            "Florida": "FL",
            "Georgia": "GA",
            "Hawaii": "HI",
            "Idaho": "ID",
            "Illinois": "IL",
            "Indiana": "IN",
            "Iowa": "IA",
            "Kansas": "KS",
            "Kentucky": "KY",
            "Louisiana": "LA",
            "Maine": "ME",
            "Maryland": "MD",
            "Massachusetts": "MA",
            "Michigan": "MI",
            "Minnesota": "MN",
            "Mississippi": "MS",
            "Missouri": "MO",
            "Montana": "MT",
            "Nebraska": "NE",
            "Nevada": "NV",
            "New Hampshire": "NH",
            "New Jersey": "NJ",
            "New Mexico": "NM",
            "New York": "NY",
            "North Carolina": "NC",
            "North Dakota": "ND",
            "Ohio": "OH",
            "Oklahoma": "OK",
            "Oregon": "OR",
            "Pennsylvania": "PA",
            "Rhode Island": "RI",
            "South Carolina": "SC",
            "South Dakota": "SD",
            "Tennessee": "TN",
            "Texas": "TX",
            "Utah": "UT",
            "Vermont": "VT",
            "Virginia": "VA",
            "Washington": "WA",
            "West Virginia": "WV",
            "Wisconsin": "WI",
            "Wyoming": "WY",
            "District of Columbia": "DC",
            "American Samoa": "AS",
            "Guam": "GU",
            "Northern Mariana Islands": "MP",
            "Puerto Rico": "PR",
            "United States Minor Outlying Islands": "UM",
            "U.S. Virgin Islands": "VI",
        }

        elpds = vmap(m.test_log_likelihoods)(sample)
        elpd = jnp.sum(jsc.special.logsumexp(elpds, axis=0) - jnp.log(len(sample)))
        print(elpd)
        lst = [0, 10, 20]
        means = np.mean(sample, axis=0)
        if len(sample[0]) == 56:
            sns.set(style="ticks", rc={"lines.linewidth": 2}, font_scale = 2)
            data = []
            for s in sample:
                if m is not None:
                    s = m.convert(s)
                for i, c in enumerate(s[:51]):
                    #if i not in lst:
                    #    continue
                    for x in range(1, 6):
                        data.append({'state': mapping[states[i]], 'x': x, 'y': float(c + means[51] + s[-1] * x )})
            data = pd.DataFrame(data)
            sns.lineplot(data=data, x='x', y='y', hue='state', errorbar='sd', err_style='bars', err_kws={'capsize':5}, legend=(name == 'VI'))
            if name == 'VI':
                plt.legend(ncol=4, columnspacing=0.8, handletextpad = 0.4)
            plt.ylim([-5, 5])
            plt.xlabel('Income group')
            plt.ylabel('Voting preference')
            plt.title(f'{name} constant coefficient \n elpd: {elpd: .2f}')
            plt.xticks([], [])
            plt.tight_layout()
            plt.show()
            plt.clf()
        elif len(sample[0]) == 106:
            sns.set(style="ticks", rc={"lines.linewidth": 2}, font_scale = 2)

            data = []
            for s in sample:
                if m is not None:
                    s = m.convert(s)
                for i, c in enumerate(s[:51]):
                    #if i not in lst:
                    #    continue
                    for x in range(1, 6):
                        data.append({'state': mapping[states[i]], 'x': x, 'y': float(c + means[51] + s[55 + i] * x)})
            data = pd.DataFrame(data)
            sns.lineplot(data=data, x='x', y='y', hue='state', errorbar='sd', err_style='bars', err_kws={'capsize':5}, legend=False)

            plt.ylim([-5, 5])
            plt.xlabel('Income group')
            plt.ylabel('Voting preference')
            plt.title(f'{name} varying coefficient \n elpd: {elpd: .2f}')
            plt.xticks([], [])
            plt.tight_layout()
            plt.show()
            plt.clf()

    def gen_params(self):
        return jnp.zeros(self.dim + self.dim * (self.dim + 1) // 2)

    def get_grad(self, theta, params):
        return self.gradq(theta, params)



