import os
import os
import networkit as nk
os.environ["OMP_NUM_THREADS"] = "2" # export OMP_NUM_THREADS=4
os.environ["OPENBLAS_NUM_THREADS"] = "2" # export OPENBLAS_NUM_THREADS=4 
os.environ["MKL_NUM_THREADS"] = "2" # export MKL_NUM_THREADS=6
os.environ["VECLIB_MAXIMUM_THREADS"] = "2" # export VECLIB_MAXIMUM_THREADS=4
os.environ["NUMEXPR_NUM_THREADS"] = "2" # export NUMEXPR_NUM_THREADS=6

import sys
import distutils
if len(sys.argv) == 9:
    do_param_print = True
    param_file = sys.argv[1]
    scenario = sys.argv[2]
    agent_conf = sys.argv[3]

    num_layers = int(sys.argv[4])
    num_hidden = int(sys.argv[5])
    disable_ra = bool(sys.argv[6] == 'True')
    disable_ic = bool(sys.argv[7] == 'True')
    poten_reg = float(sys.argv[8])
else:
    do_param_print = False
    in_fifo = sys.argv[1]
    out_fifo = sys.argv[2]
    num_episodes = int(sys.argv[3])
    scenario = sys.argv[4]
    agent_conf = sys.argv[5]

    num_layers = int(sys.argv[6])
    num_hidden = int(sys.argv[7])

    disable_ra = bool(sys.argv[8] == 'True')
    disable_ic = bool(sys.argv[9] == 'True')
    poten_reg = float(sys.argv[10])

from multiprocessing import Process, Pipe
from multiprocessing import Pool

from src.multiagent_mujoco.mujoco_multi import MujocoMulti

from src.envs.particle import Particle
import numpy as np
import time
import datetime
from scipy.optimize import linear_sum_assignment
import math



import os
os.environ['CUDA_VISIBLE_DEVICES'] = ''

try:
    os.environ['USER_RAND_KEY'] = os.path.split(in_fifo)[0]
except:
    os.environ['USER_RAND_KEY'] = ''
import sys
import numpy as np
import tensorflow as tf
tf.config.experimental.set_visible_devices([], 'GPU')
import pickle
import tensorflow_probability as tfp
#print(tf.config.threading.get_inter_op_parallelism_threads())
#print(tf.config.threading.get_intra_op_parallelism_threads())

print(tf.config.threading.set_inter_op_parallelism_threads(1))
print(tf.config.threading.set_intra_op_parallelism_threads(1))


env_args_particle = {"benchmark": False,
                      "episode_limit": 25,
                      "scenario_name": "continuous_pred_prey_3a",
                      "agent_view_radius": 1.5,
                      "score_function": "min",
                      "partial_obs" : True}



particle_map_dict = {'particle' : 'continuous_pred_prey_%da', 'particleh': 'hcontinuous_pred_prey_%da', 'particled': 'drone_delivery_%da'}

import os
import time

def mujoco_stub(recv_conn):
    pid = os.getpid()
    tm = time.time()
    np.random.seed((os.getpid() * int(time.time()*10000000)) % 123456789)
    if "particle" in scenario or 'particleh' in scenario:
        env_args = env_args_particle
        num_agents = int(agent_conf.split('x')[0])
        env_args['scenario_name'] = particle_map_dict[scenario] % num_agents
        if 'particled' in scenario:
            env_args['episode_limit'] = 150
        env = Particle(env_args = env_args)
    elif 'kaz' in scenario:
        env_args = {"scenario": scenario,
                      "agent_conf": agent_conf,
                      "agent_obsk": 0,
                      "episode_limit": 1000}
        env_args['scenario_name'] = 'kaz'
        env = KAZ(env_args = env_args)
    else:
        env_args = {"scenario": scenario,
                      "agent_conf": agent_conf,
                      "agent_obsk": 0,
                      "obs_add_global_pos": False,
                      "episode_limit": 1000}
        env = MujocoMulti(env_args=env_args)
    env_info = env.get_env_info()

    n_actions = env_info["n_actions"]
    n_agents = env_info["n_agents"]
    n_episodes = 1

    first_time = True
    
    while True:
        env.reset()
        terminated = False
        episode_reward = 0
        rewards = []

        obs = env.get_obs()
        state = env.get_state()
        obsarray = np.array(obs)
        
        if 'Ant' in scenario:
            #ant gives sprious info
            obsarray = np.take(obsarray, [0,1,20,21], axis = 1)
        if 'manyagent_ant' in scenario:
            obsarray = np.take(obsarray, [0,1,20,21,40,41,48,49], axis = 1)
        if first_time:
            first_time = False
            recv_conn.send((n_agents, obsarray.shape[1], n_actions))
        
        if do_param_print:
            return

        run_for = recv_conn.recv()
        for i in range(run_for):
            if i > 0:
                obs = env.get_obs()
                state = env.get_state()
                obsarray = np.array(obs)
                if 'Ant' in scenario:
                    #ant gives sprious info
                    obsarray = np.take(obsarray, [0,1,20,21], axis = 1)
                if 'manyagent_ant' in scenario:
                    obsarray = np.take(obsarray, [0,1,20,21,40,41,48,49], axis = 1)           
            recv_conn.send((obsarray, episode_reward))
            
            actions = recv_conn.recv()
            actions = map(lambda x: np.array(x) if (not np.isnan(np.array(x)).any()) else np.zeros(np.array(x).shape), actions)
            actions = list(actions)


            reward, terminated, _ = env.step(actions)
            rewards.append(reward)
            episode_reward += np.sum(reward)
        recv_conn.send(rewards)



