"""Implementations of algorithms for continuous control."""

from typing import Optional, Sequence, Tuple, List

import jax
import jax.numpy as jnp
import numpy as np
import optax
import os

import policy
import value_net
from actor import update as awr_update_actor
from pagar_actor import pagar_update as awr_pagar_update_actor

from common import Batch, InfoDict, Model, PRNGKey

from critic import update_q, update_v 
from pagar_critic import pagar_update_r, pagar_update_v
# from dual_critic import update_q_dual, update_v_dual

from functools import partial


def target_update(critic: Model, target_critic: Model, tau: float) -> Model:
    new_target_params = jax.tree_map(
        lambda p, tp: p * tau + tp * (1 - tau), critic.params,
        target_critic.params)

    return target_critic.replace(params=new_target_params)


@partial(jax.jit, static_argnames=['double', 'vanilla', 'args'])
def _update_jit(
    rng: PRNGKey, actor: Model, critic: Model, value: Model,
    target_critic: Model, expert_batch: Batch, suboptimal_batch: Batch, mix_batch: Batch, discount: float, tau: float,
    expectile: float, temperature: float, loss_temp: float, double: bool, vanilla: bool, args,
) -> Tuple[PRNGKey, Model, Model, Model, Model, Model, InfoDict]:

    key, rng = jax.random.split(rng)
    for i in range(args.num_v_updates):
        new_value, value_info = update_v(target_critic, value, mix_batch, expectile, loss_temp, double, vanilla, key, args)
        value = new_value
    new_actor, actor_info = awr_update_actor(key, actor, target_critic,
                                             new_value,  mix_batch, temperature, double)

    new_critic, critic_info = update_q(critic, new_value, expert_batch, suboptimal_batch, mix_batch, discount, double, key, loss_temp, args)

    new_target_critic = target_update(new_critic, target_critic, tau)

    return rng, new_actor, new_critic, new_value, new_target_critic, {
        **critic_info,
        **value_info,
        **actor_info
    }

class Learner(object):
    def __init__(self,
                 seed: int,
                 observations: jnp.ndarray,
                 actions: jnp.ndarray,
                 actor_lr: float = 3e-4,
                 value_lr: float = 3e-4,
                 critic_lr: float = 3e-4,
                 hidden_dims: Sequence[int] = (256, 256),
                 discount: float = 0.99,
                 tau: float = 0.005,
                 expectile: float = 0.8,
                 temperature: float = 0.1,
                 dropout_rate: Optional[float] = None,
                 layernorm: bool = False,
                 value_dropout_rate: Optional[float] = None,
                 max_steps: Optional[int] = None,
                 loss_temp: float = 1.0,
                 double_q: bool = True,
                 vanilla: bool = True,
                 opt_decay_schedule: str = "cosine",
                 args=None):
        """
        An implementation of the version of Soft-Actor-Critic described in https://arxiv.org/abs/1801.01290
        """

        self.expectile = expectile
        self.tau = tau
        self.discount = discount
        self.temperature = temperature
        self.loss_temp = loss_temp
        self.double_q = double_q
        self.vanilla = vanilla
        self.args = args

        rng = jax.random.PRNGKey(seed)
        rng, actor_key, critic_key, value_key = jax.random.split(rng, 4)

        action_dim = actions.shape[-1]
        actor_def = policy.NormalTanhPolicy(hidden_dims,
                                            action_dim,
                                            log_std_scale=1e-3,
                                            log_std_min=-5.0,
                                            dropout_rate=dropout_rate,
                                            state_dependent_std=False,
                                            tanh_squash_distribution=False)

        if opt_decay_schedule == "cosine":
            schedule_fn = optax.cosine_decay_schedule(-actor_lr, max_steps)
            optimiser = optax.chain(optax.scale_by_adam(),
                                    optax.scale_by_schedule(schedule_fn))
        else:
            optimiser = optax.adam(learning_rate=actor_lr)

        actor = Model.create(actor_def,
                             inputs=[actor_key, observations],
                             tx=optimiser)

        critic_def = value_net.DoubleCritic(hidden_dims)

        critic = Model.create(critic_def,
                              inputs=[critic_key, observations, actions],
                              tx=optax.adam(learning_rate=critic_lr))

        value_def = value_net.ValueCritic(hidden_dims,
                                          layer_norm=layernorm,
                                          dropout_rate=value_dropout_rate)
        value = Model.create(value_def,
                             inputs=[value_key, observations],
                             tx=optax.adam(learning_rate=value_lr))

        target_critic = Model.create(
            critic_def, inputs=[critic_key, observations, actions])

        self.actor = actor
        self.critic = critic
        self.value = value
        self.target_critic = target_critic
        self.rng = rng

    def sample_actions(self,
                       observations: np.ndarray,
                       temperature: float = 1.0) -> jnp.ndarray:
        rng, actions = policy.sample_actions(self.rng, self.actor.apply_fn,
                                             self.actor.params, observations,
                                             temperature)
        self.rng = rng

        actions = np.asarray(actions)
        return np.clip(actions, -1, 1)

    def update(self, expert_batch, suboptimal_batch, mix_batch) -> InfoDict:

        new_rng, new_actor, new_critic, new_value, new_target_critic, info = _update_jit(
            self.rng, self.actor, self.critic, self.value, self.target_critic,
            expert_batch, suboptimal_batch, mix_batch, self.discount, self.tau, self.expectile, self.temperature, self.loss_temp, self.double_q, self.vanilla, self.args)

        self.rng = new_rng
        self.actor = new_actor
        self.critic = new_critic
        self.value = new_value
        self.target_critic = new_target_critic

        return info

    def load(self, save_dir: str):
        self.actor = self.actor.load(os.path.join(save_dir, 'actor'))
        self.critic = self.critic.load(os.path.join(save_dir, 'critic'))
        self.value = self.value.load(os.path.join(save_dir, 'value'))
        self.target_critic = self.target_critic.load(os.path.join(save_dir, 'critic'))

    def save(self, save_dir: str):
        self.actor.save(os.path.join(save_dir, 'actor'))
        self.critic.save(os.path.join(save_dir, 'critic'))
        self.value.save(os.path.join(save_dir, 'value'))


