import tensorflow as tf
from utils.basic_models import *
from copy import deepcopy
import os
import random


class LatentModel(tf.Module):
    def __init__(self, config):
        self.latent_dynamic = RSSME(stoch=config['stoch'],
                                    deter=config['deter'],
                                    hidden=config['hidden_latent'],
                                    num_models=config['num_models'])
        self.reward = DenseDecoder((), 2, units=config.num_units_reward)
        self.encoder = ConvEncoder()
        self.decoder = ConvDecoder()

    def save(self, filename):
        self.latent_dynamic.save(filename / 'latent.pkl')
        self.reward.save(filename / 'reward.pkl')
        self.encoder.save(filename / 'encoder.pkl')
        self.decoder.save(filename / 'decoder.pkl')

    def load(self, filename):
        self.latent_dynamic.load(filename / 'latent.pkl')
        self.reward.load(filename / 'reward.pkl')
        self.encoder.load(filename / 'encoder.pkl')
        self.decoder.load(filename / 'decoder.pkl')


class ActorOraac(tf.Module):
    def __init__(self, config):
        self._lambda = tf.stop_gradient(tf.constant(config.trade_off_cvar_imit_agent))
        self._max_action = config.max_action
        if config['act_imit_agent'] == 'relu':
            _act = tf.nn.relu
        elif config['act_imit_agent'] == 'elu':
            _act = tf.nn.elu
        else:
            raise NotImplementedError(config['act_imit_agent'])

        self._imit_actor = BasicStochasticActor(shape=config['shape_action'],
                                                max_action=config['max_action'],
                                                num_layers=config['num_layers_imit_agent'],
                                                units=config['units_imit_agent'],
                                                act=_act)

        if config['act_cvar_agent'] == 'relu':
            _act = tf.nn.relu
        elif config['act_cvar_agent'] == 'elu':
            _act = tf.nn.elu
        else:
            raise NotImplementedError(config['act_cvar_agent'])
        self._cvar_actor = BasicActor(shape=config['shape_action'],
                                      max_action=config['max_action'],
                                      num_layers=config['num_layers_cvar_agent'],
                                      units=config['units_cvar_agent'],
                                      act=_act)
        self._target_cvar_actor = deepcopy(self._cvar_actor)

    def update_cvar_target(self, tau):
        for params, new_params in zip(self._target_cvar_actor.variables, self._cvar_actor.variables):
            new_param_ = tau * tf.stop_gradient(new_params) + (1.0 - tau)*tf.stop_gradient(params)
            params.assign(new_param_)

    def __call__(self, state, cvar_training=False):
        if cvar_training:
            imit_action = tf.stop_gradient(self._imit_actor(state))
        else:
            imit_action = self._imit_actor(state)

        action = imit_action + self._lambda*self._cvar_actor(tf.concat((state, imit_action), axis=-1))
        return tf.clip_by_value(action, clip_value_min=-self._max_action, clip_value_max=self._max_action)

    def next_action(self, state):
        imit_action = self._imit_actor(state)
        action = imit_action + self._lambda*self._target_cvar_actor(tf.concat((state, imit_action), axis=-1))
        return tf.clip_by_value(action, clip_value_min=-self._max_action, clip_value_max=self._max_action)

    def imitation_action(self, state):
        return self._imit_actor(state)

    def save(self, filename):
        self._imit_actor.save(filename / 'imit_actor.pkl')
        self._cvar_actor.save(filename / 'cvar_actor.pkl')

    def save_imit_actor(self, filename):
        if not os.path.exists(filename):
            os.mkdir(filename)
        self._imit_actor.save(filename / 'imit_actor.pkl')

    def load_imit_actor(self, filename):
        self._imit_actor.load(filename / 'imit_actor.pkl')

    def load(self, filename):
        self._imit_actor.load(filename / 'imit_actor.pkl')
        self._cvar_actor.load(filename / 'cvar_actor.pkl')