def wrapper_scipy_linear_optimize(vt):
    return linear_sum_assignment(vt)[1]


if not do_param_print:
    conns = [Pipe() for i in range(num_episodes)]
    processes = [Process(target = mujoco_stub, args = (ci[1],)) for ci in conns]
    [p.start() for p in processes]
    la_pool = Pool(num_episodes)
else:
    conns = [Pipe() for i in range(1)]
    processes = [Process(target = mujoco_stub, args = (ci[1],)) for ci in conns]
    [p.start() for p in processes]
    la_pool = Pool(1)



import math



@tf.function
def concat_idx(states):
    if not disable_ra:
        p1 = tf.range(tf.shape(states)[-2])
        bw_count = tf.math.ceil(tf.math.log(tf.cast(tf.shape(states)[-2], dtype=tf.float32)) / tf.math.log(2.))
        bw_count =  tf.cast(bw_count, dtype=tf.int32)
        a_bin = tf.cast(tf.math.floormod(tf.bitwise.right_shift(tf.expand_dims(p1,-1), tf.range(bw_count)), 2), dtype=tf.bool)
        a_bin = tf.where(a_bin, tf.ones(tf.shape(a_bin)), -1. * tf.ones(tf.shape(a_bin)))
        reduce_shape = tf.reshape(a_bin, (tf.shape(states)[-2], -1))

        auto_broad = tf.ones(tf.concat([tf.shape(states)[:-1], (1,)], axis = 0))
        bin_broad = auto_broad * reduce_shape

        concatted = tf.concat([states, bin_broad], axis = -1)
        return concatted
    else:
        p1 = tf.range(tf.shape(states)[-2])
        bw_count = tf.math.ceil(tf.math.log(tf.cast(tf.shape(states)[-2], dtype=tf.float32)) / tf.math.log(2.))
        bw_count =  tf.cast(bw_count, dtype=tf.int32)
        a_bin = tf.cast(tf.math.floormod(tf.bitwise.right_shift(tf.expand_dims(p1,-1), tf.range(bw_count)), 2), dtype=tf.bool)
        a_bin = tf.where(a_bin, tf.ones(tf.shape(a_bin)), -1. * tf.ones(tf.shape(a_bin)))
        reduce_shape = tf.reshape(a_bin, (tf.shape(states)[-2], -1))

        auto_broad = tf.ones(tf.concat([tf.shape(states)[:-1], (1,)], axis = 0))
        bin_broad = auto_broad * reduce_shape
        bin_broad = tf.zeros(shape = tf.shape(bin_broad), dtype = bin_broad.dtype)
        concatted = tf.concat([states, bin_broad], axis = -1)
        return concatted


