import os
import subprocess
import sys
import importlib
import inspect
import functools

import tensorflow as tf
import numpy as np

from baselines.common import tf_util as U


def store_args(method):
    """Stores provided method args as instance attributes.
    """
    argspec = inspect.getfullargspec(method)
    defaults = {}
    if argspec.defaults is not None:
        defaults = dict(
            zip(argspec.args[-len(argspec.defaults):], argspec.defaults))
    if argspec.kwonlydefaults is not None:
        defaults.update(argspec.kwonlydefaults)
    arg_names = argspec.args[1:]

    @functools.wraps(method)
    def wrapper(*positional_args, **keyword_args):
        self = positional_args[0]
        # Get default arg values
        args = defaults.copy()
        # Add provided arg values
        for name, value in zip(arg_names, positional_args[1:]):
            args[name] = value
        args.update(keyword_args)
        self.__dict__.update(args)
        return method(*positional_args, **keyword_args)

    return wrapper


def import_function(spec):
    """Import a function identified by a string like "pkg.module:fn_name".
    """
    mod_name, fn_name = spec.split(':')
    module = importlib.import_module(mod_name)
    fn = getattr(module, fn_name)
    return fn


def flatten_grads(var_list, grads):
    """Flattens a variables and their gradients.
    """
    return tf.concat([tf.reshape(grad, [U.numel(v)])
                      for (v, grad) in zip(var_list, grads)], 0)


def nn(input, layers_sizes, reuse=None, flatten=False, name="", last_activation=None):
    """Creates a simple neural network
    """
    for i, size in enumerate(layers_sizes):
        activation = tf.nn.relu if i < len(layers_sizes) - 1 else last_activation
        input = tf.layers.dense(inputs=input,
                                units=size,
                                kernel_initializer=tf.contrib.layers.xavier_initializer(),
                                reuse=reuse,
                                name=name + '_' + str(i))
        if activation:
            input = activation(input)
    if flatten:
        assert layers_sizes[-1] == 1
        input = tf.reshape(input, [-1])
    return input


def install_mpi_excepthook():
    import sys
    from mpi4py import MPI
    old_hook = sys.excepthook

    def new_hook(a, b, c):
        old_hook(a, b, c)
        sys.stdout.flush()
        sys.stderr.flush()
        MPI.COMM_WORLD.Abort()
    sys.excepthook = new_hook


def mpi_fork(n, cpu_index=-1, extra_mpi_args=[]):
    """Re-launches the current script with workers
    Returns "parent" for original parent, "child" for MPI children
    """
    if cpu_index == -1:
        if n <= 1:
            return "child"
    if os.getenv("IN_MPI") is None:
        env = os.environ.copy()
        env.update(
            MKL_NUM_THREADS="1",
            OMP_NUM_THREADS="1",
            IN_MPI="1"
        )
        # "-bind-to core" is crucial for good performance
        mpi_cpu_str = ""
        for i in range(0,1):
            mpi_cpu_str += str(i) + ','
        if cpu_index == -1:
            args = ["mpirun", "-np", str(n), "--cpu-set", mpi_cpu_str[:-1]] + extra_mpi_args + [sys.executable]
        else:
            args = ["mpirun", "-np", str(n), "--cpu-set", str(n%20)] + extra_mpi_args + [sys.executable]
        # args = ["mpirun", "-np", str(n), "--cpu-set", "4,5,6,7"] + extra_mpi_args + [sys.executable]
        # args = ["mpirun", "-np", str(n)] + extra_mpi_args + [sys.executable]

        args += sys.argv
        subprocess.check_call(args, env=env)
        return "parent"
    else:
        install_mpi_excepthook()
        return "child"


def convert_episode_to_batch_major(episode):
    """Converts an episode to have the batch dimension in the major (first)
    dimension.
    """
    episode_batch = {}
    for key in episode.keys():
        val = np.array(episode[key]).copy()
        # make inputs batch-major instead of time-major
        episode_batch[key] = val.swapaxes(0, 1)

    return episode_batch


def transitions_in_episode_batch(episode_batch):
    """Number of transitions in a given episode batch.
    """
    shape = episode_batch['u'].shape
    return shape[0] * shape[1]