class DistributionalCritic(tf.Module):
    def __init__(self, config):
        if config['act_critic'] == 'relu':
            _act = tf.nn.relu
        elif config['act_critic'] == 'elu':
            _act = tf.nn.elu
        else:
            raise NotImplementedError(config['act_critic'])

        if config.latent_algo == 'oraac':
            self.q1 = DeterministicNN_IQN(dim_state=config.deter+config.stoch,
                                          dim_action=config.shape_action,
                                          layers_state=config.critic_hidden_units_state,
                                          layers_action=config.critic_hidden_units_action,
                                          layers_f=config.critic_hidden_units_f,
                                          tau_embed_dim=config.critic_tau_embedding_dim,
                                          embedding_dim=config.critic_embedding_dim,
                                          tau=config.target_update_tau
                                          )

            self.q2 = DeterministicNN_IQN(dim_state=config.deter + config.stoch,
                                          dim_action=config.shape_action,
                                          layers_state=config.critic_hidden_units_state,
                                          layers_action=config.critic_hidden_units_action,
                                          layers_f=config.critic_hidden_units_f,
                                          tau_embed_dim=config.critic_tau_embedding_dim,
                                          embedding_dim=config.critic_embedding_dim,
                                          tau=config.target_update_tau
                                          )

        elif config.latent_algo == 'lodac':
            self.q1 = QuantileMlp(
                hidden_sizes=[config.critic_units, config.critic_units],
                num_quantiles=config.N_quantile_critic,
                layer_norm=config.layer_norm,
                tau=config.target_update_tau,
            )
            self.q2 = QuantileMlp(
                hidden_sizes=[config.critic_units, config.critic_units],
                num_quantiles=config.N_quantile_critic,
                layer_norm=config.layer_norm,
                tau=config.target_update_tau,

            )

        else:
            raise NotImplementedError(config.latent_algo)
        self.target_q1 = deepcopy(self.q1)
        self.target_q2 = deepcopy(self.q2)

    def get_sampled_Z(self, obs, tau_actor_k, action):
        num = random.randint(0, 1)
        if num:
            return self.q1.get_sampled_Z(obs,  tau_actor_k, action)
        else:
            return self.q2.get_sampled_Z(obs, tau_actor_k, action)

    def update_target(self, tau):
        for params, new_params in zip(self.target_q1.params, self.q1.params):
            new_param_ = tau * tf.stop_gradient(new_params) + (1.0 - tau)*tf.stop_gradient(params)
            params.assign(new_param_)
        for params, new_params in zip(self.target_q2.params, self.q2.params):
            new_param_ = tau * tf.stop_gradient(new_params) + (1.0 - tau)*tf.stop_gradient(params)
            params.assign(new_param_)

    def save(self, filename):
        self.q1.save(filename / 'q1.pkl')
        self.q2.save(filename / 'q2.pkl')

    def load(self, filename):
        self.q1.load(filename / 'q1.pkl')
        self.target_q1.load(filename / 'q1.pkl')

        self.q1.load(filename / 'q2.pkl')
        self.target_q1.load(filename / 'q2.pkl')


class Actor(tf.Module):
    def __init__(self, config):
        self._max_action = config.max_action
        if config['act_actor'] == 'relu':
            _act = tf.nn.relu
        elif config['actor_actor'] == 'elu':
            _act = tf.nn.elu
        else:
            raise NotImplementedError(config['act_imit_agent'])

        self._actor = BasicStochasticActor(shape=config['shape_action'],
                                           max_action=config['max_action'],
                                           num_layers=config['num_layers_actor'],
                                           units=config['units_actor'],
                                           act=_act)

    def __call__(self, state):
        action = self._actor(state)
        return tf.clip_by_value(action, clip_value_min=-self._max_action, clip_value_max=self._max_action)

    def action_log_prob(self, state):
        action, action_log_prob = self._actor.action_log_prob(state)
        return action, action_log_prob

    def save(self, filename):
        self._actor.save(filename / '_actor.pkl')

    def load(self, filename):
        self._actor.load(filename / '_actor.pkl')


class Critic(tf.Module):
    def __init__(self, config):
        if config['act_critic'] == 'relu':
            _act = tf.nn.relu
        elif config['act_critic'] == 'elu':
            _act = tf.nn.elu
        else:
            raise NotImplementedError(config['act_critic'])
        # (self, shape, num_layers, units, act=tf.nn.relu)
        self.qf1 = DenseNetwork(shape=1,
                                num_layers=config.critic_num_layers,
                                units=config.critic_units,
                                act=_act)

        self.qf2 = DenseNetwork(shape=1,
                                num_layers=config.critic_num_layers,
                                units=config.critic_units,
                                act=_act)

        self.target_qf1 = deepcopy(self.qf2)
        self.target_qf2 = deepcopy(self.qf1)

    def save(self, filename):
        self.qf1.save(filename / 'qf1.pkl')
        self.qf2.save(filename / 'qf2.pkl')
        self.target_qf1.save(filename / 'target_qf1.pkl')
        self.target_qf2.save(filename / 'target_qf2.pkl')

    def load(self, filename):
        self.qf1.load(filename / 'qf1.pkl')
        self.qf2.load(filename / 'qf2.pkl')
        self.target_qf1.load(filename / 'target_qf1.pkl')
        self.target_qf2.load(filename / 'target_qf2.pkl')