class RoleAssignment(object):
    def __init__(self, numagents, agentspec, num_layers, num_hidden, disable_ra = False):
        self.tfparams = []
        self.numagents = numagents
        self.agentspec = agentspec
        self.disable_ra = disable_ra
        if not disable_ra:
            add_dummy = math.ceil(math.log(numagents, 2.0))

            inp = tf.keras.Input(shape = (agentspec + 2*add_dummy,))
            outps = []
            x = inp
            x = tf.keras.layers.Flatten()(x)
            for j in range(num_layers - 1):
                x = tf.keras.layers.Dense(num_hidden, activation = 'relu')(x)
            out = tf.keras.layers.Dense(1)(x)
                
            self.tfparams.append(tf.keras.Model(inp, out))

    def get_params_spec(self):
        return np.sum(list(map(lambda x: np.sum(list(map(lambda y: np.prod(y.shape), x.weights))), self.tfparams))).astype(int)

    def get_params_spec_raw(self):
        return [self.get_params_spec()]

    def get_weights(self):
        if len(self.tfparams) > 0:
            return self.tfparams[0].weights
        else:
            return []

    def instantiate(self, params):
        assert(self.get_params_spec() == len(params))
        curr = 0
        for tfs in self.tfparams:
            for wts in tfs.weights:
                csize = np.prod(wts.shape)
                wtsassign = params[curr:curr+csize]
                if wts.shape.rank > 1:
                    #glorot init scaling
                    wtsassign = wtsassign * np.sqrt(6. / np.sum(wts.shape))
                wts.assign(wtsassign.reshape(wts.shape))
                curr += csize
    
    @tf.function    
    def value(self, agentstates, device = '/CPU:0'):
        '''
        batch = None
        rows = roles,
        columns = affinity for agents
        '''
        if self.disable_ra:
            with tf.device(device):
                return tf.eye(tf.shape(agentstates)[1], batch_shape = tf.shape(agentstates)[:1]) * 10

        with tf.device(device):
            p1 = tf.range(self.numagents)
            p2 = tf.range(self.numagents)

            tile = tf.expand_dims(tf.tile(tf.expand_dims(p2, axis = 0), (self.numagents, 1)), -1)
            tile2 = tf.expand_dims(tf.tile(tf.expand_dims(p1, axis = 1), (1, self.numagents)), -1)

            stack = tf.concat([tile, tile2], axis = -1)
            bw_count = math.ceil(math.log(self.numagents, 2.0))

            a_bin = tf.cast(tf.math.floormod(tf.bitwise.right_shift(tf.expand_dims(stack,-1), tf.range(bw_count)), 2), dtype=tf.bool)
            a_bin = tf.where(a_bin, tf.ones(tf.shape(a_bin)), -1. * tf.ones(tf.shape(a_bin)))

            reduce_shape = tf.reshape(a_bin, (self.numagents, self.numagents, -1))

            agent_expand = tf.expand_dims(agentstates, 1)
            agent_tile = tf.tile(agent_expand, (1, self.numagents, 1, 1))

            binary_expand = tf.expand_dims(reduce_shape, 0)
            binary_tile = tf.tile(binary_expand, (tf.shape(agentstates)[0], 1, 1, 1))

            ccat = tf.concat([agent_tile, binary_tile], axis = -1)
            

            proc = self.tfparams[0](tf.reshape(ccat, (-1, self.agentspec + 2*bw_count)))

            procreshape = tf.reshape(proc, agentstates.shape[0:2] + [self.numagents])
            
            return procreshape
        


    def compile(self, agentstates, sinkhorn_iters = 32):
        '''
        array of assignments of agents into roles.
        index from agent 0 to agent n-1
        '''
        valuation = self.value(agentstates)
        return tf.stack(la_pool.map(wrapper_scipy_linear_optimize, tf.unstack(valuation*-1.0)))
    
closure_func = None

@tf.function(autograph=False)
def bfgs_wrapper(init_locs, additional_args, device = '/CPU:0'):
    with tf.device(device):
        return tfp.optimizer.lbfgs_minimize(closure_func, additional_args, init_locs, tolerance = 1e-6, x_tolerance=1e-6, f_relative_tolerance = 1e-6, max_line_search_iterations = 20, max_iterations = 40)



