import tensorflow as tf
import numpy as np
import time
from .squash_bijector import SquashBijector
from .utils import evaluate_training_rollouts
import tensorflow_probability as tfp
from collections import OrderedDict, deque
import os
from copy import deepcopy
import sys
sys.path.append("..")
from robustness_eval import training_evaluation
from disturber.disturber import Disturber
from pool.pool import REINFOCE_Pool
import logger
from variant import *

SCALE_DIAG_MIN_MAX = (-20, 2)
SCALE_lambda_MIN_MAX = (0, 1)


class L_REINFORCE(object):
    def __init__(self,
                 a_dim,
                 s_dim,


                 variant,

                 action_prior = 'uniform',
                 ):



        ###############################  Model parameters  ####################################
        # self.memory_capacity = variant['memory_capacity']



        tau = variant['tau']
        self.approx_value = True if 'approx_value' not in variant.keys() else variant['approx_value']
        # self.memory = np.zeros((self.memory_capacity, s_dim * 2 + a_dim+ d_dim + 3), dtype=np.float32)
        # self.pointer = 0
        self.sess = tf.Session()
        self._action_prior = action_prior
        s_dim = s_dim * (variant['history_horizon']+1)
        self.a_dim, self.s_dim, = a_dim, s_dim
        self.history_horizon = variant['history_horizon']
        self.working_memory = deque(maxlen=variant['history_horizon']+1)
        self.N_path_num = variant['N_path_num']
        self.use_soft_clip = variant['use_soft_clip']
        self.cost_as_Lyapunov = variant ['cost_as_Lyapunov']
        self.discounted_value_as_Lyapunov = variant['discounted_value_as_Lyapunov']
        self.gamma = variant['gamma']
        pool_params = {
            's_dim': s_dim,
            'a_dim': a_dim,
            'store_last_n_paths': variant['N_path_num'],
            'c_bar': variant['c_bar'],
            'history_horizon': variant['history_horizon'],
            'target_horizon': variant['target_horizon'],
        }
        self.pool = REINFOCE_Pool(pool_params)
        EPSILON = variant['EPSILON']
        with tf.variable_scope('L_REINFORCE_NETWORK'):
            self.S = tf.placeholder(tf.float32, [None, s_dim], 's')
            self.S_ = tf.placeholder(tf.float32, [None, s_dim], 's_')
            self.S_T = tf.placeholder(tf.float32, [None, s_dim], 's_T')
            self.a_input = tf.placeholder(tf.float32, [None, a_dim], 'a_input')
            self.R = tf.placeholder(tf.float32, [None, 1], 'return')
            self.c = tf.placeholder(tf.float32, [None, 1], 'c')
            self.c_ = tf.placeholder(tf.float32, [None, 1], 'c_')
            self.c_T = tf.placeholder(tf.float32, [None, 1], 'c_T')
            self.L_target = tf.placeholder(tf.float32, [None, 1], 'L_target')
            self.LR_A = tf.placeholder(tf.float32, None, 'LR_A')
            self.LR_lag = tf.placeholder(tf.float32, None, 'LR_lag')
            self.LR_C = tf.placeholder(tf.float32, None, 'LR_C')
            self.LR_L = tf.placeholder(tf.float32, None, 'LR_L')
            self.epsilon = tf.placeholder(tf.float32, None, 'epsilon')
            self.terminal = tf.placeholder(tf.float32, [None, 1], 'terminal')
            origin = tf.zeros_like(self.S, tf.float32)

            # self.labda = tf.placeholder(tf.float32, None, 'Lambda')
            labda = variant['labda']

            alpha3 = variant['alpha3']
            epsilon = variant['epsilon']
            weight_of_s_norm = variant['weight_of_s_norm']
            log_labda = tf.get_variable('lambda', None, tf.float32, initializer=tf.log(labda))
            self.labda = tf.clip_by_value(tf.exp(log_labda), *SCALE_lambda_MIN_MAX)

            # self.a, self.deterministic_a, self.a_dist = self._build_a(self.S, )  # 这个网络用于及时更新参数
            # self.log_pis = log_pis = tf.expand_dims(self.a_dist.log_prob(self.a_input), 1)
            pi, pi_params = self._build_anet('pi', trainable=True)
            oldpi, oldpi_params = self._build_anet('oldpi', trainable=False)
            self.a = tf.squeeze(pi.sample(1), axis=0)  # operation of choosing action
            self.update_oldpi_op = [oldp.assign(p) for p, oldp in zip(pi_params, oldpi_params)]
            self.ratio = ratio = pi.prob(self.a_input) / (oldpi.prob(self.a_input) + 1e-5)



            self.l_net = self._build_l(self.S)   # lyapunov 网络
            self.l_net_ = self._build_l(self.S_, reuse=True)
            self.l_net_T = self._build_l(self.S_T, reuse=True)
            self.l_origin_value = self._build_l(origin, reuse=True)
            lyapunov_value = tf.abs(self.l_net - self.l_origin_value) + weight_of_s_norm * self.c
            lyapunov_value_ = tf.abs(self.l_net_ - self.l_origin_value) + weight_of_s_norm * self.c_
            lyapunov_value_T = tf.abs(self.l_net_T - self.l_origin_value) + weight_of_s_norm * self.c_T

            lyapunov_value = self.l_net
            lyapunov_value_ = self.l_net_
            lyapunov_value_T = self.l_net_T

            if self.cost_as_Lyapunov:
                lyapunov_value = self.c
                lyapunov_value_ = self.c_
                lyapunov_value_T = self.c_T

            a_params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='L_REINFORCE_NETWORK/actor')
            l_params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='L_REINFORCE_NETWORK/Lyapunov')

            ###############################  Model Learning Setting  ####################################
            ema = tf.train.ExponentialMovingAverage(decay=1 - tau)  # soft replacement

            def ema_getter(getter, name, *args, **kwargs):
                return ema.average(getter(name, *args, **kwargs))
            target_update = [ema.apply(a_params),  ema.apply(l_params)]  # soft update operation

            # 这个网络不及时更新参数, 用于预测 Critic 的 Q_target 中的 action
            # a_, _, a_dist_ = self._build_a(self.S_, reuse=True, custom_getter=ema_getter)  # replaced target parameters



            # 这个网络不及时更新参数, 用于给出 Actor 更新参数时的 Gradient ascent 强度

            # lyapunov constraint
            self.delta_l = tf.reduce_mean(lyapunov_value_ - lyapunov_value + alpha3 * self.c + epsilon)

            labda_loss = -tf.reduce_mean(log_labda * self.delta_l)
            self.lambda_train = tf.train.AdamOptimizer(self.LR_lag).minimize(labda_loss, var_list=log_labda)

            # self.a_loss = self.labda * tf.reduce_mean(log_pis * (self.R + self.l_net_T - self.l_net))
            # self.a_loss = self.labda * tf.reduce_mean(log_pis * (alpha3 * self.R))
            a_params = pi_params
            # target_loss = self.R + self.c_T - lyapunov_value
            target_loss = self.R + self.l_net_T - self.l_net
            surr = ratio * target_loss  # surrogate loss
            # self.a_loss = self.labda * tf.reduce_mean(tf.minimum(  # clipped surrogate objective
            #     surr,
            #     tf.clip_by_value(ratio, 1. - EPSILON, 1. + EPSILON) * target_loss))
            self.a_loss = self.labda * tf.reduce_mean(surr)
            self.atrain = tf.train.AdamOptimizer(self.LR_A).minimize(self.a_loss, var_list=a_params)

            if self.discounted_value_as_Lyapunov:
                target_l_net_ = self._build_l(self.S_, reuse=True, custom_getter=ema_getter)
                l_target = self.c + self.gamma * (1 - self.terminal) * tf.stop_gradient(target_l_net_)
            else:
                l_target = self.L_target
            with tf.control_dependencies(target_update):  # soft replacement happened at here

                self.l_error = tf.losses.mean_squared_error(labels=l_target, predictions=self.l_net)
                self.ltrain = tf.train.AdamOptimizer(self.LR_L).minimize(self.l_error, var_list=l_params)

            self.sess.run(tf.global_variables_initializer())
            self.saver = tf.train.Saver()
            self.diagnotics = {'labda':self.labda,
                               'lyapunov_error':self.l_error,
                               'actor_loss':self.a_loss,
                               'delta_lyapunov':self.delta_l,
                               }


            self.opt = [self.ltrain, self.lambda_train]
            self.opt.append(self.atrain)

    def choose_action(self, s, evaluation = False):
        if len(self.working_memory) < self.history_horizon:
            [self.working_memory.appendleft(s) for _ in range(self.history_horizon)]

        self.working_memory.appendleft(s)
        try:
            s = np.concatenate(self.working_memory)
        except ValueError:
            print(s)

        if evaluation is True:
            try:
                return self.sess.run(self.deterministic_a, {self.S: s[np.newaxis, :]})[0]
            except ValueError:
                return
        else:
            a = self.sess.run(self.a, {self.S: s[np.newaxis, :]})[0]
            if self.use_soft_clip:
                a = np.min([a, np.array([1-1e-16])],axis=0)
                a = np.max([a, -np.array([1 - 1e-16])], axis=0)
            return a

    def learn(self, LR_A, LR_L,LR_lag):
        # for t in range(self.pool.memory_pointer-1):
        #     inds = np.array([t+1])
        #     feed_dict = self.sample(LR_A, LR_L,LR_lag, inds)
        #     self.sess.run(self.opt, feed_dict)

        feed_dict = self.sample(LR_A, LR_L, LR_lag)
        diagnotics = self.sess.run([self.diagnotics[key] for key in self.diagnotics.keys()], feed_dict)
        self.sess.run(self.opt, feed_dict)
        self.sess.run(self.update_oldpi_op, feed_dict)
        output = {}
        [output.update({key: value}) for (key, value) in zip(self.diagnotics.keys(), diagnotics)]
        return output

    def sample(self, LR_A, LR_L, LR_lag, inds=None):
        if inds is None:
            batch = self.pool.sample()
        else:
            batch = self.pool.sample(inds)
        bs = batch['s']  # state
        ba = batch['a']  # action

        bc = batch['c']
        bc_ = batch['c_']
        bC = batch['C']
        bc_T = batch['c_T']
        bs_T = batch['s_T']
        btarget = batch['L_target']
        bs_ = batch['s_']  # next state
        b_terminal = batch['terminal']
        feed_dict = {self.a_input: ba, self.S: bs, self.S_: bs_, self.R: bC, self.c: bc, self.c_: bc_,
                     self.c_T: bc_T, self.L_target: btarget,
                     self.S_T: bs_T, self.terminal: b_terminal,
                     self.LR_A: LR_A, self.LR_L: LR_L, self.LR_lag: LR_lag}
        return feed_dict

    def store(self, s, a, norm, norm_, terminal, s_):
        self.pool.store(s, a, norm, norm_, terminal, s_)

    def _build_cliped_a(self, s, name='actor', reuse=None, custom_getter=None):
        if reuse is None:
            trainable = True
        else:
            trainable = False

        with tf.variable_scope(name, reuse=reuse, custom_getter=custom_getter):
            batch_size = tf.shape(s)[0]
            squash_bijector = (SquashBijector())
            base_distribution = tfp.distributions.MultivariateNormalDiag(loc=tf.zeros(self.a_dim), scale_diag=tf.ones(self.a_dim))
            epsilon = base_distribution.sample(batch_size)
            ## Construct the feedforward action
            net_0 = tf.layers.dense(s, 256, activation=tf.nn.relu, name='l1', trainable=trainable)#原始是30
            net_1 = tf.layers.dense(net_0, 256, activation=tf.nn.relu, name='l4', trainable=trainable)  # 原始是30
            mu = tf.layers.dense(net_1, self.a_dim, activation= None, name='a', trainable=trainable)
            log_sigma = tf.layers.dense(net_1, self.a_dim, None, trainable=trainable)
            log_sigma = tf.clip_by_value(log_sigma, *SCALE_DIAG_MIN_MAX)
            sigma = tf.exp(log_sigma)


            bijector = tfp.bijectors.Affine(shift=mu, scale_diag=sigma)
            raw_action = bijector.forward(epsilon)
            clipped_a = squash_bijector.forward(raw_action)

            ## Construct the distribution
            bijector = tfp.bijectors.Chain((
                squash_bijector,
                tfp.bijectors.Affine(
                    shift=mu,
                    scale_diag=sigma),
            ))
            distribution = tfp.distributions.ConditionalTransformedDistribution(
                    distribution=base_distribution,
                    bijector=bijector)

            clipped_mu = squash_bijector.forward(mu)

        return clipped_a, clipped_mu, distribution

    def _build_a(self, s, name='actor', reuse=None, custom_getter=None):
        if reuse is None:
            trainable = True
        else:
            trainable = False

        with tf.variable_scope(name, reuse=reuse, custom_getter=custom_getter):
            batch_size = tf.shape(s)[0]
            squash_bijector = (SquashBijector())
            base_distribution = tfp.distributions.MultivariateNormalDiag(loc=tf.zeros(self.a_dim), scale_diag=tf.ones(self.a_dim))
            epsilon = base_distribution.sample(batch_size)
            ## Construct the feedforward action
            # net_0 = tf.layers.dense(s, 256, activation=tf.nn.relu, name='l1', trainable=trainable)#原始是30
            net_1 = tf.layers.dense(s, 64, activation=tf.nn.relu, name='l4', trainable=trainable)  # 原始是30
            mu = tf.layers.dense(net_1, self.a_dim, activation= None, name='a', trainable=trainable)
            log_sigma = tf.layers.dense(net_1, self.a_dim, None, trainable=trainable)
            # log_sigma = tf.clip_by_value(log_sigma, *SCALE_DIAG_MIN_MAX)
            sigma = tf.exp(log_sigma)


            bijector = tfp.bijectors.Affine(shift=mu, scale_diag=sigma)
            raw_action = bijector.forward(epsilon)


            ## Construct the distribution

            distribution = tfp.distributions.ConditionalTransformedDistribution(
                    distribution=base_distribution,
                    bijector=bijector)


        return raw_action, mu, distribution

    def _build_anet(self, name, trainable):
        with tf.variable_scope(name):
            l1 = tf.layers.dense(self.S, 200, tf.nn.relu, trainable=trainable)
            mu = 2 * tf.layers.dense(l1, self.a_dim, tf.nn.tanh, trainable=trainable)
            sigma = tf.layers.dense(l1, self.a_dim, tf.nn.softplus, trainable=trainable)
            norm_dist = tf.distributions.Normal(loc=mu, scale=sigma)
        params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='L_REINFORCE_NETWORK/' + name)
        return norm_dist, params

    def evaluate_value(self, s, a):

        if len(self.working_memory) < self.history_horizon:
            [self.working_memory.appendleft(s) for _ in range(self.history_horizon)]

        self.working_memory.appendleft(s)
        try:
            s = np.concatenate(self.working_memory)
        except ValueError:
            print(s)

        return self.sess.run(self.l, {self.S: s[np.newaxis, :], self.a_input: a[np.newaxis, :]})[0]

    def _build_l(self, s, reuse=None, custom_getter=None):
        trainable = True if reuse is None else False
        with tf.variable_scope('Lyapunov', reuse=reuse, custom_getter=custom_getter):
            # n_l1 = 256#30
            # w1_s = tf.get_variable('w1_s', [self.s_dim, n_l1], trainable=trainable)
            # b1 = tf.get_variable('b1', [1, n_l1], trainable=trainable)
            # net_0 = tf.nn.relu(tf.matmul(s, w1_s) + b1)
            net_1 = tf.layers.dense(s, 64, activation=tf.nn.relu, name='l2', trainable=trainable)  # 原始是30
            # return tf.abs(tf.layers.dense(net_1, 1, trainable=trainable))  # Q(s,a)
            net_2 = tf.layers.dense(net_1, 16, activation=None, name='l3', trainable=trainable)  # 原始是30
            return tf.expand_dims(tf.reduce_sum(tf.square(net_2), axis=1),axis=1)  # Q(s,a)
            # return tf.square(tf.layers.dense(net_1, 1, trainable=trainable)) # Q(s,a)

    def save_result(self, path):

        save_path = self.saver.save(self.sess, path + "/policy/model.ckpt")
        print("Save to path: ", save_path)

    def restore(self, path):
        model_file = tf.train.latest_checkpoint(path+'/')
        if model_file is None:
            success_load = False
            return success_load
        self.saver.restore(self.sess, model_file)
        success_load = True
        return success_load

