import os
import time
import numpy as np
import os.path as osp
from baselines.logger import configure
#from baselines import logger
from collections import deque
from baselines.common import explained_variance, set_global_seeds
from baselines.common.policies import build_policy
from importlib import import_module
import sys
import multiprocessing
from baselines.common.cmd_util import make_vec_env, make_env
from baselines.common.tf_util import get_session
import tensorflow as tf
from baselines.common.vec_env import VecFrameStack, VecNormalize
import gym
import re
from collections import defaultdict
from baselines.common.vec_env.vec_video_recorder import VecVideoRecorder

import math

import os
import gc
from array import array

# Get number of cores reserved by the batch system (NSLOTS is automatically set, or use 1 if not)
NUMCORES=int(os.getenv("NSLOTS",1))
print("Using", NUMCORES, "core(s)" )

#tf.config.threading.set_inter_op_parallelism_threads(NUMCORES) 
#tf.config.threading.set_intra_op_parallelism_threads(NUMCORES)
#tf.config.set_soft_device_placement(1)

#import mpi4py
#mpi4py.rc.recv_mprobe = False

try:
    from mpi4py import MPI
except ImportError:
    MPI = None
from baselines.ppo2.runner_exp2 import Runner

def get_alg_module(alg, submodule=None):
    submodule = submodule or alg
    try:
        # first try to import the alg module from baselines
        alg_module = import_module('.'.join(['baselines', alg, submodule]))
    except ImportError:
        # then from rl_algs
        alg_module = import_module('.'.join(['rl_' + 'algs', alg, submodule]))

    return alg_module

def get_learn_function_defaults(alg, env_type):
    try:
        alg_defaults = get_alg_module(alg, 'defaults')
        kwargs = getattr(alg_defaults, env_type)()
    except (ImportError, AttributeError):
        kwargs = {}
    return kwargs

def get_default_network(env_type):
    if env_type in {'atari', 'retro'}:
        return 'cnn'
    else:
        return 'mlp'

def constfn(val):
    def f(_):
        return val
    return f

_game_envs = defaultdict(set)
for env in gym.envs.registry.all():
    # TODO: solve this with regexes
    env_type = env.entry_point.split(':')[0].split('.')[-1]
    _game_envs[env_type].add(env.id)

# reading benchmark names directly from retro requires
# importing retro here, and for some reason that crashes tensorflow
# in ubuntu
_game_envs['retro'] = {
    'BubbleBobble-Nes',
    'SuperMarioBros-Nes',
    'TwinBee3PokoPokoDaimaou-Nes',
    'SpaceHarrier-Nes',
    'SonicTheHedgehog-Genesis',
    'Vectorman-Genesis',
    'FinalFight-Snes',
    'SpaceInvaders-Snes',
}

def get_env_type(args):
    env_id = args.env

    if args.env_type is not None:
        return args.env_type, env_id

    # Re-parse the gym registry, since we could have new envs since last time.
    for env in gym.envs.registry.all():
        env_type = env.entry_point.split(':')[0].split('.')[-1]
        _game_envs[env_type].add(env.id)  # This is a set so add is idempotent

    if env_id in _game_envs.keys():
        env_type = env_id
        env_id = [g for g in _game_envs[env_type]][0]
    else:
        env_type = None
        for g, e in _game_envs.items():
            if env_id in e:
                env_type = g
                break
        if ':' in env_id:
            env_type = re.sub(r':.*', '', env_id)
        assert env_type is not None, 'env_id {} is not recognized in env types'.format(env_id, _game_envs.keys())

    return env_type, env_id