class InteractionCapture(object):
    def __init__(self, numroles, rolespec, actionspec, actionbounds, num_layers, num_hidden, disable_ic = False):
        self.numroles = numroles
        self.rolespec = rolespec
        self.actionspec = actionspec
        self.actionbounds = actionbounds
        self.disable_ic = disable_ic


        def create_model(inpsize, num_layers, num_hidden, num_out, edge_mode = False, subgaussian_regularize = [], actionbounds = None, return_raw = False):
            if not edge_mode:
                inp = tf.keras.Input(shape = (inpsize,))
                x = inp
            else:
                inp = tf.keras.Input(shape = inpsize)
                x = tf.keras.layers.Flatten()(inp)

            if len(subgaussian_regularize) > 0:
                xsub = tf.gather(x, subgaussian_regularize, axis = -1)
                #regularized between 0 - 1
                reg = tf.norm(xsub, axis = -1, ord = 2)
            else:
                reg = None
              
            for j in range(num_layers - 1):
                x = tf.keras.layers.Dense(num_hidden, activation = 'relu')(x)

            # 12 orders of magnitude ought to be good enough for anybody!
            out = tf.keras.activations.tanh(tf.keras.layers.Dense(num_out)(x))
            if len(subgaussian_regularize) > 0:
                out = out * poten_reg
            #out = tf.clip_by_value(tf.keras.layers.Dense(1)(x), -6, +6)

            return tf.keras.Model(inp, out)
            
        def dupe_model(inpsize, num_dupes, model):
            inps = [tf.keras.Input(shape = (inpsize,)) for i in range(num_dupes)]
            outs = [model(inpi) for inpi in inps]
            return inps, outs
        
        self.affparams = []
        self.nodeparams = []
        self.edgeparams = []

        ab = np.array(self.actionbounds)
        
        roleencoding = math.ceil(math.log(numroles, 2.0))
        self.roleencoding = roleencoding
        if ab.ndim == 1:
            ab = np.expand_dims(ab, axis = -1)
        
        if not self.disable_ic:
            self.affparams = [create_model(self.rolespec * 2 + roleencoding * 2, num_layers, num_hidden, 1)]
        
        self.nodeparams = create_model(self.rolespec + roleencoding + self.actionspec, num_layers, num_hidden, self.actionspec
                                                                                 , actionbounds = [ab[0][0], ab[1][0]], return_raw = True)
        nins, nouts = dupe_model(self.rolespec + roleencoding + self.actionspec, self.numroles, self.nodeparams)
        stack = tf.stack(nouts)
        self.nodeparams = [tf.keras.Model(inputs = nins, outputs = stack)]

        if not self.disable_ic:
            self.edgeparams= create_model((self.actionspec*2 + roleencoding*2), num_layers, num_hidden, self.actionspec, edge_mode = False, 
                subgaussian_regularize = list(range(0, actionspec)) + list(range(actionspec+roleencoding, 2*actionspec+roleencoding)),
                actionbounds = [ab[0][0], ab[1][0]], return_raw = True)
            nins, nouts = dupe_model((self.actionspec*2 + roleencoding*2), self.numroles, self.edgeparams)
            self.edgeparams = [tf.keras.Model(inputs = nins, outputs = nouts)]

    def get_params_spec(self):
        def size_of_params(tfparams):
            return np.sum(list(map(lambda x: np.sum(list(map(lambda y: np.prod(y.shape), x.weights))), tfparams))).astype(int)

        return size_of_params(self.affparams) + size_of_params(self.nodeparams) + size_of_params(self.edgeparams)

    def get_params_spec_raw(self):

        def size_of_params(tfparams):
            return np.sum(list(map(lambda x: np.sum(list(map(lambda y: np.prod(y.shape), x.weights))), tfparams))).astype(int)

        return [size_of_params(self.affparams), size_of_params(self.nodeparams), size_of_params(self.edgeparams)]


    def get_weights(self):
        if self.disable_ic:
            return self.nodeparams[0].weights
        else:
            return self.affparams[0].weights + self.nodeparams[0].weights + self.edgeparams[0].weights
    
    def instantiate(self, params):
        assert(self.get_params_spec() == len(params))
        curr_ = 0

        def instantiate_tflist(curr, tfparams):
            for tfs in tfparams:
                for wts in tfs.weights:
                    csize = np.prod(wts.shape)
                    wtsassign = params[curr:curr+csize]
                    if wts.shape.rank > 1:
                        #glorot init scaling
                        wtsassign = wtsassign * np.sqrt(6. / np.sum(wts.shape))
                    wts.assign(wtsassign.reshape(wts.shape))
                    curr += csize
            return curr

        curr_ = instantiate_tflist(curr_, self.affparams)
        curr_ = instantiate_tflist(curr_, self.nodeparams)
        curr_ = instantiate_tflist(curr_, self.edgeparams)


        return
    
    @tf.function    
    def connectivity(self, rolestates, device = '/CPU:0'):
        with tf.device(device):
            if self.disable_ic:
                out = tf.ones((tf.shape(rolestates)[0], self.numroles, self.numroles)) * -20.
                return out > 0
            else:
                rperm = tf.transpose(rolestates, perm = [1, 0, 2])
                affinity = []
                for i in range(self.numroles):
                    toapp = []
                    for j in range(self.numroles):
                        val = tf.concat([rperm[i], rperm[j]], axis = -1)
                        toapp.append(val)

                    affinity.append(toapp)
                affst = tf.stack(affinity)
                prep = tf.reshape(affst, (-1, self.rolespec*2 + self.roleencoding*2))
                out = self.affparams[0](prep)

                out = tf.reshape(out, (tf.shape(rolestates)[0], self.numroles, self.numroles))
                res = tf.zeros(tf.concat([[tf.shape(rolestates)[0]], [self.numroles]], axis = 0))
                out = tf.linalg.set_diag(out, res)
                return out > 0.
    

    @tf.function(autograph = False)
    def diff_connectivity(self, rolestates, device = '/CPU:0'):
        with tf.device(device):
            if self.disable_ic:
                out = tf.ones((tf.shape(rolestates)[0], self.numroles, self.numroles)) * -20.
                return out

            rperm = tf.transpose(rolestates, perm = [1, 0, 2])
            affinity = []
            for i in range(self.numroles):
                toapp = []
                for j in range(self.numroles):
                    val = tf.concat([rperm[i], rperm[j]], axis = -1)
                    toapp.append(val)

                affinity.append(toapp)
            affst = tf.stack(affinity)
            prep = tf.reshape(affst, (-1, self.rolespec*2 + self.roleencoding*2))
            out = self.affparams[0](prep)

            out = tf.reshape(out, (tf.shape(rolestates)[0], self.numroles, self.numroles))
            res = -20.*tf.ones(tf.concat([[tf.shape(rolestates)[0]], [self.numroles]], axis = 0))
            out = tf.linalg.set_diag(out, res)
            return out

    @tf.function
    def init_locations(self, num_locs = 1, device = '/CPU:0'):
        with tf.device(device):
            return tf.random.uniform((num_episodes, num_locs, self.numroles*self.actionspec), minval = self.actionbounds[0], maxval = self.actionbounds[1])

    @tf.function(autograph=False)
    def potential_closure(self, actions, additional_args, device = '/CPU:0'):
        with tf.device(device):
            numroles = additional_args[0]
            rolespec = additional_args[1]
            actionspec = additional_args[2]
            rolestates = additional_args[3]
            connectivity = additional_args[4]
            actionbounds = additional_args[5]
            
            idx_width = math.ceil(math.log(numroles, 2.))

            actions_init = tf.zeros(tf.concat([tf.shape(rolestates)[0:2], [actionspec]], axis = 0), dtype = rolestates.dtype)
            rolestates_ = tf.concat([rolestates, actions_init], axis = 2)
            rolestates_us = tf.unstack(rolestates_, axis = 1)
            role_actions = self.nodeparams[0](rolestates_us)
            role_actions = tf.transpose(role_actions, perm = [1,0,2])

            if self.disable_ic:
                return role_actions

            @tf.function(autograph=False)
            def loop_body(i, actions):
                ashape = actions.shape[0:len(actions.shape)-1].concatenate(tf.TensorShape([numroles, actionspec]))
                
                actions_prep = tf.reshape(actions, ashape)
                actions_prep = concat_idx(actions_prep)
                actions_prep = tf.transpose(actions_prep, perm = [0, 2, 1, 3])
                epotsspec = [tf.where(connectivity[i]) for i in list(range(num_episodes))]
                egather = [tf.gather(actions_prep[i], epotsspec[i]) for i in list(range(num_episodes))]

                shuffles = [[tf.squeeze(tf.where(ep[:, 0] == k), axis = -1) for k in list(range(numroles))] for i, ep in enumerate(epotsspec)]
                gathered_roled = [[tf.transpose(tf.gather(eg,shuffles[i][k]), perm = [0,2,1,3]) for k in list(range(numroles))] for i, eg in enumerate(egather)]
                gathered_concatted = [tf.reshape(tf.concat([gii[idx] for gii in gathered_roled], axis = 0), (-1, 2*actionspec + 2*self.roleencoding)) for idx in list(range(numroles))]
                mapped = self.edgeparams[0](gathered_concatted)

                mapped_numlocs_rev = [tf.reshape(mi, (-1, tf.shape(actions)[1], actionspec)) for mi in mapped]
                mapped_episode_rev = [tf.split(mi, tf.stack([tf.shape(si[i])[0] for si in shuffles]),  axis = 0) for i, mi in enumerate(mapped_numlocs_rev)]
                dynamic_sum = tf.stack([tf.stack([tf.math.reduce_sum(mii, axis = 0) for mii in mi]) for i, mi in enumerate(mapped_episode_rev)])
                dynamic_sum = tf.transpose(dynamic_sum, perm = [1,2,0,3])

                role_actions_expand = tf.expand_dims(rolestates, axis = 1)
                role_actions_expand_repeat = tf.repeat(role_actions_expand, tf.shape(dynamic_sum)[1], axis = 1)
                combined_ = tf.concat([role_actions_expand_repeat, dynamic_sum], axis = 3)

                roleprep = tf.reshape(combined_, [-1, numroles, actionspec+self.roleencoding+rolespec])
                roleprep_2 = tf.unstack(roleprep, axis = 1)

                iter_next = self.nodeparams[0](roleprep_2)
                iter_next = tf.transpose(iter_next, perm = [1, 0, 2])

                dynamic_sum = tf.reshape(iter_next, actions.shape)
                return tf.add(i, 1), dynamic_sum

            i = tf.constant(0)
            c = lambda i, p:tf.less(i, 30)
            b = lambda i, p:loop_body(i, p)
            r = tf.while_loop(c, b, [i, actions])
            
            @tf.function(autograph = False)
            def pwisedist_bcast(x):
                @tf.function(autograph = False)
                def pwisedist(xi):
                    xi = tf.reshape(xi, (tf.shape(xi)[0],-1))
                    r = tf.math.reduce_sum(xi*xi, 1)
                    r = tf.reshape(r, [-1, 1])
                    D = r - 2*tf.matmul(xi, tf.transpose(xi)) + tf.transpose(r)
                    D = D / ((r + tf.transpose(r)) / 2.0)
                    return D
                
                return tf.map_fn(pwisedist, x)

            ashape = actions.shape[0:len(actions.shape)-1].concatenate(tf.TensorShape([numroles, actionspec]))
            aout = tf.reshape(r[1], ashape)
            pwise = pwisedist_bcast(r[1])
            return tf.zeros(role_actions.shape, dtype = role_actions.dtype), aout, pwise

    @tf.function(autograph=False)
    def diff_potential_closure(self, actions, diffconn, additional_args, device = '/CPU:0'):
        with tf.device(device):


            numroles = additional_args[0]
            rolespec = additional_args[1]
            actionspec = additional_args[2]
            rolestates = additional_args[3]
            connectivity = additional_args[4]
            actionbounds = additional_args[5]
            
            idx_width = math.ceil(math.log(numroles, 2.))

            
            rolestates_save = rolestates

            actions_init = tf.zeros(tf.concat([tf.shape(rolestates)[0:2], [actionspec]], axis = 0), dtype = rolestates.dtype)
            rolestates_ = tf.concat([rolestates, actions_init], axis = 2)
            rolestates_us = tf.unstack(rolestates_, axis = 1)
            role_actions = self.nodeparams[0](rolestates_us)
            role_actions = tf.transpose(role_actions, perm = [1,0,2])
            if self.disable_ic:
                return role_actions
            
            ashape = actions.shape[0:len(actions.shape)-1].concatenate(tf.TensorShape([numroles, actionspec]))
            actreshape = tf.reshape(actions, ashape)
            actreshape_2 = concat_idx(actreshape)

            actreshape_2 = tf.transpose(actreshape_2, perm = [actreshape.shape.rank-2] + list(range(0, actreshape.shape.rank - 2)) + [actreshape.shape.rank-1])
            rolestates = actreshape_2
            s1 = tf.expand_dims(rolestates, 0)
            s2 = tf.expand_dims(rolestates, 1)

            s1_tile = tf.tile(s1, (numroles, 1, 1, 1, 1))
            s2_tile = tf.tile(s2, (1, numroles, 1, 1, 1))
            s_ready = tf.concat([s2_tile, s1_tile], axis = -1)
            s_unstacked = tf.unstack(s_ready, axis = 0)
            s_unstacked = [tf.reshape(si, (-1, 2*actionspec + 2*self.roleencoding)) for si in s_unstacked]
            mapped = tf.stack(self.edgeparams[0](s_unstacked))

            mapped = tf.reshape(mapped, (numroles, numroles, rolestates.shape[1], actionspec))
            mapped = tf.transpose(mapped, perm = [2, 0, 1, 3])




            esum = mapped * tf.expand_dims(diffconn, axis = -1)
            esum = tf.math.reduce_sum(esum, axis = [2])


            rolestates = tf.concat([rolestates_save, esum], axis = -1)

            rolestates_us = tf.unstack(rolestates, axis = 1)
            role_actions = self.nodeparams[0](rolestates_us)
            role_actions = tf.transpose(role_actions, perm = [1,0,2])

            return role_actions

    @tf.function(autograph = False)
    def compile(self, rolestates, actionbounds, device = '/CPU:0'):
        with tf.device(device):
            connectivity = self.connectivity(rolestates)
            actions = self.init_locations(35)
            out = self.potential_closure(actions, [self.numroles, self.rolespec, self.actionspec, rolestates, connectivity, actionbounds])
            return out