def train(variant):
    env_name = variant['env_name']
    env = get_env_from_name(env_name)

    env_params = variant['env_params']

    max_episodes = env_params['max_episodes']
    max_ep_steps = env_params['max_ep_steps']
    max_global_steps = env_params['max_global_steps']

    evaluation_frequency = variant['evaluation_frequency']

    policy_params = variant['alg_params']
    store_last_n_paths = policy_params['N_path_num']




    train_per_cycle = policy_params['train_per_cycle']


    lr_a,  lr_l = policy_params['lr_a'],  policy_params['lr_l']
    lr_a_now = lr_a  # learning rate for actor
    lr_l_now = lr_l  # learning rate for lyapunov

    if 'Fetch' in env_name or 'Hand' in env_name:
        s_dim = env.observation_space.spaces['observation'].shape[0]\
                + env.observation_space.spaces['achieved_goal'].shape[0]+ \
                env.observation_space.spaces['desired_goal'].shape[0]
    else:
        s_dim = env.observation_space.shape[0]
    a_dim = env.action_space.shape[0]
    # if disturber_params['process_noise']:
    #     d_dim = disturber_params['noise_dim']
    # else:
    #     d_dim = env_params['disturbance dim']

    a_upperbound = env.action_space.high
    a_lowerbound = env.action_space.low
    policy = L_REINFORCE(a_dim,s_dim, policy_params)


    # For analyse
    Render = env_params['eval_render']

    # Training setting
    t1 = time.time()
    global_step = 0
    last_training_paths = deque(maxlen=store_last_n_paths)
    training_started = False

    log_path = variant['log_path']
    logger.configure(dir=log_path, format_strs=['csv'])
    logger.logkv('tau', policy_params['tau'])
    logger.logkv('alpha3', policy_params['alpha3'])
    logger.logkv('N_path_num', policy_params['N_path_num'])
    logger.logkv('epsilon', policy_params['epsilon'])
    finished_rendering_this_epoch = False
    for i in range(max_episodes):

        current_path = {'rewards': [],
                        }

        if global_step > max_global_steps:
            break

        s = env.reset()

        if 'Fetch' in env_name or 'Hand' in env_name:
            s = np.concatenate([s[key] for key in s.keys()])

        for j in range(max_ep_steps):
            if (not finished_rendering_this_epoch) and Render:
                env.render()
            a = policy.choose_action(s)
            # a = a*0
            if policy.use_soft_clip:
                action = a_lowerbound + (a + 1.) * (a_upperbound - a_lowerbound) / 2
            else:
                action = np.clip(a, a_lowerbound, a_upperbound)

            # Run in simulator

            s_, r, done, info = env.step(action)
            last_r = info['last_cost']
            if 'Fetch' in env_name or 'Hand' in env_name:
                s_ = np.concatenate([s_[key] for key in s_.keys()])
                if info['done'] > 0:
                    done = True


            global_step += 1

            if j == max_ep_steps - 1:
                done = True

            terminal = 1. if done else 0.
            policy.store(s, a, last_r, r, terminal, s_)
            # policy.store_transition(s, a, disturbance, r,0, terminal, s_)

            current_path['rewards'].append(r)

            # 状态更新
            s = s_
            # OUTPUT TRAINING INFORMATION AND LEARNING RATE DECAY
            if done:
                finished_rendering_this_epoch = True
                last_training_paths.appendleft(current_path)

                frac = 1.0 - (global_step - 1.0) / max_global_steps
                lr_a_now = lr_a * frac  # learning rate for actor
                lr_l_now = lr_l * frac  # learning rate for critic
                # if (i + 1) % policy.N_path_num == 0 :
                if policy.pool.memory_pointer > policy_params['batch_size']:
                    for _ in range(train_per_cycle):
                        diagnotics = policy.learn(lr_a_now, lr_l_now, lr_a)
                    logger.logkv("total_timesteps", global_step)

                    training_diagnotic = evaluate_training_rollouts(last_training_paths)
                    if training_diagnotic is not None:

                        [logger.logkv(key, training_diagnotic[key]) for key in training_diagnotic.keys()]
                        [logger.logkv(key, diagnotics[key]) for key in diagnotics.keys()]
                        logger.logkv('lr_a', lr_a_now)
                        logger.logkv('lr_l', lr_l_now)

                        string_to_print = ['time_step:', str(global_step), '|']
                        [string_to_print.extend([key, ':', str(round(training_diagnotic[key], 2)), '|'])
                         for key in training_diagnotic.keys()]
                        [string_to_print.extend([key, ':', str(round(diagnotics[key], 2)), '|'])
                         for key in diagnotics.keys()]
                        print(''.join(string_to_print))
                    finished_rendering_this_epoch = False
                    logger.dumpkvs()
                break
    policy.save_result(log_path)

    print('Running time: ', time.time() - t1)
    return


