from functools import partial
import math
import warnings
from markovsbi.tasks.base import Task

import jax
from jax import Array
from jax.typing import ArrayLike
import jax.numpy as jnp

from markovsbi.utils.prior_utils import GeneralDistribution, Normal

from probjax.inference.filter import filter, filter_log_likelihood
from probjax.inference.filtering.particle_filter import ParticleFilter
from probjax.inference.kernels import gibbs, slice
import blackjax


def drift_lotka_volterra(t, x, alpha, beta, gamma, delta):
    predator = x[0]
    prey = x[1]
    # predator and prey are swapped
    d_predator = alpha * predator - beta * predator * prey
    d_prey = -gamma * prey + delta * predator * prey

    dx = jnp.array([d_predator, d_prey])
    return dx


def diffusion_lotka_volterra(t, x, sigma1=0.05, sigma2=0.05):
    return jnp.array([[sigma1*x[0]*x[1], 0.], [0., sigma2*x[1]*x[0]]])


def build_log_likelihood(task: Task, n_particles=100, obs_sigma=5e-2):

    simulator = task.get_simulator()

    def transition_fn(key, x, t, theta=None):
        keys = jax.random.split(key, x.shape[0])
        _transiton = lambda k, x: simulator(k, theta, 2, x0=x)[-1]
        return jax.vmap(_transiton)(keys, x)

    def observed_log_likelihood_fn(x, y, t):
        # Approximate dirac delta

        return jax.scipy.stats.norm.logpdf(x, y, obs_sigma).sum(axis=-1)

    # Initialize particles at right values
    init_particles = jnp.ones((n_particles, 2)) * jnp.array([1.0, 0.5])

    def unbiased_log_likelihood_fn(theta, key, x_o):
        _transiton_fn = partial(transition_fn, theta=theta)
        kernel = ParticleFilter(observed_log_likelihood_fn, _transiton_fn)
        ts = jnp.arange(x_o.shape[0])
        log_Z = filter_log_likelihood(key, ts, ts[1:], x_o[1:], kernel, init_particles)
        return log_Z.sum()

    return unbiased_log_likelihood_fn


def build_sampler(
    task: Task,
    x_obs: ArrayLike,
    burn_in: int = 500,
    num_chains: int = 50,
    num_sir=10,
):

    if x_obs.shape[0] <= 3:
        obs_ll = 1e-2
        step_size = 0.4
    elif x_obs.shape[0] <= 20:
        obs_ll = 5e-2
        step_size = 0.3
    else:
        obs_ll = 1e-1
        step_size = 0.1
    log_likelihood = build_log_likelihood(task, obs_sigma=obs_ll)
    prior = task.get_prior()

    def log_density_fn(theta, key):
        return log_likelihood(theta, key, x_obs) + prior.log_prob(theta)

    T = x_obs.shape[0]
    # kernel1 = slice
    kernel1 = blackjax.rmh
    kernel2 = blackjax.irmh
    kernels = {"theta": kernel1, "key": kernel2}
    inner_kernel_kwargs = {
        # "theta": {"step_size": 0.05 + 1 / T, "max_steps": 20},
        "theta": {
            "transition_generator": lambda k, x: x
            + step_size * jax.random.normal(k, shape=x.shape)
        },
        "key": {
            "proposal_distribution": lambda key: jax.random.split(key)[1],
            "proposal_logdensity_fn": lambda x, x_old: 0.0,
        },
    }
    # inner_kernel_steps ={ "theta": 1, "key": 1}
    inner_kernel_steps = {"theta": 5, "key": 1}

    kernel = gibbs(
        log_density_fn,
        kernels,
        inner_kernel_kwargs=inner_kernel_kwargs,
        inner_kernel_steps=inner_kernel_steps,
    )

    def mcmc(key, num_steps):
        key, subkey1, subkey2, subkey3 = jax.random.split(key, 4)
        subkey1_propose, subkey_sir = jax.random.split(subkey1)
        subkeys_sir = jax.random.split(subkey_sir, num_sir)
        theta0 = prior.sample(subkey1_propose, (num_sir,))
        weights = jax.vmap(log_density_fn)(theta0, subkeys_sir)
        theta0 = theta0[jnp.argmax(weights)]
        init_key = subkey3
        position = {"theta": theta0, "key": init_key}
        state = kernel.init(position, subkey2)

        def one_step(state, key):
            state, _ = kernel.step(key, state)
            return state, state.position["theta"]

        keys = jax.random.split(key, num_steps)
        _, thetas = jax.lax.scan(one_step, state, keys)
        return thetas

    @partial(jax.jit, static_argnums=(1,))
    def sample(key, shape):
        num_samples = math.prod(shape)
        chains = min(num_samples, num_chains)
        num_samples = num_samples // chains + burn_in
        keys = jax.random.split(key, chains)
        thetas = jax.vmap(partial(mcmc, num_steps=num_samples))(keys)
        thetas = thetas[:, burn_in:]
        return thetas.reshape(shape + (thetas.shape[-1],))

    def log_prob(theta):
        warnings.warn("log_prob is not normalized")
        _log_density_fn = log_density_fn
        for _ in range(theta.ndim - 1):
            _log_density_fn = jax.vmap(_log_density_fn)
        return _log_density_fn(theta)

    return sample, log_prob