def reshape_for_broadcasting(source, target):
    """Reshapes a tensor (source) to have the correct shape and dtype of the target
    before broadcasting it with MPI.
    """
    dim = len(target.get_shape())
    shape = ([1] * (dim - 1)) + [-1]
    return tf.reshape(tf.cast(source, target.dtype), shape)

def features(input,penulti_linear,feature_size=50):

    out = input 
    for i in range(1,5):
        out = tf.layers.conv2d(out,filters=32,kernel_size=[3,3],strides=2,padding='same',activation=tf.nn.relu,name="cov2d_%d" % i)
    
    shape = out.get_shape().as_list()        # a list: [None, 9, 2]
    dim = np.prod(shape[1:])            # dim = prod(9,2) = 18
    x = tf.reshape(out, [-1, dim])           # -1 means "all"
    
    x =  tf.layers.dense(inputs=x,units=penulti_linear,activation=tf.nn.relu,kernel_initializer=tf.contrib.layers.xavier_initializer(),reuse=False)
    feature =  tf.layers.dense(inputs=x,units=feature_size,kernel_initializer=tf.contrib.layers.xavier_initializer(),reuse=False)

    # print("after_convolution_feature",feature.get_shape())

    return feature

# DQN NIPS 2013 and A3C
def featuresDQN13(input,penulti_linear,feature_size=50):

    out = input 
    out = tf.layers.conv2d(out,filters=16,kernel_size=[8,8],strides=4,padding='same',activation=tf.nn.relu,name="cov2d_%d" % 1)
    out = tf.layers.conv2d(out,filters=32,kernel_size=[4,4],strides=2,padding='same',activation=tf.nn.relu,name="cov2d_%d" % 2)
    
    shape = out.get_shape().as_list()        # a list: [None, 9, 2]
    dim = np.prod(shape[1:])            # dim = prod(9,2) = 18
    x = tf.reshape(out, [-1, dim])           # -1 means "all"
    
    # x =  tf.layers.dense(inputs=x,units=penulti_linear,activation=tf.nn.relu,kernel_initializer=tf.contrib.layers.xavier_initializer(),reuse=False)
    feature =  tf.nn.relu(tf.layers.dense(inputs=x,units=feature_size,kernel_initializer=tf.contrib.layers.xavier_initializer(),reuse=False))

    # print("after_convolution_feature",feature.get_shape())

    return feature

# DQN Nature 2015 paper
def featuresDQN15(input, penulti_linear, feature_size=50, reuse=False):

    out = input 
    out = tf.layers.conv2d(out,filters=32,kernel_size=[8,8],strides=4,padding='same',activation=tf.nn.relu,reuse=reuse,name="cov2d_%d" % 1)
    out = tf.layers.conv2d(out,filters=64,kernel_size=[4,4],strides=2,padding='same',activation=tf.nn.relu,reuse=reuse,name="cov2d_%d" % 2)
    out = tf.layers.conv2d(out,filters=64,kernel_size=[3,3],strides=1,padding='same',activation=tf.nn.relu,reuse=reuse,name="cov2d_%d" % 3)
    
    shape = out.get_shape().as_list()        # a list: [None, 9, 2]
    dim = np.prod(shape[1:])            # dim = prod(9,2) = 18
    x = tf.reshape(out, [-1, dim])           # -1 means "all"
    
    # x =  tf.layers.dense(inputs=x,units=penulti_linear,activation=tf.nn.relu,kernel_initializer=tf.contrib.layers.xavier_initializer(),reuse=False)
    # feature =  tf.layers.dense(inputs=x,units=feature_size,kernel_initializer=tf.contrib.layers.xavier_initializer(),reuse=False)
    # feature =  tf.nn.relu(tf.layers.dense(inputs=x,units=feature_size,kernel_initializer=tf.contrib.layers.xavier_initializer(),reuse=False))
    x =  tf.layers.dense(inputs=x,units=penulti_linear,kernel_initializer=tf.contrib.layers.xavier_initializer(),activation='relu',name='dense1',reuse=reuse)
    feature =  tf.layers.dense(inputs=x,units=feature_size,kernel_initializer=tf.contrib.layers.xavier_initializer(),activation=None,name='dense2',reuse=reuse)

    # print("after_convolution_feature",feature.get_shape())

    return feature

