import numpy as np

from risk_morl.mo_sac.policy import MOSACPolicy
from flax import nnx
import jax.numpy as jnp
import jax
from risk_morl.ewps.ewp_critic import EWPCritic
from risk_morl.kr_iqn_sac.policy import KRCriticPolicy
from basics.layers import Lagrangian, SmoothLagrangian
from risk_morl.architecture.actor import Actor
from risk_morl.utils.network_utils import copy_param, polyak_update
from functools import partial
import optax
from risk_morl.buffer import ReplayBufferSamples

from basics.weight_sampler import PreferenceSampler
from typing import Callable, Literal


class EWPPolicy(KRCriticPolicy):
    critic: EWPCritic
    preference_sampler: PreferenceSampler
    s_sampler: nnx.State
    g_sampler: nnx.GraphDef
    def __init__(self,
                 env,
                 reward_dim: int,
                 n_env: int = 1,
                 gamma: float = 0.99,
                 soft_update_ratio: float = 5e-3,
                 critic_lr: float = 3e-4,
                 actor_lr: float = 3e-4,
                 risk_measure: Callable = lambda x: x * 0.1,  # TV@R 10%
                 ent_coef: float | Literal['auto'] = 'auto',
                 target_entropy: float | Literal['auto'] = 'auto',
                 truncation_upper: int = 1,
                 truncation_lower: int = 0,
                 *,
                 discrete_weights: bool = False,
                 num_grids: int = 20,
                 seed: int = 42,
                 ):

        super().__init__(
            env=env,
            reward_dim=reward_dim,
            n_env=n_env,
            gamma=gamma,
            soft_update_ratio=soft_update_ratio, critic_lr=critic_lr,
            actor_lr=actor_lr, risk_measure=risk_measure, ent_coef=ent_coef,
            target_entropy=target_entropy,
            truncation_lower=truncation_lower, truncation_upper=truncation_upper,
            discrete_weights=discrete_weights, num_grids=num_grids, seed=seed
        )

    def build_actor(self):
        self.actor = Actor(self.obs_dim, self.n_action, self.reward_dim, rngs=self.rng)
        self.opt_actor = nnx.Optimizer(self.actor, optax.chain(
            optax.adam(self.actor_lr, b1=0.5, b2=0.9),

        )
                                       )
        q_infos = { f"qf_{i}": nnx.metrics.Average(f'qf_{i}') for i in range(1, self.reward_dim + 1) }

        self.metric_actor = nnx.MultiMetric(
            pi_loss=nnx.metrics.Average('pi_loss'),
            corr=nnx.metrics.Average('corr'),
            ent_=nnx.metrics.Average('ent'),
            **q_infos
        )

        self.g_actor, self.s_actor = nnx.split(
            (self.actor, self.opt_actor, self.metric_actor)
        )
        self.actor_update_fn = self.build_actor_update_fn()

    def build_critic(self):
        self.critic = EWPCritic(self.obs_dim, self.n_action, self.reward_dim,

                                rngs=self.rng)
        self.opt_critic = nnx.Optimizer(self.critic,
                                        optax.chain(optax.adam(self.critic_lr, ), )

                                        )

        self.metric_critic = nnx.MultiMetric(
            q_loss=nnx.metrics.Average('q_loss'),
        )
        self.g_critic, self.s_critic = nnx.split(
            (self.critic, self.opt_critic, self.metric_critic)
        )
        self.target_param = copy_param(self.critic)
        self.critic_update_fn = self.build_critic_update_fn()
        self.preference_sampler = PreferenceSampler(self.reward_dim, rngs=self.rng)
        self.g_sampler, self.s_sampler = nnx.split(self.preference_sampler)

    def build_critic_update_fn(self,
                               ):
        lower = self.truncation_lower
        upper = self.truncation_upper

        def critic_update_fn(g_critic, s_critic, target_critic_params,
                             g_actor, s_actor,
                             g_ent_coef, s_ent_coef,
                             g_sampler, s_sampler,
                             batch: ReplayBufferSamples,
                             key):
            N_TAUS = 8
            keys = jax.random.split(key, 3)
            critic, opt_critic, metric_critic = nnx.merge(g_critic, s_critic)
            graph, param, *others = nnx.split(critic, nnx.Param, ...)
            target_q_network = nnx.merge(graph, target_critic_params, *others)
            preference_sampler = nnx.merge(g_sampler, s_sampler)

            actor, _, _ = nnx.merge(g_actor, s_actor)
            __ent_coef, _, _ = nnx.merge(g_ent_coef, s_ent_coef)
            ent_coef = __ent_coef()

            w = self.random_weight(keys[-1], batch.observations)  # preference_sampler.sample_weight(batch.observations)

            B = batch.observations.shape[0]
            next_taus = jax.random.uniform(keys[0], shape=(B, self.reward_dim, N_TAUS))
            next_actions, next_log_prob = actor.sample_and_log_prob(batch.next_observations, w)
            # feature, actions, taus, weight
            # (b, w, n, c) -> (B, W, N* C)

            next_quantile = target_q_network(batch.next_observations, next_actions, w, next_taus)

            next_quantile = next_quantile.reshape(w.shape[0], self.reward_dim, -1)

            # remove largest values for each axis
            for d in range(self.reward_dim):
                _index = jnp.argsort(next_quantile[:, d, :], axis=-1)[..., lower:-upper]
                _index = _index[..., None, :]
                next_quantile = jnp.take_along_axis(next_quantile, _index, axis=-1)

            # (b, w, 1) * (b, w, n * c) -> (b, n * c)

            scalarized = ((w[..., None] * next_quantile).sum(axis=1))
            # (b, 1)
            scalarized = scalarized.argsort(axis=-1)[..., :-1 * 3]
            # (b, 1, 1)
            scalarized = scalarized[:, None, :]
            # (b, w, n *c) -> (b, w, n* c - 2)
            next_quantile = jnp.take_along_axis(next_quantile, indices=scalarized, axis=-1)
            # (b, w, n *c) -> (b, w, (n - 2)* c)
            next_quantile = next_quantile - ent_coef * next_log_prob.reshape(B, 1, 1)
            # (B, n_reward, n_quantiles)
            reward = batch.rewards.reshape(w.shape[0], self.reward_dim, 1)
            non_terminal = 1 - batch.dones.reshape(B, 1, 1)
            td_target = reward + self.gamma * non_terminal * next_quantile

            taus = jax.random.uniform(keys[1], shape=(B, self.reward_dim, N_TAUS))
            def loss_fn(model):
                loss = model.loss_fn(batch.observations, batch.actions, td_target, w, taus)

                loss = loss.mean()
                return loss

            loss, grads = nnx.value_and_grad(loss_fn)(critic)
            metric_critic.update(q_loss=loss)
            opt_critic.update(grads)
            _, critic_state = nnx.split((critic, opt_critic, metric_critic))
            _, s_sampler = nnx.split(preference_sampler)
            return critic_state, s_sampler

        return jax.jit(critic_update_fn)

    def build_actor_update_fn(self,
                              ):
        def actor_update_fn(g_critic, s_critic,
                            g_actor, s_actor,
                            g_ent_coef, s_ent_coef,
                            g_sampler, s_sampler,
                            batch: ReplayBufferSamples,
                            key: jax.Array):
            keys = jax.random.split(key, 3)

            critic, _, _ = nnx.merge(g_critic, s_critic)
            actor, opt_actor, metric_actor = nnx.merge(g_actor, s_actor)
            __ent_coef, _, _ = nnx.merge(g_ent_coef, s_ent_coef)
            ent_coef = __ent_coef()
            preference_sampler = nnx.merge(g_sampler, s_sampler)
            taus = jax.random.uniform(keys[0], shape=(batch.observations.shape[0], self.reward_dim, 32))
            w = self.random_weight(keys[0], batch.observations)


            def loss_fn(model: Actor):
                action, log_prob = model.sample_and_log_prob(batch.observations, w)
                # (b, w, n, c)
                raw_q = critic(batch.observations, action, w, taus)
                q_info = raw_q.mean(axis=-2).min(axis=-1).mean(axis=0)

                corr = jax.vmap(lambda x, y: jnp.corrcoef(x, y)[0, 1], in_axes=(1, 1), out_axes=0)(
                    jax.lax.stop_gradient(raw_q.mean(axis=-2).min(axis=-1)), w)
                qf = raw_q * w[..., None, None]
                qf = qf.sum(axis=1).mean(axis=-2).min(axis=-1, keepdims=True)

                kl_loss = (ent_coef * log_prob - qf).mean(axis=-1).mean()
                loss = kl_loss
                return loss, (kl_loss, log_prob, corr.mean(), q_info)

            grads, (kl_loss, log_prob, corr, q_info) = nnx.grad(loss_fn, has_aux=True)(actor)

            opt_actor.update(grads)
            log_prob = log_prob.squeeze(axis=-1).mean()
            q_info_kwargs = { f"qf_{i}": q_info[i - 1] for i in range(1, self.reward_dim + 1) }
            metric_actor.update(pi_loss=kl_loss, ent=-log_prob, corr=corr, **q_info_kwargs)
            _, s_actor = nnx.split((actor, opt_actor, metric_actor))
            _, s_sampler = nnx.split(preference_sampler)
            return s_actor, log_prob, s_sampler

        return jax.jit(actor_update_fn)

    def train_step(self, batch: ReplayBufferSamples):

        self.s_critic, self.target_param, self.s_actor, self.s_ent, self.s_sampler = self.train_step_fn(
            batch, self.g_critic, self.s_critic, self.target_param,
            self.g_actor, self.s_actor,
            self.g_ent, self.s_ent,
            self.g_sampler, self.s_sampler,
            self.rng()
        )

    @staticmethod
    @jax.jit
    def _sample_preference(g_preference, s_preference, place_holder):
        sampler = nnx.merge(g_preference, s_preference)
        w = sampler.sample_weight(place_holder)
        # for rng management
        _, s_preference = nnx.split(sampler)
        return w, s_preference

    def sample_preference(self, nums: int = 1):
        w = self.random_weight(self.rng(), np.ones(nums, ))

        return w

    def build_train_step(self, ):
        critic_update_fn = self.critic_update_fn
        actor_update_fn = self.actor_update_fn
        ent_coef_update_fn = self.ent_coef_update_fn
        polyak_update_fn = jax.jit(partial(polyak_update, soft_update_ratio=self.soft_update_ratio))

        def update_fn(batch: ReplayBufferSamples,
                      g_critic: nnx.GraphDef, s_critic: nnx.State, target_param: nnx.Param,
                      g_actor: nnx.GraphDef, s_actor: nnx.State,
                      g_ent: nnx.GraphDef, s_ent: nnx.State,
                      g_sampler: nnx.GraphDef, s_sampler: nnx.State,
                      key: jax.Array
                      ):
            keys = jax.random.split(key, 2)
            s_critic, s_sampler, = critic_update_fn(g_critic, s_critic, target_param,
                                                    g_actor, s_actor,
                                                    g_ent, s_ent,
                                                    g_sampler, s_sampler,
                                                    batch, keys[0]
                                                    )

            target_param = polyak_update_fn(
                g_critic, s_critic, target_param
            )

            s_actor, log_prob, s_sampler = actor_update_fn(
                g_critic, s_critic,
                g_actor, s_actor,
                g_ent, s_ent,
                g_sampler, s_sampler,
                batch,
                keys[1]

            )
            s_ent = ent_coef_update_fn(
                g_ent, s_ent,
                log_prob
            )
            return s_critic, target_param, s_actor, s_ent, s_sampler

        return jax.jit(update_fn)