num_params__ = None

class Policy(object):
    def __init__(self, numagents, agentspec, actionspec, actionbounds, num_layers, num_hidden, disable_ra = False, disable_ic = False):
        self.ra = RoleAssignment(numagents, agentspec, num_layers, num_hidden, disable_ra = disable_ra) 
        self.ic = InteractionCapture(numagents, agentspec, actionspec, actionbounds, num_layers, num_hidden, disable_ic = disable_ic)
        self.agentspec = agentspec
        self.actionbounds = actionbounds
        self.actionspec = actionspec
        self.numagents = numagents

    def get_params_spec(self):
        return np.sum([self.ra.get_params_spec(), self.ic.get_params_spec()])

    def get_params_spec_raw(self):
        return self.ra.get_params_spec_raw() + self.ic.get_params_spec_raw()

    def instantiate(self, params):
        racount = self.ra.get_params_spec()
        
        self.ra.instantiate(params[0:racount])
        self.ic.instantiate(params[racount:])
    def get_weights(self):
        return self.ra.get_weights() + self.ic.get_weights()
    
    def eval_policy(self, agentstates):
        roleassign = self.ra.compile(agentstates)
        runstacked = tf.unstack(roleassign, axis = 1)
        
        rolesstates = tf.stack([tf.gather_nd(batch_dims = 1, params = agentstates, indices = tf.expand_dims(stsi, axis = -1)) for stsi in runstacked], axis = 1)
        
        #put the role idx info which is consumed by interaction capture.
        rolesstates = concat_idx(rolesstates)
        

        if not disable_ic:
            role_acts, edge_acts, pwise_dists = self.ic.compile(rolesstates, self.actionbounds)
            role_acts, edge_acts, pwise_dists = role_acts.numpy(), edge_acts.numpy(), pwise_dists.numpy()

            edge_add = []

            pwise_dist_bool = pwise_dists < 0.1
            for iep in range(role_acts.shape[0]):
                g = nk.Graph(n = edge_acts.shape[1])
                for ei in range(edge_acts.shape[1]):
                    for ej in range(edge_acts.shape[1]):
                        if pwise_dist_bool[iep][ei][ej]:
                            g.addEdge(ei, ej)
                
                reachable = nk.reachability.ReachableNodes(g, exact = True)
                reachable.run()

                totreach = 0
                idx = 0
                for ei in range(edge_acts.shape[1]):
                    curr = reachable.numberOfReachableNodes(ei)
                    if curr > totreach:
                        totreach = curr
                        idx = ei

                bfs = nk.distance.BFS(g, idx)
                bfs.run()
                dists = np.array(bfs.getDistances()).reshape(-1)
                dists = dists < edge_acts.shape[1]
                vec = np.mean(np.expand_dims(dists, [1,2]) * edge_acts[iep], axis = 0)
                edge_add.append(vec)

            edge_add = np.array(edge_add)

            out = role_acts + edge_add
        else:
            out = self.ic.compile(rolesstates, self.actionbounds).numpy()
        
        out = out.clip(self.actionbounds[0], self.actionbounds[1])
        outprep = out
        outprep = outprep.reshape(list(agentstates.shape[0:2]) + [-1])
        outprep = np.clip(outprep, self.actionbounds[0], self.actionbounds[1])

        rassign = roleassign.numpy()
        
        def invert_permutation(p):
            s = np.empty(p.size, p.dtype)
            s[p] = np.arange(p.size)
            return s
        for j in range(num_episodes):
            iper = invert_permutation(rassign[j])
            swp = outprep[j][iper]
            outprep[j] = swp
        return outprep.reshape(list(agentstates.shape[0:2]) + [-1])
    