class LotkaVolterra(Task):

    def __init__(
        self, x0_min=0.0, x0_max=10.0, dt=0.02, inner_steps=5, sigma1=0.1, sigma2=0.1
    ):
        super().__init__("LotkaVolterra")
        self.x0_min = x0_min
        self.x0_max = x0_max
        self.dt = dt
        self.inner_steps = inner_steps
        self.sigma1 = sigma1
        self.sigma2 = sigma2

    @property
    def input_shape(self):
        return (4,)

    @property
    def condition_shape(self):
        return (2,)

    def get_prior(self):
        return Normal(jnp.array([0.0, 0.0, 0.0, 0.0]), jnp.array([1.0, 1.0, 1.0, 1.0]))

    def get_simulator(self):

        def simulator(rng_key, theta: Array, T: int, x0=None):
            alpha, beta, gamma, delta = theta
            alpha = jax.nn.sigmoid(alpha)*2. + 1.
            beta = jax.nn.sigmoid(beta)*2. + 1.
            gamma = jax.nn.sigmoid(gamma)*2 + 1.
            delta = jax.nn.sigmoid(delta)*2. + 1.

            if x0 is None:
                x0 = jnp.array([1., 0.5])

            def one_step(x, key):

                def integrate_step(x, key):
                    xdt = drift_lotka_volterra(0, x, alpha, beta, gamma, delta)
                    dWt = jax.random.normal(key, shape=(2,)) * jnp.sqrt(self.dt)
                    B = diffusion_lotka_volterra(0, x, self.sigma1, self.sigma2)
                    x = x + xdt*self.dt + jnp.dot(B, dWt)
                    return x

                keys = jax.random.split(key, self.inner_steps)
                x = jax.lax.fori_loop(0, self.inner_steps, lambda i, x: integrate_step(x, keys[i]), x)
                return x, x

            keys = jax.random.split(rng_key, T-1)
            _, xs = jax.lax.scan(one_step, x0, keys)
            xs = jnp.concatenate([x0[None], xs], axis=0)
            xs = jnp.clip(xs, 0)
            return xs

        return simulator

    def get_data(self, key, num_simulations, T):
        prior = self.get_prior()
        simulators = self.get_simulator()
        key1, key2, key3 = jax.random.split(key, 3)
        thetas = prior.sample(key1, (num_simulations,))
        x0 = jax.random.uniform(key2, (num_simulations, 2), minval=self.x0_min, maxval=self.x0_max)
        keys = jax.random.split(key3, num_simulations)
        xs = jax.vmap(simulators, in_axes=(0,0,None,0))(keys, thetas, T, x0)

        return {"thetas": thetas, "xs": xs}

    def get_true_posterior(
        self, x_o: ArrayLike, burn_in=500, num_chains=10, num_sir=50
    ):
        sample_fn, log_prob_fn = build_sampler(
            self,
            x_o,
            burn_in=burn_in,
            num_chains=num_chains,
            num_sir=num_sir,
        )
        return GeneralDistribution(self.input_shape, sample_fn, log_prob_fn)