@partial(jax.jit, static_argnames=['double', 'vanilla', 'args'])
def _update_pagar_jit(
    rng: PRNGKey, protagonist_actor: Model, antagonist_actor: Model, reward: Model, value: Model,
    target_reward: Model, expert_batch: Batch, suboptimal_batch: Batch, mix_batch: Batch, discount: float, tau: float,
    expectile: float, temperature: float, loss_temp: float, double: bool, vanilla: bool, args,
) -> Tuple[PRNGKey, Model, Model, Model, Model, Model, InfoDict]:

    key, rng = jax.random.split(rng)
    for i in range(args.num_v_updates):
        new_value, value_info = pagar_update_v(target_reward, value, mix_batch, discount, expectile, loss_temp, double, vanilla, key, args)
        value = new_value

    new_antagonist_actor, antagonist_actor_info = awr_pagar_update_actor(key, antagonist_actor, target_reward,
                                             new_value, discount, mix_batch, temperature, double)
    
    new_protagonist_actor, protagonist_actor_info = awr_pagar_update_actor(key, protagonist_actor, target_reward,
                                             new_value, discount, mix_batch, temperature, double)

    actor_info = {('antagonist_' + k): v for k, v in antagonist_actor_info.items()}
    for k, v in protagonist_actor_info.items():
        actor_info['protagonist_' + k] = v

    new_reward, reward_info = pagar_update_r(reward, new_value, protagonist_actor, antagonist_actor, expert_batch, suboptimal_batch, mix_batch, discount, double, key, loss_temp, temperature, args)

    new_target_reward = target_update(new_reward, target_reward, tau)

    return rng, new_protagonist_actor, new_antagonist_actor, new_reward, new_value, new_target_reward, {
        **reward_info,
        **value_info,
        **actor_info 
    }





