import numpy as np
from baselines.a2c.utils import discount_with_dones
from baselines.common.runners import AbstractEnvRunner
import tensorflow as tf
import math
import time
from mpi4py import MPI
import gc
from array import array

#tf.compat.v1.enable_eager_execution()

#import mpi4py
#mpi4py.rc.recv_mprobe = False

class Runner(AbstractEnvRunner):
    """
    We use this class to generate batches of experiences

    __init__:
    - Initialize the runner

    run():
    - Make a mini batch of experiences
    """
    def __init__(self, env, model, model_ppo2, model_acer, nsteps=5, gamma=0.99):
        super().__init__(env=env, model=model, nsteps=nsteps)
        self.gamma = gamma
        nenv = self.nenv
        self.batch_ob_shape_acer = (nenv*(nsteps+1),) + env.observation_space.shape
        #print("!!!!!! a2c batch_action_shape: " + str(model.train_model.action.shape.as_list()))
        self.batch_action_shape = [x if x is not None else -1 for x in model.train_model.action.shape.as_list()]
        #print("a2c batch action shape: " + str(self.batch_action_shape))
        self.ob_dtype = model.train_model.X.dtype.as_numpy_dtype
        self.nstack = self.env.nstack
        self.nc = self.batch_ob_shape_acer[-1] // self.nstack
        self.obs_dtype = env.observation_space.dtype
        self.ac_dtype = env.action_space.dtype
        #self.q_exp = q_exp
        #self.q_model = q_model
        self.models = [model, model_ppo2, model_acer]
        #self.models.append(model)
        #self.models.append(model_ppo2)
        #self.models.append(model_acer)
        #self.model_ppo2 = model_ppo2
        #self.model_acer = model_acer
        self.lam=0.95
        self.nact = self.env.action_space.n
        #self.buf = bytearray(1 << 30)  # receive buffer
        #self.buf = memoryview(np.arange(1 << 70).tobytes())
        #self.buf = np.zeros(int(2^31-1), dtype='double')
        self.buf = [np.zeros((8, 8, 4, 32), dtype='double'), np.zeros((1, 32, 1, 1), dtype='double'), np.zeros((4, 4, 32, 64), dtype='double'), np.zeros((1, 64, 1, 1), dtype='double'), np.zeros((3, 3, 64, 64), dtype='double'), np.zeros((1, 64, 1, 1), dtype='double'), np.zeros((3136, 512), dtype='double'), np.zeros((512,), dtype='double'), np.zeros((512, 6), dtype='double'), np.zeros((6,), dtype='double'), np.zeros((512, 6), dtype='double'), np.zeros((6,), dtype='double'), np.zeros((512, 1), dtype='double'), np.zeros((1,), dtype='double')]

    def run(self, rewmean, Arr_index, Params_cur, Params_pre):

        comm = MPI.COMM_WORLD
        size = comm.Get_size()
        rank = comm.Get_rank()
        #buf = bytearray(1 << 30)  # 1GB receive buffer

        ppo2_param_arr = None
        #wait = 5
        arr_count = Arr_index[1]
        #ppo2_param = []
        #self.buf[:] = b'\x00' * len(self.buf)
        self.buf[arr_count].fill(0) #[:] = 0
        req = comm.Irecv(self.buf[arr_count], source=1, tag=1)
        #req.Wait()
        while True:
            if not req.Get_status():
                time.sleep(0.7)
                print('BEFORE wait!')
                print('Arr index: ' + str(arr_count))
                status = req.Test() #wait()
                print('AFTER wait!')
            else:
                print('Arr index: ' + str(arr_count))
                status = req.Test()
            if status:
                #if arr_count >= 14:
                #    ppo2_param = []
                #    arr_count = 0
                #wait = 5
                #arr_count += 1
                #rewmean[2] = status[1]
                #print(self.buf)
                #mlen = int(self.buf[0])
                ppo2_param_arr = self.buf[arr_count] #status[1]
                #ppo2_param.append(ppo2_param_arr)
                Params_cur[1].append(ppo2_param_arr)
                del status
                gc.collect()
                if arr_count >= 13:
                    #ppo2_param = []
                    Params_pre[1] = Params_cur[1]
                    Params_cur[1] = []
                    arr_count = -1
                arr_count += 1
                #self.buf[:] = b'\x00' * len(self.buf)
                self.buf[arr_count].fill(0) #[:] = 0
                req = comm.Irecv(self.buf[arr_count], source=1, tag=1)
                #req.Wait()
            #elif wait > 0:
            #    time.sleep(5)
            #    wait -= 1
                #self.buf[:] = b'\x00' * len(self.buf)
            #    self.buf[arr_count][:] = 0
            #    req = comm.Irecv(self.buf[arr_count], source=1, tag=1)
            #    req.Wait()
            else:
                print('test false')
                req.Cancel()
                req.Free()
                break
        #while not self.q_model[0][1].empty():
        #    ppo2_param = self.q_model[0][1].get()
        Arr_index[1] = arr_count

        if len(Params_pre[1]) == 14:
            print('a2c recv ppo2 model!')
            #print("ppo2_param received: ")
            #print(ppo2_param)
            params = tf.trainable_variables("ppo2_model")
            #print("params current ppo2 model: ")
            #print(params)
            for i in range(len(params)):
                #params[i].assign(ppo2_param[i])
                update = tf.assign(params[i],Params_pre[1][i])
                self.models[1].sess.run(update)
            #print("params ppo2 model after assign: ")
            #print(params)
        acer_param_arr = None
        #wait = 5
        arr_count = Arr_index[2]
        #acer_param = []
        self.buf[arr_count].fill(0) #[:] = 0 #b'\x00' * len(self.buf)
        req = comm.Irecv(self.buf[arr_count], source=2, tag=1)
        #req.Wait()
        while True:
            if not req.Get_status():
                time.sleep(0.7)
                print('BEFORE wait!')
                print('Arr index: ' + str(arr_count))
                status = req.Test() #wait()
                print('AFTER wait!')
            else:
                print('Arr index: ' + str(arr_count))
                status = req.Test()
            if status:
                #if arr_count >= 14:
                #    acer_param = []
                #    arr_count = 0
                #wait = 5
                #arr_count += 1
                #rewmean[2] = status[1]
                #print(self.buf)
                #mlen = int(self.buf[0])
                acer_param_arr = self.buf[arr_count] #status[1]
                #acer_param.append(acer_param_arr)
                Params_cur[2].append(acer_param_arr)
                del status
                gc.collect()
                if arr_count >= 13:
                    #acer_param = []
                    Params_pre[2] = Params_cur[2]
                    Params_cur[2] = []
                    arr_count = -1
                arr_count += 1
                self.buf[arr_count].fill(0) #[:] = 0 #b'\x00' * len(self.buf)
                req = comm.Irecv(self.buf[arr_count], source=2, tag=1)
                #req.Wait()
            #elif wait > 0:
            #    time.sleep(5)
            #    wait -= 1
            #    self.buf[arr_count][:] = 0 #b'\x00' * len(self.buf)
            #    req = comm.Irecv(self.buf[arr_count], source=2, tag=1)
            #    req.Wait()
            else:
                print('test false')
                req.Cancel()
                req.Free()
                break

        Arr_index[2] = arr_count
        #while not self.q_model[0][2].empty():
        #    acer_param = self.q_model[0][2].get()
        if len(Params_pre[2]) == 14:
            print('a2c recv acer model!')
            #print("acer_param received: ")
            #print(acer_param)
            params = tf.trainable_variables("acer_model")
            #print("params current acer model: ")
            #print(params)
            for i in range(len(params)-2):
                #params[i].assign(acer_param[i])
                update = tf.assign(params[i],Params_pre[2][i])
                self.models[2].sess.run(update)
            #print("params acer model after assign: ")
            #print(params)

        # We initialize the lists that will contain the mb of experiences
        enc_obs = np.split(self.env.stackedobs, self.env.nstack, axis=-1)
        mb_obs, mb_rewards, mb_actions, mb_values, mb_values_ppo2, mb_dones = [], [], [], [], [], []
        mb_states = self.states
        epinfos = []
        #print("A2C self.obs: ")
        #print(np.shape(self.obs))
        count = [0,0,0]
        value_sum = [0,0,0]
        for n in range(self.nsteps):
            #gc.collect()
            # Given observations, take action and value (V(s))
            # We already have self.obs because Runner superclass run self.obs[:] = env.reset() on initi
            #time.sleep(0.01)

            action_list = []
            value_list = []
            state_list = []
            likelihood_list = []
            mus_list = []
            for k in range(3):
                #tmp0, tmp1, tmp2, tmp3 = self.models[k].step(self.obs, S=self.states, M=self.dones)
                tmp0, tmp4, tmp2, tmp5, tmp1, tmp3 = self.models[k]._step(self.obs, S=self.states, M=self.dones)
                action_list.append(tmp0)
                #print("a2c agent " + str(k) + " selected action: " + str(tmp0))
                #print("a2c likelihood: " + str(tmp3))
                #print("a2c pd: " + str(tmp4))
                #print(tmp0)
                if tmp2.size == 0:
                    tmp2 = None
                value_list.append(tmp1)
                state_list.append(tmp2)
                likelihood_list.append(tmp3)
                #print(tmp3)
                mus_list.append(tmp4)
                #print(tmp4)
                #print("values: ")
                #print(tmp1)
                #print(tmp5)
            #print("a2c acer action values: " + str(tmp5))

            value_sum[0] += sum(value_list[0])
            value_sum[1] += sum(value_list[1])
            value_sum[2] += sum([sum([tmp5[i,j]*mus_list[2][i,j] for j in range(self.nact)]) for i in range(self.nenv)])
            #print("a2c value_sum: " + str(value_sum))
            #print(value_sum)

            for k in range(self.nenv):
                #flag = 0
                #temp = [likelihood_list[0][k],likelihood_list[1][k],likelihood_list[2][k]]
                #temp_min = min(temp)
                index = 0
                #threshold0 = -1 * math.log(1/self.nact) * 0.95
                threshold1 = -1 * math.log(1/self.nact) * 0.8
                #threshold2 = -1 * math.log(1/self.nact) * 0.4
                #print("threshold0 = " + str(threshold0))
                #print("threshold1 = " + str(threshold1))
                for j in range(3):
                    if rewmean[j] and rewmean[0] and rewmean[j] > rewmean[0] and likelihood_list[j][k] < threshold1 and (value_sum[j] == max(value_sum)):
                        index = j
                        #flag = 1
                        #print("a2c agent selected by max value_sum!")
                #if temp_min < threshold1 and temp_min > threshold2 and flag == 0:
                    #index = temp.index(temp_min)
                    #print("a2c agent selected by min likelihood!")
                #print("a2c selected agent: " + str(index))
                #print(index)
                count[index] += 1
                action_list[0][k] = action_list[index][k]
                #value_list[0][k] = value_list[index][k]
                #state_list[0][k] = state_list[index][k]
                #likelihood_list[0][k] = likelihood_list[index][k]
                #mus_list[0][k] = mus_list[index][k]
                #likelihood_list[0][k] = -1 * math.log(mus_list[0][k][action_list[0][k]])


            #actions, values, states, neglogpacs = self.model.step(self.obs, S=self.states, M=self.dones)
            #_, mus, _ = self.model._step(self.obs, S=self.states, M=self.dones)


            actions = action_list[0]
            values = value_list[0]
            values_ppo2 = value_list[1]
            states = state_list[0]
            #neglogpacs = likelihood_list[0]
            #mus = mus_list[0]


            # Append the experiences
            mb_obs.append(np.copy(self.obs))
            #mb_obs_acer.append(np.copy(self.obs))
            mb_actions.append(actions)
            mb_values.append(values)
            mb_values_ppo2.append(values_ppo2)
            #mb_mus.append(mus)
            mb_dones.append(self.dones)
            #mb_dones_ppo2.append(self.dones)
            #mb_neglogpacs.append(neglogpacs)

            # Take actions in env and look the results
            #print("A2C actions zise: ")
            #print(np.shape(actions))
            obs, rewards, dones, infos = self.env.step(actions)
            for info in infos:
                maybeepinfo = info.get('episode')
                if maybeepinfo: epinfos.append(maybeepinfo)
            self.states = states
            self.dones = dones
            self.obs = obs
            mb_rewards.append(rewards)
            enc_obs.append(obs[..., -self.nc:])
        #mb_obs_acer.append(np.copy(self.obs))
        mb_dones.append(self.dones)

        bestagent = None
        if rewmean[0] and rewmean[1] and rewmean[2]:
            bestagent = rewmean.index(max(rewmean))
            #print("a2c rewmean: " + str(rewmean))

        mb_values = np.asarray(mb_values, dtype=np.float32).swapaxes(1, 0)
        agent = count.index(max(count))
        #print("use agent " + str(agent) + " the most!")
        if agent != 0 and max(count) > (self.nenv*self.nsteps/2):
            #print("a2c count: ")
            #print(count)
            params = tf.trainable_variables("a2c_model")
            #print("a2c_model: ")
            #print(params)
            if agent == 1 and bestagent == 1:
                #print("ppo2_param received: ")
                #print(ppo2_param)
                print("a2c use ppo2 model")
                ppo2_param = tf.trainable_variables("ppo2_model")
                for i in range(len(params)):
                    #params[i].assign(ppo2_param[i])
                    update = tf.assign(params[i],ppo2_param[i])
                    self.model.sess.run(update)
                    #print("params " + str(i) + ":")
                    #print(params[i].numpy())
                #print("a2c_model after assign: ")
                #print(params)
                mb_values = np.asarray(mb_values_ppo2, dtype=np.float32).swapaxes(1, 0)
            if agent == 2 and bestagent == 2:
                #print("acer_model received: ")
                #print(acer_param)
                print("a2c use acer model")
                acer_param = tf.trainable_variables("acer_model")
                for i in range(len(params)-2):
                    #params[i].assign(acer_param[i])
                    update = tf.assign(params[i],acer_param[i])
                    self.model.sess.run(update)
                    #print("params " + str(i) + ":")
                    #print(params[i].numpy())
                #print("a2c_model after assign: ")
                #print(params)


        # Batch of steps to batch of rollouts
        #enc_obs = np.asarray(enc_obs[0:511], dtype=self.obs_dtype).swapaxes(1, 0)
        #mb_obs_acer = np.asarray(mb_obs, dtype=self.obs_dtype).swapaxes(1, 0)
        #mb_obs_ppo2 = np.asarray(mb_obs, dtype=self.obs.dtype)
        mb_obs = np.asarray(mb_obs, dtype=self.ob_dtype).swapaxes(1, 0).reshape(self.batch_ob_shape)
        #mb_rewards_acer = np.asarray(mb_rewards[0:511], dtype=np.float32).swapaxes(1, 0)
        #print("mb_rewards_acer shape: " + str(np.shape(mb_rewards_acer)))
        #print("mb_rewards shape: " + str(np.shape(mb_rewards)))
        #mb_rewards_ppo2 = np.asarray(mb_rewards, dtype=np.float32)
        mb_rewards = np.asarray(mb_rewards, dtype=np.float32).swapaxes(1, 0)
        #mb_actions_acer = np.asarray(mb_actions[0:511], dtype=self.ac_dtype).swapaxes(1, 0)
        #mb_actions_ppo2 = np.asarray(mb_actions)
        mb_actions = np.asarray(mb_actions, dtype=self.model.train_model.action.dtype.name).swapaxes(1, 0)
        #mb_values_ppo2 = np.asarray(mb_values, dtype=np.float32)
        #mb_values = np.asarray(mb_values, dtype=np.float32).swapaxes(1, 0)
        #mb_neglogpacs = np.asarray(mb_neglogpacs, dtype=np.float32)
        #mb_mus = np.asarray(mb_mus[0:511], dtype=np.float32).swapaxes(1, 0)

        #mb_dones_ppo2 = np.asarray(mb_dones_ppo2, dtype=np.bool)
        #mb_dones_acer = np.asarray(mb_dones_ppo2, dtype=np.bool).swapaxes(1, 0)
        #mb_masks_acer = mb_dones_acer
        #mb_dones_acer = mb_dones_acer[:, 1:]
        mb_dones = np.asarray(mb_dones, dtype=np.bool).swapaxes(1, 0)
        mb_masks = mb_dones[:, :-1]
        mb_dones = mb_dones[:, 1:]

        #print("a2c self.obs shape at last_values: ")
        #print(np.shape(self.obs))

        last_values = self.model.value(self.obs, S=self.states, M=self.dones)
        #print("a2c last_values shape: ")
        #print(np.shape(last_values))
        #print(last_values)


        #print("a2c mb_actions_a2c size: ")
        #print(np.shape(mb_actions))
        #print("a2c mb_actions_acer size: ")
        #print(np.shape(mb_actions_acer))
        #print("a2c mb_actions_ppo2 size: ")
        #print(np.shape(mb_actions_ppo2))
        """
        mb_returns = np.zeros_like(mb_rewards_ppo2)
        mb_advs = np.zeros_like(mb_rewards_ppo2)
        lastgaelam = 0
        for t in reversed(range(self.nsteps)):
            if t == self.nsteps - 1:
                nextnonterminal = 1.0 - self.dones
                nextvalues = last_values
            else:
                nextnonterminal = 1.0 - mb_dones_ppo2[t+1]
                nextvalues = mb_values_ppo2[t+1]
            delta = mb_rewards_ppo2[t] + self.gamma * nextvalues * nextnonterminal - mb_values_ppo2[t]
            mb_advs[t] = lastgaelam = delta + self.gamma * self.lam * nextnonterminal * lastgaelam
        mb_returns = mb_advs + mb_values_ppo2
        """

        if self.gamma > 0.0:
            # Discount/bootstrap off value fn
            last_value = last_values.tolist()
            for n, (rewards, dones, value) in enumerate(zip(mb_rewards, mb_dones, last_value)):
                rewards = rewards.tolist()
                #print("rewards_tolist shape: " + str(len(rewards)))
                dones = dones.tolist()
                #print("dones_tolist shape: " + str(len(dones)))
                if dones[-1] == 0:
                    #print("here")
                    rewards = discount_with_dones(rewards+[value], dones+[0], self.gamma)[:-1]
                else:
                    #print("there")
                    rewards = discount_with_dones(rewards, dones, self.gamma)

                #print("rewards shape: " + str(np.shape(rewards)))
                #print("mb_rewards shape: " + str(np.shape(mb_rewards)))
                mb_rewards[n] = rewards

        mb_actions = mb_actions.reshape(self.batch_action_shape)

        mb_rewards = mb_rewards.flatten()
        mb_values = mb_values.flatten()
        mb_masks = mb_masks.flatten()


        #exp_acer = [enc_obs, mb_obs_acer, mb_actions_acer, mb_rewards_acer, mb_mus, mb_dones_acer, mb_masks_acer]

        #ll = list(map(sf01, (mb_obs_ppo2, mb_returns, mb_dones_ppo2, mb_actions_ppo2, mb_values_ppo2, mb_neglogpacs)))
        #exp_ppo2 = [ll[0], ll[1], ll[2], ll[3], ll[4], ll[5], mb_states, epinfos]

        #self.q_exp[1].put(exp_ppo2)
        #self.q_exp[2].put(exp_acer)

        ret = []
        ret.append([mb_obs, mb_states, mb_rewards, mb_masks, mb_actions, mb_values, epinfos])

        #while not self.q_exp[0].empty():
        #    exp_a2c = self.q_exp[0].get()
        #    ret.append(exp_a2c)

        #return mb_obs, mb_states, mb_rewards, mb_masks, mb_actions, mb_values, epinfos
        return ret

def sf01(arr):
    """
    swap and then flatten axes 0 and 1
    """
    s = arr.shape
    return arr.swapaxes(0, 1).reshape(s[0] * s[1], *s[2:])