def convDQN15(input, feature_size=256, reuse=False):

    out = input 
    out = tf.layers.conv2d(out,filters=32,kernel_size=[8,8],strides=4,padding='same',activation=tf.nn.relu,reuse=reuse,name="cov2d_%d" % 1)
    out = tf.layers.conv2d(out,filters=64,kernel_size=[4,4],strides=2,padding='same',activation=tf.nn.relu,reuse=reuse,name="cov2d_%d" % 2)
    out = tf.layers.conv2d(out,filters=64,kernel_size=[3,3],strides=1,padding='same',activation=tf.nn.relu,reuse=reuse,name="cov2d_%d" % 3)
    
    shape = out.get_shape().as_list()        # a list: [None, 9, 2]
    dim = np.prod(shape[1:])            # dim = prod(9,2) = 18
    x = tf.reshape(out, [-1, dim])           # -1 means "all"
    
    # x =  tf.layers.dense(inputs=x,units=penulti_linear,activation=tf.nn.relu,kernel_initializer=tf.contrib.layers.xavier_initializer(),reuse=False)
    # feature =  tf.layers.dense(inputs=x,units=feature_size,kernel_initializer=tf.contrib.layers.xavier_initializer(),reuse=False)
    # feature =  tf.nn.relu(tf.layers.dense(inputs=x,units=feature_size,kernel_initializer=tf.contrib.layers.xavier_initializer(),reuse=False))
    # x =  tf.layers.dense(inputs=x,units=penulti_linear,kernel_initializer=tf.contrib.layers.xavier_initializer(),activation='relu',name='dense1',reuse=reuse)
    # feature =  tf.layers.dense(inputs=x,units=feature_size,kernel_initializer=tf.contrib.layers.xavier_initializer(),activation=None,name='dense2',reuse=reuse)

    # print("after_convolution_feature",feature.get_shape())

    return x

# def convDQN15(input, feature_size=256, reuse=False):

#     out = input 
#     # out = tf.layers.conv2d(out,filters=32,kernel_size=[8,8],strides=4,padding='same',activation=tf.nn.relu,reuse=reuse,name="cov2d_%d" % 1)
#     out = tf.layers.conv2d(out,filters=32,kernel_size=[3,3],strides=2,padding='same',activation=tf.nn.relu,reuse=reuse,name="cov2d_%d" % 2)
#     out = tf.layers.conv2d(out,filters=64,kernel_size=[2,2],strides=1,padding='same',activation=tf.nn.relu,reuse=reuse,name="cov2d_%d" % 3)
    
#     shape = out.get_shape().as_list()        # a list: [None, 9, 2]
#     dim = np.prod(shape[1:])            # dim = prod(9,2) = 18
#     x = tf.reshape(out, [-1, dim])           # -1 means "all"
    
#     # x =  tf.layers.dense(inputs=x,units=penulti_linear,activation=tf.nn.relu,kernel_initializer=tf.contrib.layers.xavier_initializer(),reuse=False)
#     # feature =  tf.layers.dense(inputs=x,units=feature_size,kernel_initializer=tf.contrib.layers.xavier_initializer(),reuse=False)
#     # feature =  tf.nn.relu(tf.layers.dense(inputs=x,units=feature_size,kernel_initializer=tf.contrib.layers.xavier_initializer(),reuse=False))
#     # x =  tf.layers.dense(inputs=x,units=penulti_linear,kernel_initializer=tf.contrib.layers.xavier_initializer(),activation='relu',name='dense1',reuse=reuse)
#     # feature =  tf.layers.dense(inputs=x,units=feature_size,kernel_initializer=tf.contrib.layers.xavier_initializer(),activation=None,name='dense2',reuse=reuse)

#     # print("after_convolution_feature",feature.get_shape())

#     return x



