from functools import partial

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

batch_norm_momentum = 0.9
class Network(object):
    def __init__(self, env, TAU_a, TAU_c, TAU_p, TAU_l, LR_A, LR_C, LR_P, LR_L, LR_R, Threshold, arm_num, Seed, Seed_num, Algo, units, interval, TEST, Runout, scope=None, session=None, Update_freq=1, TAU_c1=0, TAU_c2=0, LR_C1=0, LR_C2=0, observe=0):
        self.transmit_par(env, TAU_a, TAU_c, TAU_p, TAU_l, LR_A, LR_C, LR_P, LR_L, LR_R, Threshold, Seed, Seed_num, Algo, units, interval, TEST, Runout, scope, session, Update_freq, TAU_c1, TAU_c2, LR_C1, LR_C2, observe)

        self.ini_k = tf.contrib.layers.xavier_initializer(False)
        self.ini_b = tf.constant_initializer(0)
        self.he_init = tf.variance_scaling_initializer()
        self.var_ini = 0.2#env.action_space.high[0]
        self.par_net = 10
        self.par = 1.0
        self.decay_rate = 1#0.999998
        self.alpha = 0.1
        self.bias = 1e-6
        self.gamma = 0.99
        self.gamma_c = 0.95
        self.state_dim, self.action_dim, self.a_bound = env.observation_space.shape[0], env.action_space.shape[0], env.action_space.high[0]
        self.dim = 25
        self.label = 'Arm_%i' % arm_num
        if Seed == 0:
            self.train_num = []
            self.reward_file = []
            self.penalty_file = []
        else:
            self.train_num = list(np.array(pd.read_csv('./data/' + self.label + '/Train_num_' + Algo + '.csv')).reshape(-1))
            self.reward_file = list(np.array(pd.read_csv('./data/' + self.label + '/Aver_reward_' + Algo + '.csv')).reshape(-1))
            self.penalty_file = list(np.array(pd.read_csv('./data/' + self.label + '/Aver_penalty_' + Algo + '.csv')).reshape(-1))
        if scope == 'global':
            self.name = ''
        else:
            self.name = '/workers_' + Algo + '/' + self.scope
        self.var_reset()
        self.net_reset(scope)

    def transmit_par(self, env, TAU_a, TAU_c, TAU_p, TAU_l, LR_A, LR_C, LR_P, LR_L, LR_R, Threshold, Seed, Seed_num, Algo, units, interval, TEST, Runout, scope, session, Update_freq, TAU_c1, TAU_c2, LR_C1, LR_C2, observe):
        self.LR_A = LR_A
        self.LR_C = LR_C
        self.LR_C1 = LR_C1
        self.LR_C2 = LR_C2
        self.LR_P = LR_P
        self.LR_L = LR_L
        self.LR_R = LR_R
        self.TAU_a = TAU_a
        self.TAU_c = TAU_c
        self.TAU_c1 = TAU_c1
        self.TAU_c2 = TAU_c2
        self.TAU_p = TAU_p
        self.TAU_l = TAU_l
        self.Threshold = Threshold
        self.Seed = Seed
        self.Seed_num = Seed_num
        self.Algo = Algo
        self.units = units
        self.interval = interval
        self.TEST = TEST
        self.Runout = Runout
        self.Update_freq = Update_freq
        self.observe = observe
        self.env = env
        self.scope = scope
        self.session = session

    def var_reset(self):
        self.train_lo = []
        self.rew_lo = []
        self.pen_lo = []
        self.train_counter = 0
        self.var = self.var_ini

    def net_reset(self, scope=None):
        if scope == 'global':
            self.creat_network(scope)
        self.saver = tf.train.Saver()

    def creat_network(self, scope):
        self.S = tf.placeholder(tf.float32, [None, self.state_dim], 's')
        self.S_ = tf.placeholder(tf.float32, [None, self.state_dim], 's_')
        self.R = tf.placeholder(tf.float32, [None, 1], 'r')
        self.Pmt = tf.placeholder(tf.float32, None, 'Par')

        with tf.variable_scope(scope):
            with tf.variable_scope('Actor'): #改改
                self.a, self.a_alt, log_p = self._build_a_rand(self.S, scope='eval', trainable=True)
                self.a_, self.a_alt_, log_p_ = self._build_a_rand(self.S_, scope='target', trainable=False)
            with tf.variable_scope('Lambda'):  # 改改
                self.s_next, log_s = self._build_lam(self.S, self.a_alt, scope='eval', trainable=True)
                self.s_next_, log_s_ = self._build_lam(self.S_, self.a_alt_, scope='target', trainable=False)
            with tf.variable_scope('Critic1'):
                q1 = self._build_c(self.S, self.a, scope='eval', trainable=True)
                q1_alt = self._build_c(self.S, self.a_alt, scope='eval', trainable=True)
                q1_ = self._build_c(self.S_, self.a_, scope='target', trainable=False)
                q1_alt_ = self._build_c(self.s_next, self.a_alt_, scope='target', trainable=False)
            with tf.variable_scope('Critic2'):
                q2 = self._build_c(self.S, self.a, scope='eval', trainable=True)
                q2_alt = self._build_c(self.S, self.a_alt, scope='eval', trainable=True)
                q2_ = self._build_c(self.S_, self.a_, scope='target', trainable=False)
                q2_alt_ = self._build_c(self.s_next, self.a_alt_, scope='target', trainable=False)
            with tf.variable_scope('R'):
                self.R_on = self._build_r(self.S, self.a_alt, self.s_next, scope='eval', trainable=True)

        # networks parameters
        self.ae_params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=scope + '/' + 'Actor/eval')
        self.at_params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=scope + '/' + 'Actor/target')
        self.c1e_params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=scope + '/' + 'Critic1/eval')
        self.c1t_params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=scope + '/' + 'Critic1/target')
        self.c2e_params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=scope + '/' + 'Critic2/eval')
        self.c2t_params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=scope + '/' + 'Critic2/target')
        self.le_params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=scope + '/' + 'Lambda/eval')
        self.lt_params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=scope + '/' + 'Lambda/target')
        self.re_params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=scope + '/' + 'R/eval')

        self.soft_replace_a = [tf.assign(t, (1 - self.TAU_a) * t + self.TAU_a * e) for t, e in zip(self.at_params, self.ae_params)]
        self.soft_replace_c1 = [tf.assign(t, (1 - self.TAU_c1) * t + self.TAU_c1 * e) for t, e in zip(self.c1t_params, self.c1e_params)]
        self.soft_replace_c2 = [tf.assign(t, (1 - self.TAU_c2) * t + self.TAU_c2 * e) for t, e in zip(self.c2t_params, self.c2e_params)]
        self.soft_replace_l = [tf.assign(t, (1 - self.TAU_l) * t + self.TAU_l * e) for t, e in zip(self.lt_params, self.le_params)]

        self.done = tf.placeholder(tf.float32, [None, 1], 'done')
        q_ = tf.minimum(q1_, q2_)
        q_alt_ = tf.minimum(q1_alt_, q2_alt_)
        log_g_ = log_s_ + log_p_
        log_g = log_s + log_p

        if self.Algo == 'MBAC-R':
            y_batch = self.R + self.gamma * (self.Pmt * q_ + (1 - self.Pmt) * q_alt_ - self.alpha * log_p_) * (1 - self.done)
        elif self.Algo == 'MBAC-P':
            y_batch = self.Pmt * self.R + (1 - self.Pmt) * self.R_on + self.gamma * (q_ - self.alpha * log_p_) * (1 - self.done)
        else:
            y_batch = self.Pmt * self.R + (1 - self.Pmt) * self.R_on + self.gamma * (self.Pmt * q_ + (1 - self.Pmt) * q_alt_ - self.alpha * log_p_) * (1 - self.done)
        if self.Algo == 'MBAC-Q':
            td1_error = tf.losses.mean_squared_error(labels=y_batch, predictions=q1)
            td2_error = tf.losses.mean_squared_error(labels=y_batch, predictions=q2)
        else:
            td1_error = tf.losses.mean_squared_error(labels=y_batch, predictions=self.Pmt * q1 + (1 - self.Pmt) * q1_alt)
            td2_error = tf.losses.mean_squared_error(labels=y_batch, predictions=self.Pmt * q2 + (1 - self.Pmt) * q2_alt)
        self.c1train = tf.train.AdamOptimizer(self.LR_C1).minimize(td1_error, var_list=self.c1e_params)
        self.c2train = tf.train.AdamOptimizer(self.LR_C2).minimize(td2_error, var_list=self.c2e_params)

        self.q = tf.minimum(q1, q2)
        a_loss = tf.reduce_mean(self.alpha * log_p - self.q)
        self.atrain = tf.train.AdamOptimizer(self.LR_A).minimize(a_loss, var_list=self.ae_params)

        l_loss = tf.reduce_mean(tf.reduce_sum((self.s_next - self.S_) ** 2, axis=1, keep_dims=True))  # 改改
        self.ltrain = tf.train.AdamOptimizer(self.LR_L).minimize(l_loss, var_list=self.le_params)

        r_loss = tf.reduce_mean((self.R_on - self.R) ** 2)  # 改改
        self.rtrain = tf.train.AdamOptimizer(self.LR_R).minimize(r_loss, var_list=self.re_params)

    def var_run(self, state, action, next_state, reward, done):
        self.session.run(self.c1train, {self.S: state, self.a: action, self.R: reward, self.S_: next_state, self.done: done, self.Pmt: self.par})
        self.session.run(self.c2train, {self.S: state, self.a: action, self.R: reward, self.S_: next_state, self.done: done, self.Pmt: self.par})
        self.session.run(self.atrain, {self.S: state, self.Pmt: self.par}) #改改
        self.session.run(self.ltrain, {self.S: state, self.S_: next_state, self.a_alt: action})  # 改改
        self.session.run(self.rtrain, {self.S: state, self.a_alt: action, self.s_next: next_state, self.R: reward})  # 改改
        if self.train_counter % self.Update_freq == 0:
            self.session.run([self.soft_replace_a, self.soft_replace_c1, self.soft_replace_c2, self.soft_replace_l])

        self.train_counter += 1
        self.var *= self.decay_rate
        self.par *= 0.9999999995
        if self.train_counter > self.observe:
            self.par *= 0#0.9999999995  # test
        self.par_net *= 0.95  # test
        if self.train_counter % self.interval == 0:
            ave_reward, acc_penalty = self.per_eval()
            # self.per_save()
            print(self.scope, 'episode:', self.train_counter, 'Average Reward:', ave_reward, 'Average Penalty:', acc_penalty)

    def per_eval(self):
        # self.env.back_var()
        total_reward = 0
        total_penalty = 0
        step = 0
        for i in range(self.TEST):
            local_penalty = 0
            s = self.env.reset()
            for j in range(self.Runout):
                a = self.choose_action(s)
                s, r, done, _ = self.env.step(a)
                p = self.env.penalty(a)
                total_reward += r[0][0] + self.env.env_.control_cost(a[0])
                local_penalty += p[0][0]
                step += 1
                if done:
                    break
            total_penalty += local_penalty
        ave_reward = total_reward / self.TEST
        acc_penalty = total_penalty / self.TEST
        self.train_lo.append(self.train_counter)
        self.rew_lo.append(ave_reward)
        self.pen_lo.append(acc_penalty)
        self.train_num.append(self.train_counter)
        self.reward_file.append(ave_reward)
        self.penalty_file.append(acc_penalty)
        # self.env.restore_var()
        return ave_reward, acc_penalty

    def choose_action_exp(self, s):
        a = self.session.run(self.a, {self.S: s})
        a = np.random.normal(a, self.var)
        a = np.clip(a, -self.a_bound, self.a_bound)
        return a

    def choose_action(self, s):
        return self.session.run(self.a, {self.S: s})

    def choose_action_u(self, s): #改
        size = s.shape[0]
        a = np.random.uniform(-self.a_bound, self.a_bound, (size, self.action_dim))
        return a

    def step(self, s, a=None):
        if a is None:
            a = self.choose_action(s)
        s_next = self.session.run(self.s_next, {self.S: s, self.a_alt: a})
        # s_next = np.random.normal(s_next, self.var)
        r = self.session.run(self.R_on, {self.S: s, self.a_alt: a, self.s_next: s_next})
        return s_next, r, a

    def _build_c(self, s, a, scope=None, trainable=True):
        with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
            size = tf.shape(s)[0]
            hidden_units = self.units  # small problem: 30
            w1_a = tf.get_variable('w1_a', [self.action_dim, hidden_units], initializer=self.ini_k, trainable=trainable)
            w1_s = tf.get_variable('w1_s', [self.state_dim, hidden_units], initializer=self.ini_k, trainable=trainable)
            b1 = tf.get_variable('b1', [1, hidden_units], initializer=self.ini_b, trainable=trainable)
            net = tf.nn.relu(tf.matmul(a, w1_a) + tf.matmul(s, w1_s) + b1)
            net = tf.layers.dense(net, hidden_units, tf.nn.relu, kernel_initializer=self.ini_k, bias_initializer=self.ini_b, trainable=trainable)
            q = tf.layers.dense(net, 1, kernel_initializer=self.ini_k, bias_initializer=self.ini_b, trainable=trainable)
            return q

    def _build_a_rand(self, s, scope, trainable):
        with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
            net = tf.layers.dense(s, self.units, tf.nn.relu, kernel_initializer=self.ini_k, bias_initializer=self.ini_b, trainable=trainable)
            net = tf.layers.dense(net, self.units, tf.nn.relu, kernel_initializer=self.ini_k, bias_initializer=self.ini_b, trainable=trainable)
            mu = self.a_bound * tf.layers.dense(net, self.action_dim, tf.nn.tanh, kernel_initializer=self.ini_k, bias_initializer=self.ini_b, trainable=trainable)
            sigma = tf.layers.dense(net, self.action_dim, activation=tf.nn.softplus,
                                    kernel_initializer=self.ini_k,
                                    bias_initializer=self.ini_b, trainable=trainable) + self.bias
            # sigma = tf.clip_by_value(sigma, 1e-6, 20)
            # sigma = tf.exp(sigma)
            norm_dist = tf.distributions.Normal(loc=mu, scale=sigma)
            a = norm_dist.sample()
            action = tf.clip_by_value(a, -self.a_bound, self.a_bound)
            action_alt = tf.clip_by_value(a, -self.a_bound, self.a_bound)
            out = tf.nn.tanh(a)
            # action = out * self.a_bound
            log_p = tf.reduce_sum(norm_dist.log_prob(a) - tf.log(1 - out ** 2 + 1e-6), 1, keepdims=True)

            return action, action_alt, log_p

    def _build_lam(self, s, a, scope, trainable):
        with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
            my_batch_norm_layer = partial(tf.layers.batch_normalization, training=trainable, momentum=batch_norm_momentum)
            units = self.units
            net = tf.concat([s, a], 1)
            net = tf.layers.dense(net, units, tf.nn.relu, kernel_initializer=self.he_init, bias_initializer=self.ini_b, trainable=trainable)
            # net = my_batch_norm_layer(net)
            net = tf.layers.dense(net, units, tf.nn.relu, kernel_initializer=self.he_init, bias_initializer=self.ini_b, trainable=trainable)
            # net = my_batch_norm_layer(net)
            net = tf.layers.dense(net, units, tf.nn.relu, kernel_initializer=self.he_init, bias_initializer=self.ini_b, trainable=trainable)
            # net = my_batch_norm_layer(net)
            mu = tf.layers.dense(net, self.state_dim, kernel_initializer=self.he_init, bias_initializer=self.ini_b, trainable=trainable)
            sigma = tf.layers.dense(net, self.state_dim, activation=tf.nn.softplus,
                                    kernel_initializer=self.he_init,
                                    bias_initializer=self.ini_b, trainable=trainable) + self.bias
            # sigma = tf.clip_by_value(sigma, 1e-6, 20)
            # sigma = tf.exp(sigma)
            norm_dist = tf.distributions.Normal(loc=mu, scale=sigma)
            s_next = norm_dist.sample()
            s_next += s
            log_s = tf.reduce_sum(norm_dist.log_prob(s_next), 1, keepdims=True)
            return s_next, log_s

    def _build_r(self, s, a, s_, scope, trainable):
        with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
            units = self.units
            net = tf.concat([s, a, s_], 1)
            net = tf.layers.dense(net, units, tf.nn.relu, kernel_initializer=self.he_init, bias_initializer=self.ini_b, trainable=trainable)
            net = tf.layers.dense(net, units, tf.nn.relu, kernel_initializer=self.he_init, bias_initializer=self.ini_b, trainable=trainable)
            # net = tf.layers.dense(net, self.units, kernel_initializer=self.ini_k, bias_initializer=self.ini_b, trainable=trainable)
            # net = tf.layers.dense(net, self.units, tf.nn.relu, kernel_initializer=self.ini_k, bias_initializer=self.ini_b, trainable=trainable)
            r = tf.layers.dense(net, 1, kernel_initializer=self.he_init, bias_initializer=self.ini_b, trainable=trainable)
            return r

    def save(self):
        self.saver.save(self.session, './data/' + self.label + self.name + '/params_' + self.Algo, write_meta_graph=False)

    def load(self):
        self.saver.restore(self.session, './data/' + self.label + self.name + '/params_' + self.Algo)

    def plot_cost(self, on_train=True):
        if on_train:
            if self.scope == 'global':
                self.global_save()
            if self.Seed < self.Seed_num - 1:
                self.var_reset()
            if self.scope == 'global':
                self.per_save()
        else:
            x = np.array(pd.read_csv('./data/' + self.label + self.name + '/Train_num_' + self.Algo + '.csv')).reshape(-1)
            y = np.array(pd.read_csv('./data/' + self.label + self.name + '/Aver_reward_' + self.Algo + '.csv')).reshape(-1)
            z = np.array(pd.read_csv('./data/' + self.label + '/Aver_penalty_' + self.Algo + '.csv')).reshape(-1)
            parameter = np.polyfit(x, y, 20)
            parameter1 = np.polyfit(x, z, 20)
            p = np.poly1d(parameter)
            p1 = np.poly1d(parameter1)
            plt.scatter(x, y)
            # plt.scatter(x, z)
            train_counter = np.array(pd.read_csv('./data/' + self.label + self.name + '/Train_counter_' + self.Algo + '.csv')).reshape(-1)
            x = (np.arange(train_counter / self.interval) + 1) * self.interval
            plt.plot(x, p(x), color='g')
            plt.ylabel('Average reward')
            plt.xlabel('training nums')
            plt.title(self.Algo + '-' + self.env.env_name)
            plt.show()

    def per_save(self):
        var = pd.DataFrame(np.array([self.var]))
        train_counter = pd.DataFrame(np.array([self.train_counter]))
        x_lo = pd.DataFrame(self.train_lo)
        y_lo = pd.DataFrame(self.rew_lo)
        z_lo = pd.DataFrame(self.pen_lo)
        self.save()
        var.to_csv('./data/' + self.label + '/var_' + self.Algo + '.csv', index=False, header=True)
        train_counter.to_csv('./data/' + self.label + '/Train_counter_' + self.Algo + '.csv', index=False, header=True)
        x_lo.to_csv('./data/' + self.label + '/Local_Train_num_' + self.Algo + '.csv', index=False, header=True)
        y_lo.to_csv('./data/' + self.label + '/Local_Aver_reward_' + self.Algo + '.csv', index=False, header=True)
        z_lo.to_csv('./data/' + self.label + '/Local_Aver_penalty_' + self.Algo + '.csv', index=False, header=True)

    def global_save(self):
        x_data = pd.DataFrame(self.train_num)
        y_data = pd.DataFrame(self.reward_file)
        z_data = pd.DataFrame(self.penalty_file)
        x_data.to_csv('./data/' + self.label + '/Train_num_' + self.Algo + '.csv', index=False, header=True)
        y_data.to_csv('./data/' + self.label + '/Aver_reward_' + self.Algo + '.csv', index=False, header=True)
        z_data.to_csv('./data/' + self.label + '/Aver_penalty_' + self.Algo + '.csv', index=False, header=True)

    def eval(self):
        self.plot_cost(False)
        self.load()
        s = self.env.reset()
        while True:
            self.env.env.render()
            a = self.choose_action(s)
            s, r, done, _ = self.env.step(a)

    def load_par(self):
        self.var = np.array(pd.read_csv('./data/' + self.label + '/var_' + self.Algo + '.csv')).reshape(-1)[0]
        self.train_counter = np.array(pd.read_csv('./data/' + self.label + '/Train_counter_' + self.Algo + '.csv')).reshape(-1)[0]
        self.train_lo = list(np.array(pd.read_csv('./data/' + self.label + '/Local_Train_num_' + self.Algo + '.csv')).reshape(-1))
        self.rew_lo = list(np.array(pd.read_csv('./data/' + self.label + '/Local_Aver_reward_' + self.Algo + '.csv')).reshape(-1))
        self.pen_lo = list(np.array(pd.read_csv('./data/' + self.label + '/Local_Aver_penalty_' + self.Algo + '.csv')).reshape(-1))
        self.train_num = list(np.array(pd.read_csv('./data/' + self.label + '/Train_num_' + self.Algo + '.csv')).reshape(-1))
        self.reward_file = list(np.array(pd.read_csv('./data/' + self.label + '/Aver_reward_' + self.Algo + '.csv')).reshape(-1))
        self.penalty_file = list(np.array(pd.read_csv('./data/' + self.label + '/Aver_penalty_' + self.Algo + '.csv')).reshape(-1))
        self.load()