from dataclasses import dataclass
from typing import Callable, NamedTuple, Sequence, Tuple, Union

import jax, jax.numpy as jnp
import haiku as hk
import math

from relax.network.blocks import Activation, QNet


class DDSQParams(NamedTuple):
    q1: hk.Params
    q2: hk.Params
    target_q1: hk.Params
    target_q2: hk.Params
    log_alpha: jax.Array
    log_beta: jax.Array


norm_fn = lambda x: jnp.std(x, axis=1) + 1e-6  # [B, N] to [B, ]


@dataclass(init=False)
class DDSQNet:
    q: Callable[[hk.Params, jax.Array, jax.Array], jax.Array]
    num_timesteps: int
    act_dim: int
    target_entropy: float
    with_reflect: bool
    prob_get_softmax: float
    num_particles: int

    def __init__(self, q, num_timesteps, act_dim, target_entropy, with_reflect, prob_get_softmax, num_particles):
        self.q = q
        self.num_timesteps = num_timesteps
        self.act_dim = act_dim
        self.target_entropy = target_entropy
        self.with_reflect = with_reflect
        self.prob_get_softmax = prob_get_softmax
        self.num_particles = num_particles

        @jax.jit
        def get_sum_log_density(q1_params, q2_params, obs, a: jax.Array) -> jax.Array:
            q = jnp.minimum(self.q(q1_params, obs, a), self.q(q2_params, obs, a))
            return jnp.sum(q)

        score = jax.jit(jax.grad(get_sum_log_density, argnums=3))

        def diffusion(
            policy_params: hk.Params,
            obs: jax.Array,
            init_acts: jax.Array,
            noises: jax.Array,
            normalizer,
            step_start,
            step_end,
        ) -> jax.Array:
            log_alpha, log_beta, q1_params, q2_params = policy_params
            alpha = jnp.exp(log_alpha)
            steps = jnp.linspace(step_start, step_end, self.num_timesteps)

            lam = alpha
            a = init_acts

            noises_and_indices = (noises, jnp.arange(self.num_timesteps))

            def mcmc(a, noise_and_index):
                noise, t = noise_and_index
                step_size = steps[t]
                grads = score(q1_params, q2_params, obs, a)
                shift = 0.5 / (lam * normalizer) * step_size * grads + jnp.sqrt(step_size) * noise

                if self.with_reflect:
                    a = self.reflect(a, shift)
                else:
                    a = jnp.clip(a + shift, -1, 1)

                return a, None

            a, _ = jax.lax.scan(mcmc, a, noises_and_indices)

            return a

        self.diffusion = jax.jit(diffusion)

    @staticmethod
    def reflect(x, dx):
        virtual = x + dx  # [B, A]
        boundary = jnp.where(dx >= 0, 1, -1)
        distance = virtual - boundary
        count = jnp.floor(distance / 2.0)
        remain = distance - count * 2.0
        last_boundary = jnp.where(count % 2 == 0, boundary, -boundary)
        direction = -jnp.sign(last_boundary)
        target = jnp.where((virtual >= -1) & (virtual <= 1), virtual, last_boundary + direction * remain)
        return target

    def init_act(self, key, policy_params, obs, from_softmax):
        log_alpha, log_beta, q1_params, q2_params = policy_params
        k1, k2 = jax.random.split(key, 2)
        alpha = jnp.exp(log_alpha)

        def get_batch(batch_obs):
            a_particles = jax.random.uniform(
                k1, shape=(self.num_particles, self.act_dim), minval=-1, maxval=1
            )  # [N, A]

            def q_for_obs(obs):
                obs_tiled = jnp.tile(obs[None, :], (self.num_particles, 1))  # [N, S]
                q1 = self.q(q1_params, obs_tiled, a_particles)  # [N,]
                q2 = self.q(q2_params, obs_tiled, a_particles)  # [N,]
                q_min = jnp.minimum(q1, q2)
                return q_min  # [N,]

            qs = jax.vmap(q_for_obs)(batch_obs)  # [B, N]

            normalizer = norm_fn(qs) + 1e-6  # [B,]

            if from_softmax:
                idx = jax.random.categorical(k2, qs / (alpha * normalizer[:, None]), axis=1)  # [B,]
            else:
                idx = jnp.argmax(qs, axis=1)  # [B,]

            selected_actions = a_particles[idx, :]  # [B, A]
            return selected_actions, normalizer

        if len(obs.shape) > 1:
            a, normalizer = get_batch(obs)
            return a, jnp.broadcast_to(
                jnp.expand_dims(normalizer, axis=1), shape=(obs.shape[0], self.act_dim)
            )  # [B, A]
        else:
            a, normalizer = get_batch(jnp.expand_dims(obs, axis=0))
            return a[0], jnp.broadcast_to(normalizer, (self.act_dim,))  # [A, ], [A, ]

    def candidate_steps(self, num):
        start_steps = 10 ** (jnp.linspace(0, -4, num))
        end_steps = jnp.array([1e-4])
        x, y = jnp.meshgrid(start_steps, end_steps, indexing="ij")
        return x.ravel(), y.ravel()

    def grid_search(self, policy_params: hk.Params, obs: jax.Array, init_acts, noises, normalizer, num):
        _, _, q1_params, q2_params = policy_params
        steps_start, steps_end = self.candidate_steps(num)

        def sample(step_start, step_end):
            act = self.diffusion(policy_params, obs, init_acts, noises, normalizer, step_start, step_end)
            q1 = self.q(q1_params, obs, act)
            q2 = self.q(q2_params, obs, act)
            q = jnp.minimum(q1, q2)
            return act, q

        acts, qs = jax.vmap(sample, in_axes=(0, 0))(steps_start, steps_end)
        q_best_ind = jnp.argmax(qs, axis=0, keepdims=True)
        act = jnp.take_along_axis(acts, q_best_ind[..., None], axis=0).squeeze(axis=0)
        return act

    def get_action(self, key: jax.Array, policy_params: hk.Params, obs: jax.Array) -> jax.Array:
        prob_key, init_key, noise_key = jax.random.split(key, 3)
        if len(obs.shape) > 1:
            noise_shape = (self.num_timesteps, obs.shape[0], self.act_dim)
            probs = jax.random.uniform(prob_key, shape=(obs.shape[0],), minval=0, maxval=1)
        else:
            noise_shape = (self.num_timesteps, self.act_dim)
            probs = jax.random.uniform(prob_key, minval=0, maxval=1)

        get_softmax = probs <= self.prob_get_softmax  # [B, ] or [, ]

        softmax_init_acts, softmax_normalizer = self.init_act(init_key, policy_params, obs, from_softmax=True)
        greedy_init_acts, greedy_normalizer = self.init_act(
            jax.random.PRNGKey(0), policy_params, obs, from_softmax=False
        )

        softmax_noises = jax.random.normal(noise_key, shape=noise_shape)
        greedy_noises = jnp.zeros(shape=noise_shape)

        if len(obs.shape) > 1:
            noises = get_softmax[None, :, None] * softmax_noises + (1 - get_softmax[None, :, None]) * greedy_noises
            init_acts = get_softmax[:, None] * softmax_init_acts + (1 - get_softmax[:, None]) * greedy_init_acts
            normalizer = get_softmax[:, None] * softmax_normalizer + (1 - get_softmax[:, None]) * greedy_normalizer
        else:
            noises = get_softmax * softmax_noises + (1 - get_softmax) * greedy_noises
            init_acts = get_softmax * softmax_init_acts + (1 - get_softmax) * greedy_init_acts
            normalizer = get_softmax * softmax_normalizer + (1 - get_softmax) * greedy_normalizer

        return self.grid_search(policy_params, obs, init_acts, noises, normalizer, 10)

    def get_softmax(self, key, policy_params: hk.Params, obs: jax.Array) -> jax.Array:
        init_key, noise_key = jax.random.split(key, 2)
        if len(obs.shape) > 1:
            noise_shape = (self.num_timesteps, obs.shape[0], self.act_dim)
        else:
            noise_shape = (self.num_timesteps, self.act_dim)
        init_acts, normalizer = self.init_act(init_key, policy_params, obs, from_softmax=True)
        noises = jax.random.normal(noise_key, shape=noise_shape)
        return self.grid_search(policy_params, obs, init_acts, noises, normalizer, 10)

    def get_deterministic_action(self, policy_params: hk.Params, obs: jax.Array) -> jax.Array:
        init_acts, normalizer = self.init_act(jax.random.PRNGKey(0), policy_params, obs, from_softmax=False)
        noises = jnp.zeros(shape=(self.num_timesteps,))
        return self.grid_search(policy_params, obs, init_acts, noises, normalizer, 10)