consume_agent_settings = [ci[0].recv() for ci in conns]



plc = Policy(consume_agent_settings[0][0], consume_agent_settings[0][1], consume_agent_settings[0][2], [-1., 1.],
             num_layers = num_layers,
             num_hidden = num_hidden,
             disable_ra = disable_ra,
             disable_ic = disable_ic)


def main(run_for, weights = None):
    with tf.device('/CPU:0'):
        global policy
        numparams = plc.get_params_spec()
        global num_params__
        num_params__ = numparams
        if weights is None:
            print('random weights')
            weights = np.array(np.random.rand(numparams)*2.0 - 1.0, dtype = np.float32)
        else:
            plc.instantiate(weights)

        obs_collect = []
        acts_collect = []
        rew = 0.
        [ci[0].send(run_for) for i, ci in enumerate(conns)]
        
        for i in range(run_for):
            recvs = [ci[0].recv() for ci in conns]
            obslist = np.array([ri[0] for ri in recvs], dtype=np.float32)
            obsreshape = obslist.reshape((num_episodes, plc.numagents, plc.agentspec))
            acts = plc.eval_policy(obsreshape)

            acts = np.array(acts, dtype=np.float32)
            [ci[0].send(acts[i].reshape(plc.numagents, plc.actionspec)) for i, ci in enumerate(conns)]
            obs_collect.append(obsreshape)
            acts_collect.append(acts)

        recvs = [ci[0].recv() for ci in conns]
        return obs_collect, acts_collect, plc, recvs



