from basics.layers import Lagrangian, Constant
from risk_morl.architecture.actor import Actor
from risk_morl.architecture.critic import VMapContinuousQNet
from flax import nnx
from stable_baselines3.common.preprocessing import get_action_dim, get_flattened_obs_dim
from risk_morl.buffer import ReplayBufferSamples
import jax.numpy as jnp
from risk_morl.utils.network_utils import copy_param, quanitle_regression_loss, polyak_update
import numpy as np
import optax
from functools import partial
import jax
from typing import Callable, Literal, Optional, Any
from dataclasses import dataclass
import cloudpickle
from risk_morl.utils.mo_utils import simplex_grid_points


@dataclass
class FlaxModule:
    graph: nnx.GraphDef
    state: nnx.State
    others: Optional[Any] = None


class DummyMetric(nnx.metrics.Average):
    def __init__(self,
                 value,
                 argname: str = 'values', ):
        super().__init__(argname)
        self.value =value

    def compute(self, ):
        return self.value

class MOSACPolicy(object):


    critic: VMapContinuousQNet
    s_critic: nnx.State
    g_critic: nnx.GraphDef

    opt_critic: nnx.Optimizer
    metric_critic: nnx.Metric
    target_param: nnx.Param

    actor: Actor
    s_actor: nnx.State
    g_actor: nnx.GraphDef
    opt_actor: nnx.Optimizer
    metric_actor: nnx.Metric

    ent_coef: Lagrangian
    opt_ent: nnx.Optimizer
    metric_ent: nnx.Metric
    g_ent: nnx.GraphDef
    s_ent: nnx.State
    target_entropy: float

    critic_update_fn: Callable
    actor_update_fn: Callable
    ent_coef_update_fn: Callable
    train_step_fn: Callable | None = None

    def __init__(self,
                 env,
                 reward_dim: int,
                 n_env: int = 1,
                 gamma: float = 0.99,
                 soft_update_ratio: float = 5e-3,
                 critic_lr: float = 3e-4,
                 actor_lr: float = 3e-4,
                 risk_measure: Callable = lambda x: x * 0.1,  # TV@R 10%
                 ent_coef: float | Literal['auto'] = 'auto',
                 target_entropy: float | Literal['auto'] = 'auto',
                 *,
                 discrete_weights: bool = False,
                 num_grids: int = 20,
                 seed: int = 42,
                 ):
        self.learn_ent_coef = not isinstance(ent_coef, float)
        self.given_ent_coef = ent_coef
        self.discrete_weights = discrete_weights
        self.num_grids = num_grids
        self.reward_dim = reward_dim

        if self.discrete_weights:
            self.random_weight = self.build_discrete()
        else:
            self.random_weight = self._random_weight

        self.critic_lr = critic_lr

        self.actor_lr = actor_lr
        self.n_env = n_env
        self.env = env
        self.gamma = gamma
        self.soft_update_ratio = soft_update_ratio
        self.obs_dim = get_flattened_obs_dim(self.env.observation_space)
        self.action_space = self.env.action_space
        self.n_action = get_action_dim(self.action_space)
        self.seed = seed
        self.risk_measure = risk_measure
        self.rng = nnx.Rngs(self.seed)
        self.build(target_entropy)

    def build_discrete(self, ):
        grids = simplex_grid_points(n=self.reward_dim, k=self.num_grids)
        grids = jnp.asarray(list(map(lambda x: jnp.asarray(x, dtype=jnp.float32), grids)))
        def sample(key, placeholder):
            return jax.random.choice(key, grids, shape=(placeholder.shape[0], ), axis=0)
        return jax.jit(sample)

    def predict(self, observation, weight):
        action, self.s_actor = self._predict(
            self.g_actor, self.s_actor,
            observation, weight)
        return np.asarray(action).copy(), weight

    @staticmethod
    @jax.jit
    def _predict(g_actor, s_actor, obs, weight):
        actor, o, m = nnx.merge(g_actor, s_actor)
        new_action = actor(obs, weight)
        _, actor_state = nnx.split((actor, o, m))
        return new_action, actor_state

    def build(self, target_entropy):
        self.build_critic()
        self.build_actor()
        self.build_ent_coef(target_entropy)
        self.train_step_fn = self.build_train_step()

    def build_critic(self):
        self.critic = VMapContinuousQNet(self.obs_dim, self.n_action, self.reward_dim, rngs=self.rng)
        self.opt_critic = nnx.Optimizer(self.critic, optax.adam(self.critic_lr))
        self.metric_critic = nnx.MultiMetric(
            q_loss=nnx.metrics.Average('q_loss'),
        )
        self.g_critic, self.s_critic = nnx.split(
            (self.critic, self.opt_critic, self.metric_critic)
        )
        self.target_param = copy_param(self.critic)
        # this strange design pattern is because of just in time compilation efficiency.
        # I did not want to...
        self.critic_update_fn = self.build_critic_update_fn()

    def build_actor(self):
        self.actor = Actor(self.obs_dim, self.n_action, self.reward_dim, rngs=self.rng)
        self.opt_actor = nnx.Optimizer(self.actor, optax.adam(self.actor_lr))

        self.metric_actor = nnx.MultiMetric(
            pi_loss=nnx.metrics.Average('pi_loss'),
            ent_=nnx.metrics.Average('ent')
        )
        self.g_actor, self.s_actor = nnx.split(
            (self.actor, self.opt_actor, self.metric_actor)
        )
        self.actor_update_fn = self.build_actor_update_fn()

    def build_ent_coef(self, target_entropy: float | Literal['auto']):
        if self.learn_ent_coef:
            self.ent_coef = Lagrangian()
            self.opt_ent = nnx.Optimizer(self.ent_coef, optax.adam(self.actor_lr))
            self.metric_ent = nnx.MultiMetric(
                ent_coef=nnx.metrics.Average('current_ent'),
                ent_loss=nnx.metrics.Average('ent_loss')
            )

        else:
            self.ent_coef = Constant(self.given_ent_coef)
            self.opt_ent = nnx.Optimizer(self.ent_coef, optax.adam(self.actor_lr))
            self.metric_ent = nnx.MultiMetric(
                ent_coef=DummyMetric(self.given_ent_coef, 'current_ent'),
            )
        self.g_ent, self.s_ent = nnx.split(
            (self.ent_coef, self.opt_ent, self.metric_ent)
        )
        if target_entropy == 'auto':
            self.target_entropy = -np.prod(self.action_space.shape)
        else:
            self.target_entropy = target_entropy
        self.ent_coef_update_fn = self.build_ent_coef_update_fn()

    def train_step(self, batch: ReplayBufferSamples):


        self.s_critic, self.target_param, self.s_actor, self.s_ent = self.train_step_fn(
            batch, self.g_critic, self.s_critic, self.target_param,
            self.g_actor, self.s_actor,
            self.g_ent, self.s_ent,
            self.rng()
        )


    def build_train_step(self,):
        critic_update_fn = self.critic_update_fn
        actor_update_fn = self.actor_update_fn
        ent_coef_update_fn = self.ent_coef_update_fn
        polyak_update_fn = jax.jit(partial(polyak_update, soft_update_ratio=self.soft_update_ratio))


        def update_fn(batch: ReplayBufferSamples,
                g_critic: nnx.GraphDef, s_critic: nnx.State, target_param: nnx.Param,
                g_actor: nnx.GraphDef, s_actor: nnx.State,
                g_ent: nnx.GraphDef, s_ent: nnx.State,
                key: jax.Array
        ):
            keys = jax.random.split(key, 2)
            s_critic = critic_update_fn(g_critic, s_critic, target_param,
                                             g_actor, s_actor,
                                             g_ent, s_ent, batch, keys[0]
                                             )

            target_param = polyak_update_fn(
                g_critic, s_critic, target_param
            )

            s_actor, log_prob = actor_update_fn(
                g_critic, s_critic,
                g_actor, s_actor,
                g_ent, s_ent,
                batch,
                keys[1]

            )
            s_ent = ent_coef_update_fn(
                        g_ent, s_ent,
                        log_prob
                    )
            return s_critic, target_param, s_actor, s_ent

        return jax.jit(update_fn)

    @partial(jax.jit, static_argnums=(0,))
    def _random_weight(self, key, place_holder):

        w = jax.random.dirichlet(key, alpha=jnp.ones(shape=(self.reward_dim, ))   ,
                                 shape=(place_holder.shape[0], ))

        return w



    def build_critic_update_fn(self,
                               ):
        def update_fn(g_critic, s_critic, target_critic_params,
                      g_actor, s_actor,
                      g_ent_coef, s_ent_coef,
                      batch: ReplayBufferSamples,
                      key):
            keys = jax.random.split(key, 3)
            critic, opt_critic, metric_critic = nnx.merge(g_critic, s_critic)
            graph, param, *others = nnx.split(critic, nnx.Param, ...)
            target_q_network = nnx.merge(graph, target_critic_params, *others)

            actor, _, _ = nnx.merge(g_actor, s_actor)
            __ent_coef, _, _ = nnx.merge(g_ent_coef, s_ent_coef)
            ent_coef = __ent_coef()

            w = self.random_weight(keys[-1], batch.observations)
            B = batch.observations.shape[0]
            next_taus = jax.random.uniform(keys[0], shape=(B, 32))
            current_taus = jax.random.uniform(keys[1], shape=next_taus.shape)
            next_actions, next_log_prob = actor.sample_and_log_prob(batch.next_observations, w)

            next_quantile = target_q_network(batch.next_observations, next_actions, next_taus, w)

            next_quantile = next_quantile.reshape(B, self.reward_dim, -1).sort(axis=-1)[..., :-2]

            next_quantile = next_quantile - ent_coef * next_log_prob.reshape(B, 1, 1)
            # (B, n_reward, n_quantiles)
            td_target = batch.rewards + (1 - batch.dones.reshape(B, 1, 1)) * next_quantile

            def loss_fn(model):
                current_qf = model(batch.observations, batch.actions, current_taus, w)
                reward_vmap = jax.vmap(quanitle_regression_loss, (1, 1, None), out_axes=1)

                loss = jax.vmap(reward_vmap, in_axes=(None, -1, None), out_axes=-1)(td_target, current_qf, current_taus)
                loss = loss.sum(axis=-1).mean()
                return loss, loss

            grads, loss = nnx.grad(loss_fn, argnums=0, has_aux=True)(critic)
            opt_critic.update(grads)
            metric_critic.update(q_loss=loss)
            _, critic_state = nnx.split((critic, opt_critic, metric_critic))
            return critic_state

        return jax.jit(update_fn)

    def build_actor_update_fn(self,
                              ):
        def actor_update_fn(g_critic, s_critic,
                            g_actor, s_actor,
                            g_ent_coef, s_ent_coef,
                            batch: ReplayBufferSamples,
                            key: jax.Array):
            keys = jax.random.split(key, 2)

            w = self.random_weight(keys[-1], batch.observations)
            critic, _, _ = nnx.merge(g_critic, s_critic)
            actor, opt_actor, metric_actor = nnx.merge(g_actor, s_actor)
            __ent_coef, _, _ = nnx.merge(g_ent_coef, s_ent_coef)
            ent_coef = __ent_coef()
            taus = self.risk_measure(jax.random.uniform(keys[0], shape=(batch.observations.shape[0], 32)))

            def loss_fn(model: Actor):
                action, log_prob = model.sample_and_log_prob(batch.observations, w)
                qf = critic(batch.observations, action, taus, w).mean(axis=-2).min(axis=-1)
                mean_qf = jax.lax.stop_gradient(jnp.mean(qf, axis=0, keepdims=True))
                std_qf = jax.lax.stop_gradient(jnp.std(qf, axis=0, keepdims=True))
                standardized_qf = (qf - mean_qf) / (std_qf + 1e-7)
                qf = (standardized_qf * w).sum(axis=-1, keepdims=True)

                loss = (ent_coef * log_prob - qf).mean(axis=-1)
                loss = loss.mean()
                return loss, (loss, log_prob)

            grads, (loss, log_prob) = nnx.grad(loss_fn, has_aux=True)(actor)
            opt_actor.update(grads)
            log_prob = log_prob.squeeze(axis=-1).mean()

            metric_actor.update(pi_loss=loss, ent=-log_prob)
            _, s_actor = nnx.split((actor, opt_actor, metric_actor))
            return s_actor, log_prob

        return jax.jit(actor_update_fn)

    def build_ent_coef_update_fn(self):
        if self.learn_ent_coef:
            def update_fn(
                    g_ent_coef, s_ent_coef,
                    log_prob
            ):
                ent_coef, opt_ent, metric_ent = nnx.merge(g_ent_coef, s_ent_coef)

                def loss_fn(model):
                    alpha = model()
                    loss = -jnp.log(alpha) * (log_prob + self.target_entropy).mean()
                    return loss, (loss, alpha)

                grads, (loss, alpha) = nnx.grad(loss_fn, has_aux=True)(ent_coef)
                opt_ent.update(grads)
                metric_ent.update(current_ent=alpha, ent_loss=loss)

                _, s_ent_coef = nnx.split((ent_coef, opt_ent, metric_ent))
                return s_ent_coef

            return jax.jit(update_fn)
        else:
            def update_fn(
                    g_ent_coef, s_ent_coef,
                    log_prob
            ):

                return s_ent_coef

            return jax.jit(update_fn)

    @staticmethod
    @jax.jit
    def compute_metric(graph, state):
        module, opt, metric = nnx.merge(graph, state)
        metric: nnx.Metric
        result = metric.compute()
        metric.reset()
        _, state = nnx.split((module, opt, metric))
        return result, state

    def log_all(self):
        result = { }

        critic_metric, self.s_critic = self.compute_metric(self.g_critic, self.s_critic)
        result.update(critic_metric)

        actor_metric, self.s_actor = self.compute_metric(self.g_actor, self.s_actor)
        result.update(actor_metric)

        ent_metric, self.s_ent = self.compute_metric(self.g_ent, self.s_ent)
        result.update(ent_metric)

        return result

    def save(self, path):
        save_attr = self._save_attributes()
        with open(path, 'wb') as f:
            cloudpickle.dump(save_attr, f)

    def _save_attributes(self):
        critic = FlaxModule(self.g_critic, self.s_critic, self.target_param)
        actor = FlaxModule(self.g_actor, self.s_actor)
        ent = FlaxModule(self.g_ent, self.s_ent, self.target_entropy)

        return { "critic": critic, "actor": actor, "ent": ent }

    def load(self, path):
        with open(path, 'rb') as f:
            attributes = cloudpickle.load(f)

        critic_attr = attributes["critic"]
        actor_attr = attributes["actor"]
        ent_attr = attributes["ent"]

        self.g_critic, self.s_critic, self.target_param = (
            critic_attr.graph,
            critic_attr.state,
            critic_attr.others,
        )
        self.g_actor, self.s_actor = actor_attr.graph, actor_attr.state
        self.g_ent, self.s_ent, self.target_entropy = (
            ent_attr.graph,
            ent_attr.state,
            ent_attr.others,
        )