def create_ddsq_net(
    key: jax.Array,
    obs_dim: int,
    act_dim: int,
    num_timesteps: int,
    with_reflect: bool,
    hidden_sizes: Sequence[int],
    activation: Activation = jax.nn.relu,
    init_alpha: float = 0.05,
    init_beta: float = 0.05,
    prob_get_softmax: float = 0.15,
    num_particles: int = 100,
) -> Tuple[DDSQNet, DDSQParams]:
    q = hk.without_apply_rng(hk.transform(lambda obs, act: QNet(hidden_sizes, activation)(obs, act)))

    @jax.jit
    def init(key, obs, act):
        q1_key, q2_key = jax.random.split(key, 2)
        q1_params = q.init(q1_key, obs, act)
        q2_params = q.init(q2_key, obs, act)
        target_q1_params = q1_params
        target_q2_params = q2_params
        log_alpha = jnp.array(math.log(init_alpha), dtype=jnp.float32)
        log_beta = jnp.array(math.log(init_beta), dtype=jnp.float32)
        return DDSQParams(q1_params, q2_params, target_q1_params, target_q2_params, log_alpha, log_beta)

    sample_obs = jnp.zeros((1, obs_dim))
    sample_act = jnp.zeros((1, act_dim))
    params = init(key, sample_obs, sample_act)

    net = DDSQNet(
        q=q.apply,
        num_timesteps=num_timesteps,
        act_dim=act_dim,
        target_entropy=-act_dim,
        with_reflect=with_reflect,
        prob_get_softmax=prob_get_softmax,
        num_particles=num_particles,
    )
    return net, params
