import flax
from flax.training import orbax_utils
import jax.random
import optax

from networks.actor import SquashedGaussianPolicy, DeterministicPolicy
from networks.base import LagrangianCoefficient
from networks.critics import VectorCritic, VectorQuantileCritic, QCritic
from misc.rng_modules import PRNGSequence
from typing import Any, Type, Callable
from networks.optimize import RLTrainState, update
from flax.training.train_state import TrainState
from functools import partial
from copy import deepcopy
import orbax


class TD3Policy(object):
    actor: DeterministicPolicy
    actor_train_state: RLTrainState

    critic: VectorCritic
    critic_train_state: RLTrainState

    def __init__(self, observation_ph, action_ph,
                 opt_class: Type[optax.adam] = optax.adam,
                 learning_rate: float = 3e-4,
                 n_critics: int = 2,
                 seed: jax.Array = jax.random.PRNGKey(42)
                 ):
        self.n_critics = n_critics
        self.seed = seed
        self.rng = PRNGSequence(seed)
        self.opt_class = opt_class(learning_rate)
        self.build(observation_ph, action_ph)

    def build(self, observation_ph, action_ph):
        self.actor = DeterministicPolicy(action_ph.shape[-1])
        self.actor.apply = jax.jit(self.actor.apply)
        self.critic = VectorCritic(n_critics=self.n_critics)
        self.critic.apply = jax.jit(self.critic.apply)

        actor_params = self.actor.init(next(self.rng), observation_ph)
        critic_params = self.critic.init(next(self.rng), observation_ph, action_ph)

        self.actor_train_state = RLTrainState.create(apply_fn=jax.jit(self.actor.apply),
                                                     params=actor_params, tx=self.opt_class,
                                                     target_params=deepcopy(actor_params),
                                                     )

        self.critic_train_state = RLTrainState.create(apply_fn=jax.jit(self.critic.apply),
                                                      params=critic_params, tx=self.opt_class,
                                                      target_params=deepcopy(critic_params)
                                                      )

    def actor_update(self, actor_loss_fn, *args, **kwargs):
        self.actor_train_state, aux = update(actor_loss_fn, self.actor_train_state, *args, **kwargs)
        return aux

    def critic_update(self, critic_loss_fn, *args, **kwargs):
        self.critic_train_state, aux = update(critic_loss_fn,
                                              self.critic_train_state, *args, **kwargs)
        return aux

    @staticmethod
    @jax.jit
    def _soft_update(critic_train_state, actor_train_state):
        critic_train_state = critic_train_state.apply_target_update()
        actor_train_state = actor_train_state.apply_target_update()
        return critic_train_state, actor_train_state

    def soft_update(self):
        self.critic_train_state, self.actor_train_state = self._soft_update(self.critic_train_state,
                                                                            self.actor_train_state)

    @property
    def params_actor(self):
        return self.actor.param

    @property
    def params_critic(self):
        return self.critic.param

    @property
    def checkpoint(self):
        return {"actor": self.actor_train_state, "critic": self.critic_train_state}

    def save_checkpoint(self, path, overwrite: bool = True):
        ckptr = orbax.checkpoint.Checkpointer(orbax.checkpoint.PyTreeCheckpointHandler())
        CKPT_PYTREE = self.checkpoint
        ckptr.save(path, CKPT_PYTREE,
                   save_args=flax.training.orbax_utils.save_args_from_target(CKPT_PYTREE), force=overwrite)

    def _load_checkpoint(self, path: str):
        target = self.checkpoint

        ckptr = orbax.checkpoint.Checkpointer(orbax.checkpoint.PyTreeCheckpointHandler())
        restored = ckptr.restore(path, item=target,
                                 restore_args=flax.training.orbax_utils.restore_args_from_target(target, mesh=None))

        return restored

    def load_checkpoint(self, path):
        restored = self._load_checkpoint(path)
        self.actor_train_state = restored['actor']
        self.critic_train_state = restored['critic']


class IQNTD3Policy(TD3Policy):
    critic: VectorQuantileCritic

    def build(self, observation_ph, action_ph):
        self.actor = DeterministicPolicy(action_ph.shape[-1])
        self.actor.apply = jax.jit(self.actor.apply)
        self.critic = VectorQuantileCritic(n_critics=self.n_critics)
        self.critic.apply = jax.jit(self.critic.apply)

        actor_params = self.actor.init(next(self.rng), observation_ph)
        taus_ph = jax.numpy.zeros(shape=(observation_ph.shape[0], 1))

        critic_params = self.critic.init(next(self.rng), observation_ph, action_ph, taus_ph)

        self.actor_train_state = RLTrainState.create(apply_fn=jax.jit(self.actor.apply),
                                                     params=actor_params, tx=self.opt_class,
                                                     target_params=deepcopy(actor_params),
                                                     )

        self.critic_train_state = RLTrainState.create(apply_fn=jax.jit(self.critic.apply),
                                                      params=critic_params, tx=self.opt_class,
                                                      target_params=deepcopy(critic_params)
                                                      )