def denseDQN15(x, penulti_linear, feature_size=50, reuse=False, index=1):
    x =  tf.layers.dense(inputs=x,units=penulti_linear,kernel_initializer=tf.contrib.layers.xavier_initializer(),activation=tf.nn.relu,name='dense'+str(index),reuse=reuse)
    x =  tf.layers.dense(inputs=x,units=penulti_linear,kernel_initializer=tf.contrib.layers.xavier_initializer(),activation=tf.nn.relu,name='dense'+str(index+1),reuse=reuse)
    feature =  tf.layers.dense(inputs=x,units=feature_size,kernel_initializer=tf.contrib.layers.xavier_initializer(),activation=None,name='dense'+str(index+2),reuse=reuse)
    return feature

def featuresDQN15_Q(input, input_action, penulti_linear, feature_size=50, reuse=False):

    out = input 
    out = tf.layers.conv2d(out,filters=32,kernel_size=[8,8],strides=4,padding='same',activation=tf.nn.relu,reuse=reuse,name="cov2d_%d" % 1)
    out = tf.layers.conv2d(out,filters=32,kernel_size=[4,4],strides=2,padding='same',activation=tf.nn.relu,reuse=reuse,name="cov2d_%d" % 2)
    out = tf.layers.conv2d(out,filters=64,kernel_size=[3,3],strides=1,padding='same',activation=tf.nn.relu,reuse=reuse,name="cov2d_%d" % 3)
    
    shape = out.get_shape().as_list()        # a list: [None, 9, 2]
    dim = np.prod(shape[1:])            # dim = prod(9,2) = 18
    x = tf.reshape(out, [-1, dim])           # -1 means "all"
    
    # x =  tf.layers.dense(inputs=x,units=penulti_linear,activation=tf.nn.relu,kernel_initializer=tf.contrib.layers.xavier_initializer(),reuse=False)
    # feature =  tf.layers.dense(inputs=x,units=feature_size,kernel_initializer=tf.contrib.layers.xavier_initializer(),reuse=False)
    # feature =  tf.nn.relu(tf.layers.dense(inputs=x,units=feature_size,kernel_initializer=tf.contrib.layers.xavier_initializer(),reuse=False))
    x =  tf.layers.dense(inputs=x,units=penulti_linear,kernel_initializer=tf.contrib.layers.xavier_initializer(),activation='relu',name='dense1',reuse=reuse)
    x = tf.concat(axis=1, values=[x, input_action])
    feature =  tf.layers.dense(inputs=x,units=feature_size,kernel_initializer=tf.contrib.layers.xavier_initializer(),activation=None,name='dense2',reuse=reuse)

    # print("after_convolution_feature",feature.get_shape())

    return feature

# Doom
def featuresDoom(input,penulti_linear,feature_size=50):

    out = input 
    out = tf.layers.conv2d(out,filters=8,kernel_size=[5,5],strides=4,padding='same',activation=tf.nn.elu,name="cov2d_%d" % 1)
    out = tf.layers.conv2d(out,filters=16,kernel_size=[3,3],strides=2,padding='same',activation=tf.nn.elu,name="cov2d_%d" % 2)
    out = tf.layers.conv2d(out,filters=32,kernel_size=[3,3],strides=2,padding='same',activation=tf.nn.elu,name="cov2d_%d" % 3)
    out = tf.layers.conv2d(out,filters=64,kernel_size=[3,3],strides=2,padding='same',activation=tf.nn.elu,name="cov2d_%d" % 4)
    
    shape = out.get_shape().as_list()        # a list: [None, 9, 2]
    dim = np.prod(shape[1:])            # dim = prod(9,2) = 18
    x = tf.reshape(out, [-1, dim])           # -1 means "all"
    
    # x =  tf.layers.dense(inputs=x,units=penulti_linear,activation=tf.nn.relu,kernel_initializer=tf.contrib.layers.xavier_initializer(),reuse=False)
    # feature =  tf.layers.dense(inputs=x,units=feature_size,kernel_initializer=tf.contrib.layers.xavier_initializer(),reuse=False)
    feature =  tf.nn.elu(tf.layers.dense(inputs=x,units=feature_size,kernel_initializer=tf.contrib.layers.xavier_initializer(),reuse=False))

    # print("after_convolution_feature",feature.get_shape())

    return feature
