import functools
import haiku as hk
import jax
import jax.numpy as jnp
import jax.random as jrandom
import optax
import tree
from typing import NamedTuple, Sequence, Tuple, Dict, Iterable
import distrax
import chex
from utils import MinMaxDenormalizationLayer, MinMaxNormalizationLayer

import numpy as np
EPS = np.finfo(np.float32).eps

from policies.nnets import energy_network


def l2_loss(params: Iterable[jnp.ndarray]) -> jnp.ndarray:
    return 0.5 * sum(jnp.sum(jnp.square(p)) for p in params)


class TrainingState(NamedTuple):
    """Holds the agent's training state."""

    params: hk.Params
    ema_params: hk.Params
    state: hk.Params
    ema_state: hk.Params
    net_opt_state: hk.Params
    step: int


class IRCP:
    def __init__(
        self,
        state_dims,
        actions_dims,
        spectral_norm,
        all_grad_penalty,
        density_penalty,
        ema=0.0,
        num_mcmc_chains=8,
        use_layer_norm=True,
        use_bias=True,
        grad_penalty=1.,
        num_action_samples=512,
        max_g=1.,
        dims=512,
        loss_type="infoNCE",
        weight_decay=0.0,
        num_layers=8,
        gradient_scaling=False,
        u_net=False,
        use_skip=True,
        lr_schedule=True,
        denorm_act=None,
        uniform_buffer=0.05,
        warmup=0,
        temperature=1.,
        activation='relu',
        rtg=True,
        eta=0.,
        **kwargs
    ):

        self.state_dims = state_dims
        self.gradient_scaling = gradient_scaling
        self.weight_decay = weight_decay
        self.warmup = warmup
        self.ema = ema
        self.loss_type = loss_type
        self.all_grad_penalty = all_grad_penalty
        self.actions_dims = actions_dims
        self.grad_penalty = grad_penalty
        self.spectral_norm = spectral_norm
        self.density_penalty = density_penalty
        self.num_mcmc_chains = num_mcmc_chains
        self.num_action_samples = num_action_samples
        self.eta = eta

        self.denorm_act = denorm_act

        self.min_sampling_action = -1
        self.max_sampling_action = 1

        self.max_g = max_g
        self.lr_schedule = lr_schedule

        if activation == 'relu':
            activation = jax.nn.relu
        elif activation == 'swish':
            activation = jax.nn.swish
        elif activation == 'leaky_relu':
            activation = jax.nn.leaky_relu
        else:
            raise ValueError("Activation Type Not Define")

        net_ = functools.partial(
            energy_network,
            use_spectral_norm=spectral_norm,
            u_net=u_net,
            num_layers=num_layers,
            dims=dims,
            use_skip=use_skip,
            use_layer_norm=use_layer_norm,
            use_bias=use_bias,
            action_dims=actions_dims,
            temperature=temperature,
            activation=activation,
            rtg=rtg,
        )

        self.network = hk.without_apply_rng(hk.transform_with_state(net_))
        self.ema_fn = hk.transform_with_state(
            lambda x: hk.EMAParamsTree(decay=ema)(x)
        )

    def get_init_state(self, rng: int, learning_rate: float) -> TrainingState:
        # Initialize the networks and optimizer.
        rng1, rng2 = jrandom.split(rng)
        dummy_observation = jnp.zeros((1, self.state_dims), jnp.float32)
        dummy_action = jnp.zeros((1, self.actions_dims), jnp.float32)
        initial_params, state = self.network.init(rng1, dummy_observation, dummy_action, jnp.zeros((1, 1)))

        schedule_fn = optax.exponential_decay(init_value=-learning_rate,
                                              transition_begin=self.warmup,
                                              transition_steps=100,
                                              decay_rate=0.99)

        self.net_optimizer = optax.chain(optax.scale_by_adam(eps=1e-7), 
                                         optax.scale_by_schedule(schedule_fn))

        net_opt_state = self.net_optimizer.init(initial_params)
        ema_params, ema_state = self.ema_fn.init(rng2, initial_params)

        agent_state = TrainingState(
            params=initial_params,
            ema_params=ema_params,
            state=state,
            ema_state=ema_state,
            net_opt_state=net_opt_state,
            step=0,
        )

        return agent_state

    def sampling(
        self,
        rng,
        params,
        state,
        obs,
        a_t,
        g_t,
        num_samples=100,
        num_samples2=0,
        alpha=0.0,
        num_mcmc_chains=8,
        return_all=True,
    ):
        rng1, rng2, rng3, rng4 = jrandom.split(rng, 4)

        a_0 = jrandom.uniform(key=rng1, 
                              shape=(num_mcmc_chains,) + a_t.shape,
                              minval=self.min_sampling_action, 
                              maxval=self.max_sampling_action)

        g_0 = jrandom.uniform(key=rng2, 
                              shape=(num_mcmc_chains,) + g_t.shape,
                              minval=-self.max_g, 
                              maxval=self.max_g)

        initial_position = {"a": a_0, "g": g_0}

        # exponential tilt to sample from the tail of the distribution
        # see https://en.wikipedia.org/wiki/Exponential_tilting
        def neg_log_prob_fn(a, g, o, p, s):
            log_density, _ = self.network.apply(p, s, o, a, g)
            assert log_density.shape == g.shape
            log_density += alpha * g
            return -jnp.mean(log_density)

        neg_log_prob_fn = functools.partial(neg_log_prob_fn, o=obs, p=params, s=state)

        def neg_log_prob(x):
            return neg_log_prob_fn(**x)

        schedule_fn = optax.polynomial_schedule(init_value=-0.5, end_value=-1e-5, power=2, transition_steps=100)

        seed = jrandom.randint(rng3, (1,), 1, 2**16)
        seed = seed[0].astype(int)

        sample_optimizer1 = optax.chain(
            optax.scale(0.5),
            optax.add_noise(1.0, gamma=0.0, seed=seed),
            optax.scale_by_schedule(schedule_fn),
            optax.clip(0.5),
        )

        opt_state1 = sample_optimizer1.init(initial_position)

        seed = jrandom.randint(rng4, (1,), 1, 2**16)
        seed = seed[0].astype(int)
        sample_optimizer2 = optax.chain(
            optax.scale(0.5),
            optax.add_noise(0.5**2, gamma=0.0, seed=seed), # 0.5 = sqrt(0.25)
            optax.scale(-1e-5),
            optax.clip(0.5),
        )

        opt_state2 = sample_optimizer2.init(initial_position)

        v_and_g = jax.vmap(jax.value_and_grad(neg_log_prob))

        def sgld(inputs, timestep, sample_optimizer):
            del timestep
            dict_x, opt_state = inputs
            g = jax.vmap(jax.grad(neg_log_prob))(dict_x)
            u, new_opt_state = sample_optimizer.update(g, opt_state)
            new_x = optax.apply_updates(dict_x, u)
            new_x["a"] = jnp.clip(new_x["a"], 
                                  self.min_sampling_action, 
                                  self.max_sampling_action)
            new_x["g"] = jnp.clip(new_x["g"], 
                                  -self.max_g, 
                                  self.max_g) 
            return (new_x, new_opt_state), new_x

        sgld1 = functools.partial(sgld, sample_optimizer=sample_optimizer1)
        sgld2 = functools.partial(sgld, sample_optimizer=sample_optimizer2)

        last, chain = jax.lax.scan(sgld1, (initial_position, opt_state1), jnp.arange(num_samples))

        if num_samples2 > 0:
            last, _ = jax.lax.scan(sgld2, (last[0], opt_state2), jnp.arange(num_samples2))

        v = jax.vmap(neg_log_prob)(last[0])
        if return_all:
            a = chain["a"]
            g = chain["g"]
        else:
            a = last[0]["a"]
            g = last[0]["g"]
        return state, (a, g, v)

    @functools.partial(jax.jit, static_argnums=(0,))
    def get_action(
        self,
        rng,
        agent_state,
        obs,
        alpha,
        num_samples=100,
        num_samples2=100,
    ):
        a_t = jnp.zeros((1, self.actions_dims))
        g_t = jnp.zeros((1, 1))
        obs = jnp.reshape(obs, (1, -1))
        state, (a, rtg, v) = self.sampling(
            rng,
            agent_state.ema_params,
            agent_state.state,
            obs,
            a_t,
            g_t,
            num_samples=num_samples,
            num_samples2=num_samples2,
            alpha=alpha,
            num_mcmc_chains=self.num_action_samples,
            return_all=False,
        )
        idx = jnp.argmax(-v)
        a = jnp.reshape(a[idx], (-1, ))
        a = self.denorm_act(a)
        return a.squeeze(), rtg[idx]

    @functools.partial(jax.jit, static_argnums=(0,))
    def sgd_step(
        self,
        rng: jax.random.PRNGKey,
        agent_state: TrainingState,
        transitions: Sequence[jnp.ndarray],
    ) -> Tuple[TrainingState, Dict]:
        """Performs an SGD step on a batch of transitions."""
        o_t = transitions["states"]
        a_t = transitions["actions"]
        g_t = transitions["rtg"]
        rng1, rng2, rng = jrandom.split(rng, 3)

        sampling = functools.partial(self.sampling, 
                                     num_mcmc_chains=self.num_mcmc_chains, 
                                     )
        parallel_sampling = jax.vmap(sampling, (0, None, None, 1, 1, 1))

        rng = jrandom.split(rng, o_t.shape[0])

        new_state, negative_samples = parallel_sampling(
            rng,
            agent_state.params,
            agent_state.state,
            o_t[None],
            a_t[None],
            g_t[None],
        )

        new_state = jax.tree_map(lambda s: jnp.mean(s, 0), new_state)

        gradients, (new_state, stats) = jax.grad(self.loss, 1, has_aux=True)(
            rng2, agent_state.params, new_state, transitions, negative_samples
        )

        grad_norm = jax.tree_map(lambda g: jnp.linalg.norm(g), gradients)
        grad_norm, _ = jax.tree_util.tree_flatten(grad_norm)
        grad_norm = jnp.array(grad_norm).mean()

        updates, new_opt_state = self.net_optimizer.update(gradients, agent_state.net_opt_state)
        new_params = optax.apply_updates(agent_state.params, updates)

        new_ema_params, new_ema_state = self.ema_fn.apply(None, agent_state.ema_state, None, new_params)

        return (
            TrainingState(
                params=new_params,
                ema_params=new_ema_params,
                state=new_state,
                net_opt_state=new_opt_state,
                ema_state=new_ema_state,
                step=agent_state.step + 1,
            ),
            {**stats, **{"update_grad_norm": grad_norm}},
        )

    def loss(self, rng, params, state, transitions, negative_samples):
        # o_t, a_t, r_t = transitions
        o_t = transitions["states"]
        a_t = transitions["actions"]
        r_t = transitions["rtg"]
        rng1, rng2, rng3 = jrandom.split(rng, 3)
        noise_a_t, noise_r_t, chain_log_density = negative_samples

        if self.all_grad_penalty:
            noise_a_t2 = jnp.reshape(noise_a_t, (o_t.shape[0], -1, self.actions_dims))
            noise_r_t2 = jnp.reshape(noise_r_t, (o_t.shape[0], -1, 1))
        else:
            noise_a_t2 = jnp.reshape(noise_a_t[:, -1], (o_t.shape[0], -1, self.actions_dims))
            noise_r_t2 = jnp.reshape(noise_r_t[:, -1], (o_t.shape[0], -1, 1))

        noise_a_t = noise_a_t[:, -1, :, 0]
        noise_r_t = noise_r_t[:, -1, :, 0]

        # B x S x dims
        all_a_t = jnp.concatenate((a_t[:, None], noise_a_t), 1)
        all_r_t = jnp.concatenate((r_t[:, None], noise_r_t), 1)

        noise_a_t2 = jnp.concatenate((a_t[:, None], noise_a_t2), 1)
        noise_r_t2 = jnp.concatenate((r_t[:, None], noise_r_t2), 1)

        def fwd(o, a, r, p, s):
            x, s = self.network.apply(p, s, o, a, r)
            return jnp.mean(x), s

        fwd = functools.partial(fwd, p=params, s=state)

        grad_wrt_ar, _ = jax.vmap(jax.grad(fwd, (1, 2), has_aux=True), (None, 1, 1))(o_t,
                                                                                     noise_a_t2, 
                                                                                     noise_r_t2)

        #def fwd(o, a, r, p, s):
        #    x, s = self.network.apply(p, s, o, a, r)
        #    return x, s

        #fwd = functools.partial(fwd, p=params, s=state)

        log_density, new_state = jax.vmap(fwd, (None, 1, 1))(o_t, 
                                                             all_a_t, 
                                                             all_r_t)
        new_state = jax.tree_map(lambda s: jnp.mean(s, 0), new_state)
        grad_wrt_ar = jnp.concatenate(grad_wrt_ar, -1)
        print(log_density.shape)
        log_p = jax.nn.log_softmax(log_density, 0)
        grad_norm = jnp.linalg.norm(grad_wrt_ar, ord=jnp.inf, axis=-1)
        ## max(0, grad_norm-1)**2
        grad_penalty = (jnp.clip(grad_norm - self.grad_penalty, a_min=0) ** 2).mean()
        target_log_density = log_density[0]
        noise_log_density = log_density[1:]

        #weights = jnp.exp(transitions["rtg"] * self.eta)
        #weights = jnp.reshape(weights, (-1, 1))
        #assert weights.shape == log_p[0].shape

        if self.loss_type == "CD":
            loss = (
                - jnp.mean(target_log_density)
                + jnp.mean(noise_log_density)
                + grad_penalty
                + self.density_penalty * (log_density**2).mean()
            )
        elif self.loss_type == "infoNCE":
            loss = - jnp.mean(log_p[0]) + grad_penalty + self.density_penalty * (log_density**2).mean()
        else:
            raise ValueError("Loss Type Not Define")

        l2_params = [p for ((mod_name, _), p) in tree.flatten_with_path(params) if "nosn_ln" not in mod_name]
        norm_params = l2_loss(l2_params)

        return loss, (
            new_state,
            {
                "loss": jnp.mean(loss),
                "target_log_density": jax.nn.logsumexp(target_log_density),
                "noise_log_density": jax.nn.logsumexp(noise_log_density) - jnp.log(noise_log_density.shape[0]),
                # "chain_noise_log_density": jax.nn.logsumexp(chain_log_density)
                # - jnp.log(8 * chain_log_density.shape[2]),
                "max_log_density": jnp.max(log_density),
                "min_log_density": jnp.min(log_density),
                "MSE a_t": jnp.mean((noise_a_t - a_t[:, None]) ** 2),
                "MSE g_t": jnp.mean((noise_r_t - r_t[:, None]) ** 2),
                "max_input_grad_norm": jnp.max(grad_norm),
                "mean_input_grad_norm": jnp.mean(grad_norm),
                "min_input_grad_norm": jnp.min(grad_norm),
                "mean_noise_g_t": jnp.mean(noise_r_t),
                "min_o_t": jnp.min(o_t),
                "mean_o_t": jnp.mean(o_t),
                "max_o_t": jnp.max(o_t),
                "min_g_t": jnp.min(r_t),
                "mean_g_t": jnp.mean(r_t),
                "max_g_t": jnp.max(r_t),
                "min_a_t": jnp.min(a_t),
                "mean_a_t": jnp.mean(a_t),
                "max_a_t": jnp.max(a_t),
                "mean_noise_a_t": jnp.mean(noise_a_t),
                "max_noise_a_t": jnp.max(noise_a_t),
                "grad_penalty": grad_penalty,
                "norm_params": norm_params,
            },
        )