class PAGAR_Learner(Learner):
    def __init__(self,
                 seed: int,
                 observations: jnp.ndarray,
                 actions: jnp.ndarray,
                 actor_lr: float = 3e-4,
                 value_lr: float = 3e-4,
                 critic_lr: float = 3e-4,
                 hidden_dims: Sequence[int] = (256, 256),
                 discount: float = 0.99,
                 tau: float = 0.005,
                 expectile: float = 0.8,
                 temperature: float = 0.1,
                 dropout_rate: Optional[float] = None,
                 layernorm: bool = False,
                 value_dropout_rate: Optional[float] = None,
                 max_steps: Optional[int] = None,
                 loss_temp: float = 1.0,
                 double_q: bool = True,
                 vanilla: bool = True,
                 opt_decay_schedule: str = "cosine",
                 args=None):
         
        self.expectile = expectile
        self.tau = tau
        self.discount = discount
        self.temperature = temperature
        self.loss_temp = loss_temp
        self.double_q = double_q
        self.vanilla = vanilla
        self.args = args

        rng = jax.random.PRNGKey(seed)
        rng, protagonist_actor_key, antagonist_actor_key, reward_key, value_key = jax.random.split(rng, 5)

        action_dim = actions.shape[-1]
        actor_def = policy.NormalTanhPolicy(hidden_dims,
                                            action_dim,
                                            log_std_scale=1e-3,
                                            log_std_min=-5.0,
                                            dropout_rate=dropout_rate,
                                            state_dependent_std=False,
                                            tanh_squash_distribution=False)

        if opt_decay_schedule == "cosine":
            schedule_fn = optax.cosine_decay_schedule(-actor_lr, max_steps)
            optimiser = optax.chain(optax.scale_by_adam(),
                                    optax.scale_by_schedule(schedule_fn))
        else:
            optimiser = optax.adam(learning_rate=actor_lr)

        protagonist_actor = Model.create(actor_def,
                             inputs=[protagonist_actor_key, observations],
                             tx=optimiser)
        
        antagonist_actor = Model.create(actor_def,
                             inputs=[antagonist_actor_key, observations],
                             tx=optimiser)

        reward_def = value_net.DoubleReward(hidden_dims)

        reward = Model.create(reward_def,
                              inputs=[reward_key, observations, actions, observations],
                              tx=optax.adam(learning_rate=critic_lr))

        value_def = value_net.ValueCritic(hidden_dims,
                                          layer_norm=layernorm,
                                          dropout_rate=value_dropout_rate)
        value = Model.create(value_def,
                             inputs=[value_key, observations],
                             tx=optax.adam(learning_rate=value_lr))

        target_reward = Model.create(
            reward_def, inputs=[reward_key, observations, actions, observations])

        self.protagonist_actor = protagonist_actor
        self.antagonist_actor = antagonist_actor

        self.value = value
        self.reward = reward
        self.target_reward = target_reward
        self.rng = rng

    def sample_actions(self,
                       observations: np.ndarray,
                       temperature: float = 1.0) -> jnp.ndarray:
        rng, actions = policy.sample_actions(self.rng, self.protagonist_actor.apply_fn,
                                             self.protagonist_actor.params, observations,
                                             temperature)
        self.rng = rng

        actions = np.asarray(actions)
        return np.clip(actions, -1, 1)

    def update(self, expert_batch, suboptimal_batch, mix_batch) -> InfoDict:

        new_rng, new_protagonist_actor, new_antagonist_actor, new_reward, new_value, new_target_reward, info = _update_pagar_jit(
            self.rng, self.protagonist_actor, self.antagonist_actor, self.reward, self.value, self.target_reward,
            expert_batch, suboptimal_batch, mix_batch, self.discount, self.tau, self.expectile, self.temperature, self.loss_temp, self.double_q, self.vanilla, self.args)

        self.rng = new_rng
        self.protagonist_actor = new_protagonist_actor
        self.antagonist_actor = new_antagonist_actor
        self.reward = new_reward
        self.value = new_value
        self.target_reward = new_target_reward

        return info

    def load(self, save_dir: str):
        self.protagonist_actor = self.protagonist_actor.load(os.path.join(save_dir, 'protagonist_actor'))
        self.antagonist_actor = self.protagonist_actor.load(os.path.join(save_dir, 'antagonist_actor'))
        self.reward = self.reward.load(os.path.join(save_dir, 'reward'))
        self.value = self.value.load(os.path.join(save_dir, 'value'))
        self.target_reward = self.target_reward.load(os.path.join(save_dir, 'reward'))

    def save(self, save_dir: str):
        self.protagonist_actor.save(os.path.join(save_dir, 'protagonist_actor'))
        self.antagonist_actor.save(os.path.join(save_dir, 'antagonist_actor'))
        self.reward.save(os.path.join(save_dir, 'reward'))
        self.value.save(os.path.join(save_dir, 'value'))  