def build_env(args):
    ncpu = multiprocessing.cpu_count()
    if sys.platform == 'darwin': ncpu //= 2
    nenv = args.num_env or ncpu
    alg = args.alg
    seed = args.seed

    env_type, env_id = get_env_type(args)

    if env_type in {'atari', 'retro'}:
        if alg == 'deepq':
            env = make_env(env_id, env_type, seed=seed, wrapper_kwargs={'frame_stack': True})
        elif alg == 'trpo_mpi':
            env = make_env(env_id, env_type, seed=seed)
        else:
            frame_stack_size = 4
            env = make_vec_env(env_id, env_type, nenv, seed, gamestate=args.gamestate, reward_scale=args.reward_scale)
            env = VecFrameStack(env, frame_stack_size)

    else:
        config = tf.ConfigProto(allow_soft_placement=True,
                               intra_op_parallelism_threads=1,
                               inter_op_parallelism_threads=1)
        config.gpu_options.allow_growth = True
        get_session(config=config)

        flatten_dict_observations = alg not in {'her'}
        env = make_vec_env(env_id, env_type, args.num_env or 1, seed, reward_scale=args.reward_scale, flatten_dict_observations=flatten_dict_observations)

        if env_type == 'mujoco':
            env = VecNormalize(env, use_tf=True)

    return env


