import functools

import jax
import jax.random as jrandom
import optax
import jax.numpy as jnp
import haiku as hk
from distrax import MixtureSameFamily, Categorical, Normal
from typing import NamedTuple, Sequence, Any, Tuple
from policies.nnets import rcp_network, policy_network
from typing import NamedTuple, Sequence, Tuple, Dict, Iterable
import tree


class TrainingState(NamedTuple):
    """Holds the agent's training state."""

    params: hk.Params
    q_params: hk.Params
    state: hk.Params
    net_opt_state: hk.Params
    step: int


def l2_loss(params: Iterable[jnp.ndarray]) -> jnp.ndarray:
    return 0.5 * sum(jnp.sum(jnp.square(p)) for p in params)


class RCP:
    def __init__(self, state_dims, actions_dims, scale, alpha, use_layer_norm, denorm_act, **kwargs) -> None:
        self.state_dims = state_dims
        self.actions_dims = actions_dims
        network2 = functools.partial(rcp_network, use_layer_norm=use_layer_norm, out_dims=self.actions_dims)
        self.denorm_act = denorm_act
        self.network = hk.without_apply_rng(hk.transform_with_state(network2))

        self.num_atoms = 21
        loc = jnp.linspace(-1, 1, self.num_atoms)
        self.loc = loc
        components_dist = Normal(loc=loc, scale=jnp.ones(self.num_atoms) * scale)

        q_net = functools.partial(policy_network, use_layer_norm=use_layer_norm, out_dims=self.num_atoms)
        self.q_network = hk.without_apply_rng(hk.transform(q_net))

        def get_sample(rng, n, logits):
            logits = jnp.reshape(logits, loc.shape)
            logits += alpha * loc
            mixture_dist = Categorical(logits=logits)
            dist = MixtureSameFamily(mixture_dist, components_dist)
            return dist._sample_n(rng, n)

        def get_log_prob(x, logits):
            mixture_dist = Categorical(logits=logits)
            dist = MixtureSameFamily(mixture_dist, components_dist)
            return dist.log_prob(x)  # .sum(-1)

        self.get_sample = get_sample
        self.get_log_prob = get_log_prob  # jax.vmap(get_log_prob, (0, 0))

    def get_init_state(self, rng: int, learning_rate: float) -> TrainingState:
        dummy_obs = jnp.zeros((1, self.state_dims))
        rng1, rng2 = jrandom.split(rng)
        initial_params, state = self.network.init(rng=rng1, obs=dummy_obs, returns=jnp.zeros((1, 1)))
        q_params = self.q_network.init(rng=rng2, obs=dummy_obs)
        self.net_optimizer = optax.adam(learning_rate)
        net_opt_state = self.net_optimizer.init((initial_params, q_params))

        return TrainingState(
            params=initial_params, q_params=q_params, state=state, net_opt_state=net_opt_state, step=0
        )

    def loss(self, rng, params, q_params, state, transitions):
        action_pred, state = self.network.apply(
            params,
            state,
            transitions["states"],
            transitions["rtg"],
        )
        q_out = self.q_network.apply(q_params, transitions["states"])

        log_p = jax.nn.log_softmax(q_out, -1)
        entropy = -(jnp.exp(log_p) * log_p).sum(-1).mean()
        rtg = transitions["rtg"]
        log_prob = jax.vmap(self.get_log_prob)(rtg, q_out)
        assert action_pred.shape == transitions["actions"].shape
        mse = jnp.mean(jnp.square((transitions["actions"] - action_pred)))
        l2_params = [p for ((mod_name, _), p) in tree.flatten_with_path(params) if "nosn_ln" not in mod_name]
        l2_q_params = [p for ((mod_name, _), p) in tree.flatten_with_path(q_params) if "nosn_ln" not in mod_name]
        param_loss = l2_loss(l2_params + l2_q_params)
        loss = mse - jnp.mean(log_prob) #+ 1e-4 * param_loss
        stats = {
            "loss": jnp.mean(loss),
            "MSE a_t": jnp.mean(mse),
            "MSE g_t": jnp.mean(mse),
            "nll": -jnp.mean(log_prob),
            "entropy_rtg": jnp.mean(entropy),
            "max_rtg": jnp.max(transitions["rtg"]),
            "min_rtg": jnp.min(transitions["rtg"]),
            "mean_rtg": jnp.mean(transitions["rtg"]),
        }
        return jnp.mean(loss), (state, stats)

    @functools.partial(jax.jit, static_argnums=(0,))
    def sgd_step(
        self, rng: int, agent_state: TrainingState, transitions: Sequence[jnp.ndarray]
    ) -> Tuple[TrainingState, Any]:

        gradient, (state, stats) = jax.grad(self.loss, (1, 2), has_aux=True)(
            rng, agent_state.params, agent_state.q_params, agent_state.state, transitions
        )
        updates, new_opt_state = self.net_optimizer.update(gradient, agent_state.net_opt_state)
        new_params, new_q_params = optax.apply_updates((agent_state.params, agent_state.q_params), updates)
        return (
            TrainingState(
                params=new_params,
                q_params=new_q_params,
                state=state,
                net_opt_state=new_opt_state,
                step=agent_state.step + 1,
            ),
            stats,
        )

    @functools.partial(jax.jit, static_argnums=(0,))
    def get_action(self, rng, agent_state, obs, alpha):
        obs = obs.reshape((1, -1))
        logits = self.q_network.apply(agent_state.q_params, obs)
        logits += alpha * jnp.reshape(self.loc, logits.shape)
        idx = jnp.argmax(logits)
        rtg = self.loc[idx]
        rtg = rtg.reshape((1, 1))

        action_pred, agent_state = self.network.apply(
            agent_state.params,
            agent_state.state,
            obs,
            rtg,
        )

        return self.denorm_act(action_pred.squeeze()), rtg
