import jax
import jax.scipy as jsc
from jax import random, grad
import jax.numpy as jnp
import numpyro.distributions as dist
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
class Basic:
    def __init__(self, dim):
        self.dim = dim
        self.gradq = grad(self.log_posterior, argnums=1)

    def extract_params(self, params):
        loc = params[:self.dim]
        logscale = params[self.dim:]
        return loc, logscale

    def log_posterior(self, theta, params):
        loc, logscale = self.extract_params(params)
        scale = jnp.exp(logscale)
        return jnp.sum(jsc.stats.norm.logpdf(theta, loc, scale))

    def sample(self, key, params, number = 1):
        loc, logscale = self.extract_params(params)
        scale = jnp.exp(logscale)
        return random.normal(key, shape=(number, self.dim, )) * scale + loc

    def posterior_parameters(self, params):
        loc, logscale = self.extract_params(params)
        scale = jnp.exp(logscale)
        return (loc, scale)

    def diagonosis(self, params, m=None, name=None, gt=None):
        loc, logscale = self.extract_params(params)
        if gt is None:
            gt = loc
        sample = dist.Normal(loc, jnp.exp(logscale)).sample(random.PRNGKey(0), (1000, ))
        data = []
        for s in sample:
            if m is not None:
                s = m.convert(s)
            for i in range(self.dim):
                data.append({'dim': i, 'val' : float(s[i]-gt[i])})
        data = pd.DataFrame(data)
        sns.set(style="ticks", rc={"lines.linewidth": 0.5})
        sns.pointplot(data=data, x='dim', y = 'val', errorbar='sd', linestyle = 'none', capsize = 0.5)
        plt.xticks([])
        plt.ylim([-2,2])
        plt.show()
        plt.clf()
        states = ["Alabama", "Alaska", "Arizona", "Arkansas", "California", "Colorado", "Connecticut", "Delaware", "District of Columbia", "Florida", "Georgia", "Hawaii", "Idaho", "Illinois", "Indiana", "Iowa", "Kansas", "Kentucky", "Louisiana", "Maine", "Maryland", "Massachusetts", "Michigan", "Minnesota", "Mississippi", "Missouri", "Montana", "Nebraska", "Nevada", "New Hampshire", "New Jersey", "New Mexico", "New York", "North Carolina", "North Dakota", "Ohio", "Oklahoma", "Oregon", "Pennsylvania", "Rhode Island", "South Carolina", "South Dakota", "Tennessee", "Texas", "Utah", "Vermont", "Virginia", "Washington", "West Virginia", "Wisconsin", "Wyoming"]
        if len(sample[0]) == 57:
            sns.set(style="ticks", rc={"lines.linewidth": 2})
            data = []
            for s in sample:
                if m is not None:
                    s = m.convert(s)
                for i, c in enumerate(s[:51]):
                    if i%6 != 0:
                        continue
                    for x in range(1,6):
                        data.append({'state':states[i], 'x':x, 'y': float(c + s[51] + s[-2] * x + s[-1])})
            data = pd.DataFrame(data)
            sns.lineplot(data=data, x='x', y='y', hue='state', errorbar='sd', err_style='bars', err_kws={'capsize':5})
            plt.show()
            plt.clf()
        elif len(sample[0]) == 107:
            sns.set(style="ticks", rc={"lines.linewidth": 2})

            data = []
            for s in sample:
                if m is not None:
                    s = m.convert(s)
                for i, c in enumerate(s[:51]):
                    if i%6 != 0:
                        continue
                    for x in range(1, 6):
                        data.append({'state':states[i], 'x':x, 'y': float(c + s[51] + s[55+i] * x + s[-1])})
            data = pd.DataFrame(data)
            sns.lineplot(data=data, x='x', y='y', hue='state', errorbar='sd', err_style='bars', err_kws={'capsize':5})
            plt.show()
            plt.clf()
    def gen_params(self):
        return jnp.zeros(self.dim * 2)

    def get_grad(self, theta, params):
        return self.gradq(theta, params)



