import jax
import jax.scipy as jsc
import pandas as pd
from jax import random, grad, vmap
import jax.numpy as jnp
import generax
import equinox as eqx

class NormalizingFlow1D:
    def __init__(self, dim, key = None, knots = 16, ):
        self.dim = dim
        k1, k2 = random.split(key)
        transform = generax.RationalQuadraticSpline(input_shape=(1,),
                                                    K=knots,
                                                    key=k1)
        prior = generax.Gaussian(input_shape=(1,))
        self.flow = generax.NormalizingFlow(transform=transform, prior=prior)
        initials = random.normal(k2, (1000, dim))
        self.flow = self.flow.data_dependent_init(initials, key=key)
        self.gradf = eqx.filter_jit(eqx.filter_grad(self._log_posterior))

    def extract_params(self, params):
        raise NotImplementedError()

    def log_posterior(self, theta, params):
        return params.log_prob(theta)

    def _log_posterior(self, params, theta):
        return self.log_posterior(theta, params)

    def sample(self, key, params, number = 1):
        keys = random.split(key, number)
        samples, _ = eqx.filter_vmap(params.sample_and_log_prob)(keys,)
        return samples

    def posterior_parameters(self, params):
        raise NotImplementedError()

    def gen_params(self):
        return self.flow

    def get_grad(self, theta, params):
        return self.gradf(params, theta)


if __name__ == '__main__':
    import numpy as np
    import seaborn as sns
    import matplotlib.pyplot as plt
    import optax
    from tqdm import tqdm
    class DiscreteDensity:
        def __init__(self, N=20):
            self.N = N
            self.cat = np.array([1, 2, 4, 7, 10, 12, 10, 7, 4, 2,
                                 1, 2, 3, 5, 7, 8, 7, 5, 3, 2])
            self.probs = self.cat / np.sum(self.cat)

        def sample(self, num, shuffle=True):
            return np.random.choice(np.arange(self.N), num, p=self.probs)
    disc = DiscreteDensity()
    y = jnp.expand_dims(disc.sample(10000) + np.random.randn(10000) * 0.1, axis=1)
    y = jsc.special.logit((y+1) / 21)
    print(y, jnp.min(y), jnp.max(y))

    nf = NormalizingFlow1D(dim=1,key=random.PRNGKey(2), knots=32)
    params = nf.gen_params()
    rng_key = random.PRNGKey(0)
    chain = []

    schedule = optax.warmup_cosine_decay_schedule(init_value=0.0,
                                                  peak_value=1.0,
                                                  warmup_steps=1000,
                                                  decay_steps=10000,
                                                  end_value=0.1,
                                                  exponent=1.0)
    chain.append(optax.clip_by_global_norm(10.0))
    chain.append(optax.adam(1e-4, nesterov=True))
    chain.append(optax.scale_by_schedule(schedule))
    optimizer = optax.chain(*chain)
    opt_state = optimizer.init(eqx.filter(params, eqx.is_inexact_array)), params, rng_key
    def objective(flow, key):
        logps = vmap(flow.log_prob)(y)
        return -jnp.sum(logps)


    jit_obj = eqx.filter_jit(eqx.filter_value_and_grad(objective))


    def step(step, opt_state):
        param, flow, rng_key = opt_state
        data_key, rng_key = random.split(rng_key)
        value, grads = jit_obj(flow, data_key)
        skip, dict = optax.skip_not_finite(grads, param, flow)
        # skip = False
        if not skip:
            updates, new_opt_state = optimizer.update(grads, param, flow)
            new_flow = eqx.apply_updates(flow, updates)
        else:
            print(step, 'skipped')
            new_opt_state = param
            new_flow = flow
        return value, (new_opt_state, new_flow, rng_key)


    data = []
    for i in tqdm(range(10000)):
        value, opt_state = step(i, opt_state)
        _, flow, rng_key = opt_state
        if i % 1000 == 0:
            print(float(value))
        data.append({'step': i, 'loss': float(value)})
    data = pd.DataFrame(data)
    sns.lineplot(data=data, x='step', y='loss')
    plt.show()
    plt.clf()

    _, flow, rng_key = opt_state

    all_samples, logp = eqx.filter_vmap(flow.sample_and_log_prob)(random.split(random.PRNGKey(0), 10000), )

    plotdata = np.array(all_samples).flatten()
    # sns.scatterplot(data = plotdata, x='beta1', y='beta2', s = 1, color = ".2")
    sns.kdeplot(data=np.array(jsc.special.expit(plotdata))*21-1)
    sns.kdeplot(data = np.array(jsc.special.expit(y)).flatten()*21-1)
    plt.xlim([0, 20])
    # plt.ylim([0, 6])
    plt.show()
    plt.clf()