@tf.function(autograph = False)
def diffapprox__(ra, ic, agentstates, agentactions, weights, num_vars):
    with tf.device('/CPU:0'):
        valuation = ra.value(agentstates, device = '/CPU:0')
        vt = valuation

        vt = tf.math.softplus(vt)
        for j in range(12):
            
            vt = vt / tf.math.reduce_sum(vt, axis = 2, keepdims = True)
            vt = vt / tf.math.reduce_sum(vt, axis = 1, keepdims = True)

        diffstates = tf.linalg.matmul(vt, agentstates)

        diffstates = concat_idx(diffstates)
        
        diffconn1 = ic.diff_connectivity(diffstates, device = '/CPU:0')
        diffconn = tf.math.sigmoid(diffconn1)
        agent_for_pc = tf.reshape(agentactions, (tf.shape(agentstates)[0], 1, -1))
        psum = ic.diff_potential_closure(agent_for_pc, diffconn, [ic.numroles, ic.rolespec, ic.actionspec, diffstates, None, [-1., 1.]])
        
        tsum = psum
        grads = tf.gradients(tsum, weights)
        gw = tf.expand_dims(tf.concat([tf.reshape(si, (-1,)) for si in grads], axis = 0), 0)
        return gw

@tf.function(autograph = False)
def gen_hvp_vectors(ra, ic, agentstates, agentactions, plc, num_vars, batch_size, num_approx = 500, perturb = 0.20):
    with tf.device('/CPU:0'):
        i = tf.constant(0)
        weights = ra.get_weights() + ic.get_weights()
        tngts = [tf.identity(wi + tf.constant(0.0)) for wi in weights]
        basef = diffapprox__(ra, ic, agentstates, agentactions, weights, num_vars)

        @tf.function(autograph = False)
        def perturb_and_diff(i, weights):
            vecs =[tf.random.uniform(tf.shape(wi), minval = -1. * perturb, maxval = 1. * perturb) for i, wi in enumerate(tngts)]
            ign = [wi.assign(tngts[i] + vecs[i]) for i, wi in enumerate(weights)]
            out_grads = diffapprox__(ra, ic, agentstates, agentactions, weights, num_vars)
            vecf = tf.concat([tf.reshape(si, (-1,)) for si in vecs], axis = 0)
            return (out_grads - basef, vecf)

        @tf.function
        def lbody(i, acc, acc2):
            gw, vecf = perturb_and_diff(i, weights)
            a1 = tf.concat([acc, tf.expand_dims(gw, axis = 0)], axis = 0)
            a2 = tf.concat([acc2, tf.expand_dims(vecf, axis = 0)], axis = 0)
            return tf.add(i, 1), a1, a2

        c = lambda i, acc, acc2: tf.less(i, num_approx)
        r = tf.while_loop(c, lbody, (i, tf.zeros(shape = (0, 1, num_vars), dtype =tf.float32), tf.zeros(shape = (0, num_vars))), parallel_iterations =20,
                         shape_invariants = (tf.TensorShape([]), tf.TensorShape([None, 1, num_vars]), tf.TensorShape([None,num_vars])))
        return r



