import jax
from flax import nnx
from basics.autoregressive_iqn import KRIQN
from typing import Callable, Optional, Any, Sequence
from stable_baselines3.common.preprocessing import get_action_dim, get_flattened_obs_dim
from risk_morl.buffer import ReplayBufferSamples
import jax.numpy as jnp
from risk_morl.utils.network_utils import copy_param, quanitle_regression_loss, polyak_update
import numpy as np
import optax
from functools import partial
import cloudpickle
from flax.struct import dataclass


@dataclass
class FlaxModule:
    graph: nnx.GraphDef
    state: nnx.State
    others: Optional[Any] = None


class DQNNetwork(KRIQN):
    def __init__(self,
                 obs_dim: int,
                 action_dim: int,
                 reward_dim: int,
                 n_critics: int = 2,
                 *,
                 rngs
                 ):
        self.embedding_dim = 64
        super().__init__(
            obs_dim,
            self.embedding_dim,
            reward_dim,
            n_critics,
            ff_action=False,
            rngs=rngs
        )
        self.n_action = action_dim
        self.actions = nnx.Embed(self.n_action, self.embedding_dim, rngs=rngs)
        self.rngs = rngs

    def __call__(self, observations, actions, taus, weight):

        action_emb = self.actions(actions)

        return super().__call__(observations, action_emb, taus, weight)

    def q_values(self, observation, taus, weight):
        action_emb = self.actions.embedding.value

        actions = jnp.repeat(action_emb[None], repeats=observation.shape[0], axis=0)


        q_vals = jax.vmap(super().__call__, in_axes=(None, 1, None, None),
                          out_axes=0)(observation, actions, taus, weight)
        return q_vals

    def predict_action(self, observation, risk_taus, weight):
        risk_q = self.q_values(observation, risk_taus, weight)

        # (a, b, w, n, c) ->
        # (1, b, w, 1, 1) * (a, b, w, n,  c)
        risk_q_vals = (weight[None, ..., None, None] * risk_q)

        # (a, b, n * c) -> (a, b) -> (b)
        action = risk_q_vals.sum(axis=-3).mean(axis=-2).min(axis=-1).argmax(axis=0)

        return action

    def next_q(self, next_observations, risk_taus, taus, weight):

        next_action = self.predict_action(next_observations, risk_taus, weight)

        next_q_distr = self.__call__(next_observations, next_action, taus, weight)

        next_q_distr = next_q_distr.reshape(next_q_distr.shape[0], next_q_distr.shape[1], -1)


        # (b ,w, 1 ) * (b, w, n * c) -> (b, w, n * c)
        index = (weight[..., None] * next_q_distr).sum(axis=1).argsort(axis=-1)

        # index (b, 1, n * c)
        index = index[..., None, :]
        next_q_distr = jnp.take_along_axis(next_q_distr, index, axis=-1)

        # (b, w, n * c - 2)
        return next_q_distr

    def td_target(self, next_observation, reward:jax.Array, taus:jax.Array, risk_taus:jax.Array, weight,
                  done, gamma, ):
        next_q_distr = self.next_q(next_observation, risk_taus, taus, weight)

        return reward.reshape(-1, reward.shape[1], 1) + (1 - done.reshape(-1, 1, 1)) * gamma * next_q_distr

    def loss_fn(self, observations, actions, td_target, weight, key):
        actions = self.actions(actions)

        return super().loss_fn(observations, actions, td_target, weight, key)


