from typing import NamedTuple, Callable
from functools import partial

import jax
import jax.numpy as np
from jaxtyping import Array, PyTree, PRNGKeyArray
from blackjax.util import generate_gaussian_noise
from blackjax.smc import resampling
import optax

from exptax.models.base import BaseExperiment
from exptax.base import ParticlesApprox
from exptax.optimizers.base import Optim
from exptax.inference.gibbs import as_top_level_api

class ImplicitState(NamedTuple):
    particles: ParticlesApprox
    particles_prior: ParticlesApprox
    design: PyTree
    y: Array
    opt_state: optax.OptState
    k: int = 0

def normalize(log_weights):
    max_logw = np.max(log_weights)
    w = np.exp(log_weights - max_logw)
    w_mean = w.mean()
    w = w / (w.size * w_mean)
    return w

def langevin(rng_key, position: PyTree, logdensity_grad: PyTree, step_size: float) -> PyTree:
    noise = generate_gaussian_noise(rng_key, position, np.array(0.0), np.array(1.0))
    position = jax.tree_util.tree_map(
        lambda p, g, n: p + step_size * g + np.sqrt(2 * step_size) * n,
        position,
        logdensity_grad,
        noise,
    )
    return position

def y_sampler(rng_key: PRNGKeyArray, sampler: Callable, thetas: PyTree) -> PyTree:
    num_particles = jax.tree_util.tree_leaves(thetas)[0].shape[0]
    keys = jax.random.split(rng_key, num_particles)
    return jax.vmap(sampler)(thetas, keys)

def mean_log_score(rng_key: PRNGKeyArray, theta_ref: Array, theta_target: Array, design: Array, y: Array, model: BaseExperiment):
    sampler = partial(model.sample, xi=design)
    y_ref = y_sampler(rng_key, sampler, theta_ref)
    logprob_ref = model.log_prob(theta_ref, y_ref, design)
    logprob_target = jax.vmap(model.log_prob, in_axes=(None, 0, None))(theta_target, y_ref, design)
    logprob_means = np.mean(logprob_target, axis=1, keepdims=True)
    log_weights = jax.lax.stop_gradient(logprob_target - logprob_means)
    _norm = jax.scipy.special.logsumexp(log_weights, axis=1, keepdims=True)
    weights = np.exp(log_weights - _norm)
    weighted_logprobs = np.mean(weights * logprob_target, axis=1)
    return (logprob_ref - weighted_logprobs).mean(), y_ref

def step_diffusion_gibbs(rng_key: PRNGKeyArray, history_meas: PyTree, state: ImplicitState, n_meas: int, model: BaseExperiment, optx_opt: optax.GradientTransformation):
    particles, particles_prior, design, ys, opt_state, k = state
    key_theta, key_y, key_target, key_p = jax.random.split(rng_key, 4)
    past = model.make_potential(history_meas, n_meas)

    # Gibbs on theta_O
    gibbs = as_top_level_api(past, 1)
    state_0 = jax.vmap(gibbs.init)(particles_prior.thetas)
    num_prior = jax.tree_leaves(particles_prior.thetas)[0].shape[0]
    keys_gibbs = jax.random.split(key_theta, num_prior)
    state_0 = jax.vmap(gibbs.step)(keys_gibbs, state_0)
    _norm = jax.scipy.special.logsumexp(state_0.logdensity, keepdims=True)
    weights = np.exp(state_0.logdensity - _norm)
    particles_prior = ParticlesApprox(state_0.position, weights)

    # Gibbs theta_l
    def vec_loglik(theta):
        logprobs_n = model.log_prob(theta, ys, design)
        logarithmic_pull = np.mean(logprobs_n, axis=0)
        return logarithmic_pull + past(theta)

    vec_gibbs = as_top_level_api(vec_loglik, 1)
    state_l = jax.vmap(vec_gibbs.init)(particles.thetas)
    num_contrastive = jax.tree_leaves(particles.thetas)[0].shape[0]
    keys_gibbs = jax.random.split(key_theta, num_contrastive)
    state_l = jax.vmap(vec_gibbs.step)(keys_gibbs, state_l)
    _norm = jax.scipy.special.logsumexp(state_l.logdensity, keepdims=True)
    weights_l = np.exp(state_l.logdensity - _norm)
    particles = ParticlesApprox(state_l.position, weights_l)

    # SGD xi
    grad_xi_score = jax.grad(mean_log_score, argnums=3, has_aux=True)
    grad_xi, ys = grad_xi_score(key_p, state_0.position, state_l.position, design, ys, model)
    updates, opt_state = optx_opt.update(grad_xi, opt_state, design)
    xi = optax.apply_updates(design, updates)

    return ImplicitState(particles, particles_prior, xi, ys, opt_state, k + 1), xi


def init_averaged(model: BaseExperiment, optx_opt: optax.GradientTransformation, rng_key: PRNGKeyArray, particles: ParticlesApprox, design: Array, n_prior: int = 200) -> ImplicitState:
    sampler = partial(model.sample, xi=design)
    opt_state = optx_opt.init(design)
    particles_prior = jax.tree.map(lambda x: jax.random.choice(rng_key, x, (n_prior,)), particles)
    y = y_sampler(rng_key, sampler, particles_prior.thetas)
    return ImplicitState(particles, particles_prior, design, y, opt_state)


class ImplicitDiffusion:
    def __new__(cls, model: BaseExperiment, opt_steps: int, optx_opt: optax.GradientTransformation, log_dir: str = "logs") -> Optim:
        def init(rng_key: PRNGKeyArray, particles: ParticlesApprox, design: Array, n_prior: int = 200) -> ImplicitState:
            return init_averaged(model, optx_opt, rng_key, particles, design, n_prior)

        def run(rng_key: ParticlesApprox, state: ImplicitState, hist: PyTree, n_meas: int):
            def step(state, tup):
                _, key = tup
                return step_diffusion_gibbs(key, hist, state, n_meas, model, optx_opt)

            keys = jax.random.split(rng_key, opt_steps)
            end_state, hist = jax.lax.scan(step, state, (np.arange(0, opt_steps), keys))
            return end_state, hist

        def logger(*args, **kwargs):
            # Implement logger if needed
            pass

        return Optim(init, run, logger)