def learn(args, extra_args, uni_tstart, network, suffix=None, eval_env = None, nsteps=2048, ent_coef=0.0, lr=3e-4,
            vf_coef=0.5,  max_grad_norm=0.5, gamma=0.99, lam=0.95,
            log_interval=1, nminibatches=4, noptepochs=4, cliprange=0.2,
            save_interval=0, load_path=None, model_fn=None, update_fn=None, init_fn=None, mpi_rank_weight=1, comm=None):
    '''
    Learn policy using PPO algorithm (https://arxiv.org/abs/1707.06347)

    Parameters:
    ----------

    network:                          policy network architecture. Either string (mlp, lstm, lnlstm, cnn_lstm, cnn, cnn_small, conv_only - see baselines.common/models.py for full list)
                                      specifying the standard network architecture, or a function that takes tensorflow tensor as input and returns
                                      tuple (output_tensor, extra_feed) where output tensor is the last network layer output, extra_feed is None for feed-forward
                                      neural nets, and extra_feed is a dictionary describing how to feed state into the network for recurrent neural nets.
                                      See common/models.py/lstm for more details on using recurrent nets in policies

    env: baselines.common.vec_env.VecEnv     environment. Needs to be vectorized for parallel environment simulation.
                                      The environments produced by gym.make can be wrapped using baselines.common.vec_env.DummyVecEnv class.


    nsteps: int                       number of steps of the vectorized environment per update (i.e. batch size is nsteps * nenv where
                                      nenv is number of environment copies simulated in parallel)

    total_timesteps: int              number of timesteps (i.e. number of actions taken in the environment)

    ent_coef: float                   policy entropy coefficient in the optimization objective

    lr: float or function             learning rate, constant or a schedule function [0,1] -> R+ where 1 is beginning of the
                                      training and 0 is the end of the training.

    vf_coef: float                    value function loss coefficient in the optimization objective

    max_grad_norm: float or None      gradient norm clipping coefficient

    gamma: float                      discounting factor

    lam: float                        advantage estimation discounting factor (lambda in the paper)

    log_interval: int                 number of timesteps between logging events

    nminibatches: int                 number of training minibatches per update. For recurrent policies,
                                      should be smaller or equal than number of environments run in parallel.

    noptepochs: int                   number of training epochs per update

    cliprange: float or function      clipping range, constant or schedule function [0,1] -> R+ where 1 is beginning of the training
                                      and 0 is the end of the training

    save_interval: int                number of timesteps between saving events

    load_path: str                    path to load the model from

    **network_kwargs:                 keyword arguments to the policy / network builder. See baselines.common/policies.py/build_policy and arguments to a particular type of network
                                      For instance, 'mlp' network architecture has arguments num_hidden and num_layers.



    '''
    comm = MPI.COMM_WORLD
    rank = comm.Get_rank()
    size = comm.Get_size()
    #size = 0
    print("aaaaa")

    logger = configure(args.log_path, log_suffix='-ppo2'+suffix)

    env = build_env(args)
    if args.save_video_interval != 0:
        env = VecVideoRecorder(env, osp.join(logger.get_dir(), "videos"), record_video_trigger=lambda x: x % args.save_video_interval == 0, video_length=args.save_video_length)

    print("bbbbb")

    total_timesteps = int(args.num_timesteps)
    seed = args.seed

    set_global_seeds(seed)

    env_type, env_id = get_env_type(args)

    print("ccccc")

    if isinstance(lr, float): lr = constfn(lr)
    else: assert callable(lr)
    if isinstance(cliprange, float): cliprange = constfn(cliprange)
    else: assert callable(cliprange)
    total_timesteps = int(total_timesteps)

    policy = build_policy(env, network)

    print("ddddd")

    # Get the nb of env
    nenvs = env.num_envs
    #print("ppo2 nenvs: " + str(nenvs))

    # Get state_space and action_space
    ob_space = env.observation_space
    ac_space = env.action_space

    # Calculate the batch_size
    nbatch = nenvs * nsteps
    nbatch_train = nbatch // nminibatches
    #is_mpi_root = (MPI is None or MPI.COMM_WORLD.Get_rank() == 0)
    is_mpi_root = True

    print("eeeee")

    # Instantiate the model object (that creates act_model and train_model)
    if model_fn is None:
        from baselines.ppo2.model import Model
        model_fn = Model

    print("eeeee1")

    model = model_fn(model_type='ppo2_model', policy=policy, ob_space=ob_space, ac_space=ac_space, nbatch_act=nenvs, nbatch_train=nbatch_train,
                    nsteps=nsteps, ent_coef=ent_coef, vf_coef=vf_coef,
                    max_grad_norm=max_grad_norm, comm=comm, mpi_rank_weight=mpi_rank_weight)
    model_a2c = model_fn(model_type='a2c_model', policy=policy, ob_space=ob_space, ac_space=ac_space, nbatch_act=nenvs, nbatch_train=20,
                    nsteps=nsteps, ent_coef=ent_coef, vf_coef=vf_coef,
                    max_grad_norm=max_grad_norm, comm=comm, mpi_rank_weight=mpi_rank_weight)
    model_acer = model_fn(model_type='acer_model', policy=policy, ob_space=ob_space, ac_space=ac_space, nbatch_act=nenvs, nbatch_train=84,
                    nsteps=nsteps, ent_coef=ent_coef, vf_coef=vf_coef,
                    max_grad_norm=max_grad_norm, comm=comm, mpi_rank_weight=mpi_rank_weight)
    print("eeeee2")
    if load_path is not None:
        model.load(load_path)
    
    print("fffff")

    # Instantiate the runner object
    runner = Runner(env=env, model=model, nsteps=nsteps, gamma=gamma, lam=lam, model_a2c=model_a2c, model_acer=model_acer)
    if eval_env is not None:
        eval_runner = Runner(env = eval_env, model = model, nsteps = nsteps, gamma = gamma, lam = lam, EVAL=True)

    epinfobuf = deque(maxlen=100)
    if eval_env is not None:
        eval_epinfobuf = deque(maxlen=100)

    if init_fn is not None:
        init_fn()

    print("ggggg")

    # Start total timer
    tfirststart = time.perf_counter()

    rewmean = [None,None,None]

    print("hhhhh")
    #buf = bytearray(1 << 30)  # receive buffer
    #buf = memoryview(np.arange(1 << 70).tobytes())
    #buf = np.zeros(int(2^31-1), dtype='double')
    bufs = np.zeros(int(1), dtype='double')
    buf = [np.zeros((8, 8, 4, 32), dtype='double'), np.zeros((1, 32, 1, 1), dtype='double'), np.zeros((4, 4, 32, 64), dtype='double'), np.zeros((1, 64, 1, 1), dtype='double'), np.zeros((3, 3, 64, 64), dtype='double'), np.zeros((1, 64, 1, 1), dtype='double'), np.zeros((3136, 512), dtype='double'), np.zeros((512,), dtype='double'), np.zeros((512, 6), dtype='double'), np.zeros((6,), dtype='double'), np.zeros((512, 6), dtype='double'), np.zeros((6,), dtype='double'), np.zeros((512, 1), dtype='double'), np.zeros((1,), dtype='double')]

    Params_cur = [[] for i in range(size)]
    Params_pre = [[] for i in range(size)]
    Arr_index = [0 for i in range(size)]
    nupdates = total_timesteps//nbatch
    for update in range(1, nupdates+1):
        #gc.collect()
        print("ppo2 update: " + str(update))
        assert nbatch % nminibatches == 0
        # Start timer
        tstart = time.perf_counter()
        frac = 1.0 - (update - 1.0) / nupdates
        # Calculate the learning rate
        lrnow = lr(frac)
        # Calculate the cliprange
        cliprangenow = cliprange(frac)

        if update % log_interval == 0 and is_mpi_root: logger.log('Stepping environment...', level=20)

        #while not q_exp[1][0].empty():
        #    rewmean[0] = q_exp[1][0].get()
        #while not q_exp[1][2].empty():
        #    rewmean[2] = q_exp[1][2].get()

        #rew tag=0, model tag=1
        bufs.fill(0) #[:] = 0 #b'\x00' * len(buf)
        req = comm.Irecv(bufs, source=0, tag=0)
        while True:
            if not req.Get_status():
                time.sleep(0.4)
                print('BEFORE wait!')
                #print('Arr index: ' + str(arr_count))
                status = req.Test() #wait()
                print('AFTER wait!')
            else:
                #print('Arr index: ' + str(arr_count))
                status = req.Test()
            if status:
                print('ppo2 before rew 1')
                rewmean[0] = bufs[0] #status[1][0]
                print('ppo2 after rew 1')
                del status
                gc.collect()
                bufs.fill(0) #[:] = 0 #b'\x00' * len(buf)
                req = comm.Irecv(bufs, source=0, tag=0)
            else:
                req.Cancel()
                req.Free()
                break

        bufs.fill(0) #[:] = 0 #b'\x00' * len(buf)
        req = comm.Irecv(bufs, source=2, tag=0)
        while True:
            if not req.Get_status():
                time.sleep(0.4)
                print('BEFORE wait!')
                #print('Arr index: ' + str(arr_count))
                status = req.Test() #wait()
                print('AFTER wait!')
            else:
                status = req.Test()
            if status:
                print('ppo2 before rew 2')
                rewmean[2] = bufs[0] #status[1][0]
                print('ppo2 after rew 2')
                bufs.fill(0) #[:] = 0 #b'\x00' * len(buf)
                req = comm.Irecv(bufs, source=2, tag=0)
            else:
                req.Cancel()
                req.Free()
                break

        ret = runner.run(rewmean, Arr_index, Params_cur, Params_pre)
        for i in range(len(ret)):
        # Get minibatch
            #print("ppo2 i: " + str(i))
            obs, returns, masks, actions, values, neglogpacs, states, epinfos = ret[i]#runner.run() #pylint: disable=E0632
            if eval_env is not None and i == 0:
                eval_obs, eval_returns, eval_masks, eval_actions, eval_values, eval_neglogpacs, eval_states, eval_epinfos = eval_runner.run() #pylint: disable=E0632

            if update % log_interval == 0 and is_mpi_root: logger.log('Done.', level=20)

            epinfobuf.extend(epinfos)
            if eval_env is not None and i == 0:
                eval_epinfobuf.extend(eval_epinfos)

            # Here what we're going to do is for each minibatch calculate the loss and append it.
            mblossvals = []
            if states is None: # nonrecurrent version
                # Index of each element of batch_size
                # Create the indices array
                inds = np.arange(nbatch)
                for _ in range(noptepochs):
                    # Randomize the indexes
                    if i == 0:
                        np.random.shuffle(inds)
                    # 0 to batch_size with batch_train_size step
                    for start in range(0, nbatch, nbatch_train):
                        end = start + nbatch_train
                        mbinds = inds[start:end]
                        #print("ppo2 i: " + str(i))
                        #print("ppo2 obs shape: " + str(np.shape(obs)))
                        #print("ppo2 returns: " + str(np.shape(returns)))
                        #print("ppo2 masks: " + str(np.shape(masks)))
                        #print("ppo2 actions: " + str(actions))
                        #print("ppo2 values: " + str(values))
                        #print("ppo2 neglogpacs: " + str(np.shape(neglogpacs)))
                        slices = (arr[mbinds] for arr in (obs, returns, masks, actions, values, neglogpacs))
                        mblossvals.append(model.train(lrnow, cliprangenow, *slices))
                        if i != 0:
                            break
            else: # recurrent version
                assert nenvs % nminibatches == 0
                envsperbatch = nenvs // nminibatches
                envinds = np.arange(nenvs)
                flatinds = np.arange(nenvs * nsteps).reshape(nenvs, nsteps)
                for _ in range(noptepochs):
                    np.random.shuffle(envinds)
                    for start in range(0, nenvs, envsperbatch):
                        end = start + envsperbatch
                        mbenvinds = envinds[start:end]
                        mbflatinds = flatinds[mbenvinds].ravel()
                        slices = (arr[mbflatinds] for arr in (obs, returns, masks, actions, values, neglogpacs))
                        mbstates = states[mbenvinds]
                        mblossvals.append(model.train(lrnow, cliprangenow, *slices, states=mbstates))


        params = tf.trainable_variables('ppo2_model')
        #for var in params:
        #    print(var)
        param_val = model.sess.run(params)
        #print(param_val)
        print('ppo2 model length: ' + str(len(param_val)))
        for j in range(size):
            if j != rank:
                #pass
                for m in range(len(param_val)):
                    #buf[0] = len(param_val[m])
                    buf[m] = param_val[m]
                    comm.Isend(buf[m], dest=j, tag=1)
        #q_model[0][1].put(param_val)
        #q_model[2][1].put(param_val)
        ppo2rewmean = safemean([epinfo['r'] for epinfo in epinfobuf])
        bufs[0] = ppo2rewmean
        if not math.isnan(ppo2rewmean):
            for k in range(size):
                if k != rank:
                    #pass
                    comm.Isend(bufs, dest=k, tag=0)
            #q_exp[0][1].put(ppo2rewmean)
            #q_exp[2][1].put(ppo2rewmean)
            rewmean[1] = ppo2rewmean

        # Feedforward --> get losses --> update
        lossvals = np.mean(mblossvals, axis=0)
        # End timer
        tnow = time.perf_counter()
        
        wallclock_time = time.time() - uni_tstart

        # Calculate the fps (frame per second)
        fps = int(nbatch / (tnow - tstart))

        if update_fn is not None:
            update_fn(update)

        if update % log_interval == 0 or update == 1:
            #print("ppo2 nbatch: " + str(nbatch))
            # Calculates if value function is a good predicator of the returns (ev > 1)
            # or if it's just worse than predicting nothing (ev =< 0)
            ev = explained_variance(values, returns)
            logger.logkv("misc/serial_timesteps", update*nsteps)
            logger.logkv("misc/nupdates", update)
            logger.logkv("misc/total_timesteps", update*nbatch)
            logger.logkv("fps", fps)
            logger.logkv("misc/explained_variance", float(ev))
            logger.logkv('eprewmean', safemean([epinfo['r'] for epinfo in epinfobuf]))
            logger.logkv('eplenmean', safemean([epinfo['l'] for epinfo in epinfobuf]))
            if eval_env is not None:
                logger.logkv('eval_eprewmean', safemean([epinfo['r'] for epinfo in eval_epinfobuf]) )
                logger.logkv('eval_eplenmean', safemean([epinfo['l'] for epinfo in eval_epinfobuf]) )
            logger.logkv('misc/time_elapsed', tnow - tfirststart)
            logger.logkv("wallclock_time", wallclock_time)
            for (lossval, lossname) in zip(lossvals, model.loss_names):
                logger.logkv('loss/' + lossname, lossval)

            logger.dumpkvs()
        if save_interval and (update % save_interval == 0 or update == 1) and logger.get_dir() and is_mpi_root:
            checkdir = osp.join(logger.get_dir(), 'checkpoints')
            os.makedirs(checkdir, exist_ok=True)
            savepath = osp.join(checkdir, '%.5i'%update)
            print('Saving to', savepath)
            model.save(savepath)

    return model
# Avoid division error when calculate the mean (in our case if epinfo is empty returns np.nan, not return an error)
def safemean(xs):
    return np.nan if len(xs) == 0 else np.mean(xs)