class IQNSACPolicy(IQNTD3Policy):
    ent_coef: LagrangianCoefficient
    ent_coef_train_state: TrainState

    def build(self, observation_ph, action_ph):
        self.actor = SquashedGaussianPolicy(action_ph.shape[-1])
        self.actor.apply = jax.jit(self.actor.apply)
        self.critic = VectorQuantileCritic(n_critics=self.n_critics, monotone=False)
        self.critic.apply = jax.jit(self.critic.apply)

        actor_params = self.actor.init(next(self.rng), observation_ph)
        taus_ph = jax.numpy.zeros(shape=(observation_ph.shape[0], 1))

        critic_params = self.critic.init(next(self.rng), observation_ph, action_ph, taus_ph)

        self.actor_train_state = RLTrainState.create(apply_fn=jax.jit(self.actor.apply),
                                                     params=actor_params, tx=self.opt_class,
                                                     target_params=deepcopy(actor_params),
                                                     )

        self.critic_train_state = RLTrainState.create(apply_fn=jax.jit(self.critic.apply),
                                                      params=critic_params, tx=self.opt_class,
                                                      target_params=deepcopy(critic_params)
                                                      )

        self.ent_coef = LagrangianCoefficient()
        ent_coef_param = self.ent_coef.init(jax.random.PRNGKey(42))
        self.ent_coef_train_state = TrainState.create(apply_fn=jax.jit(self.ent_coef.apply),
                                                      params=ent_coef_param, tx=optax.adam(learning_rate=3e-4)
                                                      )

    @partial(jax.jit, static_argnums=(0,))
    def ent_coef_loss(self, param_ent, train_state, current_ent, target_ent):
        log_ent_coef = self.ent_coef.apply(param_ent)
        loss = -log_ent_coef * jax.lax.stop_gradient(current_ent - target_ent)
        return loss, {"ent_coef_loss": loss}

    def ent_coef_update(self, current_ent, target_ent):
        self.ent_coef_train_state, aux = update(self.ent_coef_loss,
                                                self.ent_coef_train_state, current_ent, target_ent)
        return aux


class ModelRiskPolicy(IQNTD3Policy):
    cmv_critic: VectorCritic
    cmv_critic_state: RLTrainState
    reward_model: QCritic
    reward_model_state: RLTrainState

    def __init__(self, observation_ph, action_ph,
                 opt_class: Type[optax.adabelief] = optax.adabelief,
                 learning_rate: float = 3e-4,
                 use_reward_model: bool = False,
                 n_critics: int = 2,
                 seed: jax.Array = jax.random.PRNGKey(42)
                 ):
        self.use_reward_model = use_reward_model
        super().__init__(observation_ph, action_ph, opt_class, learning_rate, n_critics=n_critics,
                         seed=seed)

    def build(self, observation_ph, action_ph):
        super().build(observation_ph, action_ph)

        self.cmv_critic = VectorCritic(cmv=True)
        self.cmv_critic.apply = jax.jit(self.cmv_critic.apply)
        cmv_critic_params = self.cmv_critic.init(next(self.rng), observation_ph, action_ph)
        self.cmv_critic_state = RLTrainState.create(apply_fn=jax.jit(self.cmv_critic.apply),
                                                    params=cmv_critic_params, tx=self.opt_class,
                                                    target_params=deepcopy(cmv_critic_params)
                                                    )

        self.reward_model = QCritic()
        self.reward_model.apply = jax.jit(self.reward_model.apply)
        reward_model_params = self.reward_model.init(next(self.rng), observation_ph, action_ph)
        self.reward_model_state = TrainState.create(apply_fn=jax.jit(self.reward_model.apply),
                                                    params=reward_model_params, tx=self.opt_class)

    def actor_update(self, actor_loss_fn, *args, **kwargs):
        self.actor_train_state, aux = update(actor_loss_fn, self.actor_train_state, *args, **kwargs)
        return aux

    def cmv_model_update(self, model_loss_fn, *args, **kwargs):
        self.cmv_critic_state, aux = update(model_loss_fn,
                                            self.cmv_critic_state, *args, **kwargs)
        return aux

    def reward_model_update(self, reward_loss_fn, *args, **kwargs):
        self.reward_model_state, aux = update(reward_loss_fn,
                                              self.reward_model_state, *args, **kwargs)
        return aux

    @staticmethod
    @jax.jit
    def _soft_update(cmv_train_state, critic_train_state, actor_train_state):
        critic_train_state = critic_train_state.apply_target_update()
        actor_train_state = actor_train_state.apply_target_update()
        cmv_train_state = cmv_train_state.apply_target_update()
        return cmv_train_state, critic_train_state, actor_train_state

    def soft_update(self):
        self.cmv_critic_state, self.critic_train_state, self.actor_train_state = (
            self._soft_update(self.cmv_critic_state, self.critic_train_state, self.actor_train_state))

    @property
    def checkpoint(self):
        checkpoint = super().checkpoint
        checkpoint['cmv'] = self.cmv_critic_state
        checkpoint['reward'] = self.reward_model_state
        return checkpoint

    def load_checkpoint(self, path):
        restored = self._load_checkpoint(path)
        self.actor_train_state = restored['actor']
        self.critic_train_state = restored['critic']
        self.cmv_critic_state = restored['cmv']
        self.reward_model_state = restored['reward']