if do_param_print:
    print('param print')
    print(consume_agent_settings[0])
    print(disable_ra)
    print(disable_ic)
    print(plc.get_params_spec())
    print(plc.ra.get_params_spec())
    print(plc.ic.get_params_spec())
    pickle.dump(plc.get_params_spec_raw(), open(param_file, 'wb'))
    sys.exit(0)


while True:
    fifo_read = open(in_fifo, 'rb')
    input_stuff = fifo_read.read()
    fifo_read.close()
    (weights, run_for) = pickle.loads(input_stuff)
    print('running main')
    print(datetime.datetime.now())
    oc, ac, plc, rew = main(run_for, weights)
    
    print('end main')
    print(datetime.datetime.now())
    obs = tf.concat(oc, axis = 0)
    acts = tf.concat(ac, axis = 0)
    acts_clean = tf.where(tf.math.is_nan(acts), tf.zeros(tf.shape(acts)), acts)

    with tf.device('/CPU:0'):
        num_vars = sum([tf.math.reduce_sum(tf.where(True, tf.ones(tf.shape(t), dtype =tf.int32), tf.ones(tf.shape(t), dtype =tf.int32))) for t in plc.get_weights()])
        rout = gen_hvp_vectors(plc.ra, plc.ic, obs, acts_clean, plc, int(num_vars.numpy()), batch_size = 32)

        @tf.function
        def quadratic(xs, args):
          #traces once, minimum is bound on first trace
          v_rest_flat = args[0]
          a_subbed = args[1]
          with tf.GradientTape() as g:
              g.watch(xs)
              x = tf.reshape(xs, (num_params__, num_params__))


              out = tf.reduce_sum((tf.linalg.matmul(v_rest_flat, x) - tf.squeeze(a_subbed))**2)
          grad = g.gradient(out, xs)
          return out, grad

        a = tfp.optimizer.lbfgs_minimize(
              quadratic,[rout[2], rout[1]], initial_position=tf.zeros((num_params__*num_params__,)),
              stopping_condition=tfp.optimizer.converged_all,
              max_iterations=100,
              tolerance=1e-5)

        atf = tf.reshape(a.position, (num_params__, num_params__))
        atf = tf.math.abs(atf) + tf.transpose(tf.math.abs(atf))

    print('end procedure')
    print(datetime.datetime.now())
    fifo_write = open(out_fifo, 'wb')
    str_out = pickle.dumps((np.sum(rew), atf.numpy()))
    fifo_write.write(str_out)
    fifo_write.close()