class MODQNPolicy(object):
    q_net: DQNNetwork
    s_q_net: nnx.State
    g_q_net: nnx.GraphDef

    opt_q_net: nnx.Optimizer
    metric_q_net: nnx.Metric
    target_param: nnx.Param

    update_fn: Callable
    train_step_fn: Callable

    def __init__(self,
                 env,
                 reward_dim: int,
                 n_env: int = 1,
                 gamma: float = 0.99,
                 soft_update_ratio: float = 5e-3,
                 critic_lr: float = 3e-4,
                 custom_weights: Optional[Sequence[jax.Array]] = None,
                 risk_measure: Callable = lambda x: x * 0.1,  # TV@R 10%
                 *,
                 seed: int = 42,
                 ):

        self.custom_weights = custom_weights
        if self.custom_weights is None:
            self.random_weight = self._random_weight
        else:
            self.random_weight = self.build_preference_fn()
        self.reward_dim = reward_dim
        self.learning_rate = critic_lr
        self.n_env = n_env
        self.env = env
        self.gamma = gamma
        self.soft_update_ratio = soft_update_ratio
        self.obs_dim = get_flattened_obs_dim(self.env.observation_space)
        self.action_space = self.env.action_space
        self.n_action = self.action_space.n

        self.seed = seed
        self.risk_measure = risk_measure
        self.rng = nnx.Rngs(self.seed)
        self.build()

    def build_preference_fn(self):
        arrays = jnp.asarray(self.custom_weights)

        def sampler(key, place_holder):
            w = jax.random.choice(key, arrays, axis=0, shape=(place_holder.shape[0], ))
            return w
        return jax.jit(sampler)


    @partial(jax.jit, static_argnums=(0,))
    def _random_weight(self, key, place_holder):
        alphas = jnp.ones(shape=(self.reward_dim, ))
        # alphas = alphas.at[-1].set(1)
        # print("WARNING , (5, 5, 1) dirichlet for mine cart")
        w = jax.random.dirichlet(key, alpha=alphas ,
                                 shape=(place_holder.shape[0],))
        return w

    def predict(self, observation, weight):
        squeeze = False
        if len(observation.shape) == 1:
            observation = observation[None]
            squeeze = True
        if len(weight.shape) == 1:
            weight = weight[None]
            squeeze = True
        action, self.s_q_net = self._predict(
            self.g_q_net, self.s_q_net,
            observation, weight, )
        if squeeze:
            action = np.asarray(action.squeeze(axis=0).copy())

        return action, weight

    @partial(jax.jit, static_argnums=(0,))
    def _predict(self, g_q_net, s_q_net, obs, weight):
        q_net, o, m = nnx.merge(g_q_net, s_q_net)
        taus = jnp.linspace(0, 1, 32)[None, None]
        taus = jnp.repeat(taus, repeats=obs.shape[0], axis=0)
        taus = jnp.repeat(taus, repeats=weight.shape[1], axis=1)
        taus = self.risk_measure(taus)

        new_action = q_net.predict_action(obs, taus, weight)
        _, actor_state = nnx.split((q_net, o, m))
        return new_action, actor_state


    def build(self):
        self.q_net = DQNNetwork(self.obs_dim, self.n_action, self.reward_dim, rngs=self.rng)
        self.opt_q_net = nnx.Optimizer(self.q_net, optax.adam(self.learning_rate, b1=0.5, b2=0.9))
        q_met = {f"q_val_{i}": nnx.metrics.Average(f'q_val_{i}') for i in range(self.reward_dim)}

        self.metric_q_net = nnx.MultiMetric(
            q_loss=nnx.metrics.Average('q_loss'),
            **q_met
        )
        self.g_q_net, self.s_q_net = nnx.split(
            (self.q_net, self.opt_q_net, self.metric_q_net)
        )
        self.target_param = copy_param(self.q_net)
        # this strange design pattern is because of just in time compilation efficiency.
        # I did not want to...
        self.update_fn = self.build_q_net_update_fn()
        self.train_step_fn = self.build_train_step()

    def train_step(self, batch: ReplayBufferSamples):
        self.s_q_net, self.target_param = self.train_step_fn(
            batch, self.g_q_net, self.s_q_net, self.target_param,
            self.rng()
        )

    def build_train_step(self, ):
        q_net_update_fn = self.update_fn
        polyak_update_fn = jax.jit(partial(polyak_update, soft_update_ratio=self.soft_update_ratio))

        def update_fn(batch: ReplayBufferSamples,
                      g_qnet: nnx.GraphDef, s_qnet: nnx.State, target_param: nnx.Param,
                      key: jax.Array
                      ):
            keys = jax.random.split(key, 2)
            s_qnet = q_net_update_fn(g_qnet, s_qnet, target_param, batch, keys[0])

            target_param = polyak_update_fn(
                g_qnet, s_qnet, target_param
            )

            return s_qnet, target_param

        return jax.jit(update_fn)

    def sample_preference(self, nums: int = 1):
        '''
        w, self.s_sampler = self._sample_preference(
            self.g_sampler, self.s_sampler,
            jnp.ones(shape=(nums, )))
        w = np.asarray(w).copy()
        '''
        return self.random_weight(self.rng(), np.ones(nums, ))

    def build_q_net_update_fn(self,
                              ):
        random_weight = self.random_weight

        def update_fn(g_q_net: nnx.GraphDef,
                      s_q_net: nnx.State,
                      target_critic_params,
                      batch: ReplayBufferSamples,
                      key):
            keys = jax.random.split(key, 3)
            q_net, opt_q_net, metric_q_net = nnx.merge(g_q_net, s_q_net)
            q_net: DQNNetwork
            graph, param, *others = nnx.split(q_net, nnx.Param, ...)
            target_q_network: DQNNetwork = nnx.merge(graph, target_critic_params, *others)

            w = random_weight(keys[-1], batch.observations)
            B = batch.observations.shape[0]
            next_taus = jax.random.uniform(keys[0], shape=(B, self.reward_dim, 6))
            next_risk_taus = self.risk_measure(next_taus)
            td_target = jax.lax.stop_gradient(target_q_network.td_target(batch.next_observations, batch.rewards, next_taus,
                                                   next_risk_taus, w, batch.dones, self.gamma))

            def loss_fn(model):
                loss = model.loss_fn(batch.observations, batch.actions.squeeze(axis=1), td_target, w, keys[-1])
                loss = loss.sum(axis=(-1, -2)).mean()
                q_val = model(batch.observations, batch.actions.squeeze(axis=1), next_taus, w).mean(axis=-2).min(axis=-1)
                return loss, (loss, q_val)

            grads, (loss, q_val) = nnx.grad(loss_fn, argnums=0, has_aux=True)(q_net)
            q_val = q_val.mean(axis=0)
            opt_q_net.update(grads=grads)
            kwargs = {f"q_val_{i}": q_val[i] for i in range(self.reward_dim)}
            metric_q_net.update(q_loss=loss, **kwargs)
            _, critic_state = nnx.split((q_net, opt_q_net, metric_q_net))
            return critic_state

        return jax.jit(update_fn)

    @staticmethod
    @jax.jit
    def compute_metric(graph, state):
        module, opt, metric = nnx.merge(graph, state)
        metric: nnx.Metric
        result = metric.compute()
        metric.reset()
        _, state = nnx.split((module, opt, metric))
        return result, state

    def log_all(self):
        result = { }

        critic_metric, self.s_q_net = self.compute_metric(self.g_q_net, self.s_q_net)
        result.update(critic_metric)

        return result

    def save(self, path):
        save_attr = self._save_attributes()
        with open(path, 'wb') as f:
            cloudpickle.dump(save_attr, f)

    def _save_attributes(self):
        q_net = FlaxModule(self.g_q_net, self.s_q_net, self.target_param)

        return { "critic": q_net, }

    def load(self, path):
        with open(path, 'rb') as f:
            attributes = cloudpickle.load(f)

        critic_attr = attributes["critic"]

        self.g_q_net, self.s_q_net, self.target_param = (
            critic_attr.graph,
            critic_attr.state,
            critic_attr.others,
        )
