import os
import argparse
import multiprocessing as mp
import pickle
import glob
import numpy as np
import shutil
import gzip
import tensorflow as tf
import csv
from pyscipopt import Model
import math

import pyscipopt as scip
import utilities
from utilities_tf import load_batch_gcnn, load_batch_gcnn_integer, load_batch_gcnn_branching

import time
from collections import deque
import pickle
import random

from ddpg.ddpg_learner import DDPG
from ddpg.models import GCNPolicy, GCNPolicy_critic
from ddpg.memory_integer import Memory
from ddpg.noise import AdaptiveParamNoiseSpec, NormalActionNoise, OrnsteinUhlenbeckActionNoise
from ddpg.common import set_global_seeds
import ddpg.common.tf_util as U

import generate_instances_fly
import pandas as pd

import tensorflow.contrib.eager as tfe

import gurobipy as gp
from gurobipy import GRB

import shutil
import networkx as nx

from branching_net.model import GCNPolicy as Branching_policy

from warnings import simplefilter
from scipy.integrate import simps
simplefilter(action='ignore', category=FutureWarning)


def make_samples(in_queue, branching_step, integer_update_flag, integer_list, integer_length):
    """
    Worker loop: fetch an instance, run an episode and record samples.

    Parameters
    ----------
    in_queue : multiprocessing.Queue
        Input queue from which orders are received.
    out_queue : multiprocessing.Queue
        Output queue in which to send samples.
    """

    episode, instance, obs, actions, seed, exploration_policy, eval_flag, time_limit, out_dir, no_improve_rounds = in_queue
    print('[w {}] episode {}, seed {}, processing instance \'{}\'...'.format(os.getpid(),episode,seed,instance)) 

    if eval_flag==1:
        seed=0
    else:
        seed=0

    m = scip.Model()
    m.setIntParam('display/verblevel', 0)
    m.readProblem('{}'.format(instance))
    # utilities.init_scip_paramsR(m, seed=seed)
    if time_limit < 30:
        utilities.init_scip_paramsR_setcover(m, seed=seed)
    m.setIntParam('timing/clocktype', 1)
    m.setRealParam('limits/time', time_limit)   # 设定求解时间，避免训练时间无限延长

    varss = [x for x in m.getVars()]   # 取出全部变量

    minimum_k = np.where(np.array(actions.squeeze())<0.5)
    min_k = minimum_k[0]

    start_index = 0
    integer_index = 0

    fix_var_list = []
    fix_value_list = []
    for i in range(len(varss)):
        cur_variable = varss[i]
        if cur_variable.vtype() == 'INTEGER':
            m.setRealParam('limits/time', 180)
            upbound = cur_variable.getUbOriginal()
            lowbound = cur_variable.getLbOriginal()
            if upbound == lowbound:
                start_index += 1
                continue

            if upbound < 1e10 and lowbound >= -1e10:
                binary_length = integer_length[integer_index]
                integer_index += 1

                # sol_index = actions.squeeze()[start_index:start_index+binary_length]
                solutions = obs[start_index:start_index+binary_length]
                solutions = [str(int(temp)) for temp in solutions]
                for idx in range(start_index, start_index + binary_length):
                    if idx in min_k:  # 0.15 better

                        current_sol = int(''.join(solutions), 2)

                        lowbound = max(current_sol - (upbound - lowbound)/2, lowbound)
                        upbound = min(current_sol + (upbound - lowbound)/2, upbound)

                        m.addCons(cur_variable >= lowbound)
                        m.addCons(cur_variable <= upbound)

                        # bound_update = (upbound + lowbound) / 2
                        # if current_sol >= bound_update:
                        #     m.addCons(cur_variable >= (bound_update - 1e-3))
                        #     lowbound = bound_update
                        # else:
                        #     m.addCons(cur_variable <= (bound_update + 1e-3))
                        #     upbound = bound_update
                    else:
                        break
                
                start_index += binary_length
        else:
            # sol_index = actions.squeeze()[start_index]
            solutions = obs[start_index]
            counts = 0
            if start_index in min_k:
                counts += 1
                # if counts/len(obs) > 0.25:
                #     continue
                fix_var_list.append(cur_variable)
                fix_value_list.append(solutions)
                            
            start_index += 1

    if no_improve_rounds > branching_step:
        m.addCons(sum(abs(fix_var_list[i] - fix_value_list[i]) for i in range(len(fix_var_list))) <= 5)
    else:
        for i in range(len(fix_var_list)):
            a,b = m.fixVar(fix_var_list[i],fix_value_list[i])
    
    m.optimize()

    print(m.getPrimalbound())
    if abs(m.getPrimalbound()) > 1e15:
        K = obs   # 未得到可行解的情况下，各变量的取值保持不变
        obj = abs(m.getPrimalbound())  # 目标值赋予无穷大
    else:
        # K = update_solution.copy()   #获取各变量的取值
        K = [m.getVal(x) for x in m.getVars()]
        obj = m.getObjVal()
    
        integer_list = []
        integer_length = []
        update_solution = []

        all_vars = m.getVars()
        for i in range(len(all_vars)):
            cur_variable = all_vars[i]
            if cur_variable.vtype() == 'INTEGER':
                upbound = cur_variable.getUbOriginal()
                lowbound = cur_variable.getLbOriginal()
                if upbound == lowbound:
                    update_solution.extend([K[i]])
                    continue
                if upbound < 1e10 and lowbound >= -1e10:  # 判断是否存在上下界
                    integer_list.append(cur_variable.name)
                    cur_length = int(np.ceil(math.log2(int(upbound - lowbound + 1))))
                    integer_length.append(cur_length)
                    binary_str = bin(int(K[i]))[2:].zfill(cur_length)
                    binary_list = list(binary_str)
                    binary_list = [int(temp) for temp in binary_list]
                    update_solution.extend(binary_list)
                else:
                    update_solution.extend([K[i]])
            else:
                update_solution.extend([K[i]])    

        K = update_solution.copy()

    m.freeProb() 

    out_queue = {
        'type': 'solution',
        'episode': episode,
        'instance': instance,
        'sol' : np.array(K),
        'obj' : obj,
        'seed': seed,
        'mask': min_k,
    }

    return out_queue


def send_orders(instances, epi, obs, actions, seed, exploration_policy, eval_flag, time_limit, out_dir, no_improve_rounds):
    """
    Continuously send sampling orders to workers (relies on limited
    queue capacity).

    Parameters
    ----------
    orders_queue : multiprocessing.Queue
        Queue to which to send orders.
    instances : list
        Instance file names from which to sample episodes.
    seed : int
        Random seed for reproducibility.
    exploration_policy : str
        Branching strategy for exploration.
    query_expert_prob : float in [0, 1]
        Probability of running the expert strategy and collecting samples.
    time_limit : float in [0, 1e+20]
        Maximum running time for an episode, in seconds.
    out_dir: str
        Output directory in which to write samples.
    """
    rng = np.random.RandomState(seed)

    orders_queue = []
    for i in range(len(instances)):
        seed = rng.randint(2**32)
        orders_queue.append([epi[i], instances[i], obs[i], actions[i], seed, exploration_policy, eval_flag, time_limit, out_dir,no_improve_rounds])

    return orders_queue


def collect_samples(instances, epi, obs, actions, out_dir, rng, n_samples, n_jobs,
                    exploration_policy, eval_flag, time_limit, no_improve_rounds, branching_step, integer_update_flag, integer_list, integer_length):
    """
    Runs branch-and-bound episodes on the given set of instances, and collects
    randomly (state, action) pairs from the 'vanilla-fullstrong' expert
    brancher.

    Parameters
    ----------
    instances : list
        Instance files from which to collect samples.
    out_dir : str
        Directory in which to write samples.
    rng : numpy.random.RandomState
        A random number generator for reproducibility.
    n_samples : int
        Number of samples to collect.
    n_jobs : int
        Number of jobs for parallel sampling.
    exploration_policy : str
        Exploration policy (branching rule) for sampling.
    query_expert_prob : float in [0, 1]
        Probability of using the expert policy and recording a (state, action)
        pair.
    time_limit : float in [0, 1e+20]
        Maximum running time for an episode, in seconds.
    """
    tmp_samples_dir = '{}/tmp'.format(out_dir)
    os.makedirs(tmp_samples_dir, exist_ok=True)
    
    pars = send_orders(instances, epi, obs, actions, rng.randint(2**32), exploration_policy, eval_flag, time_limit, tmp_samples_dir, no_improve_rounds) 
    
    out_Q = []
    for n in range(n_samples):
        out_queue = make_samples(pars[n], branching_step, integer_update_flag, integer_list[n], integer_length[n])
        out_Q.append(out_queue)        

    # record answers 
    i = 0
    collecter=[]
    epi=[]
    obje=[]
    instances=[]
    mask=[]

    for sample in out_Q:
        
        collecter.append(sample['sol'])
        
        epi.append(sample['episode'])
        
        obje.append(sample['obj'])

        instances.append(sample['instance'])

        mask.append(sample['mask'])
        
        i += 1

    shutil.rmtree(tmp_samples_dir, ignore_errors=True)
   
    return np.concatenate(np.stack(collecter), axis=0), np.stack(epi), np.stack(obje), instances, mask
    
##########  initial formulation features    
class SamplingAgent0(scip.Branchrule):

    def __init__(self, episode, instance, seed, exploration_policy, out_dir):
        self.episode = episode
        self.instance = instance
        self.seed = seed
        self.exploration_policy = exploration_policy
        self.out_dir = out_dir

        self.rng = np.random.RandomState(seed)
        self.new_node = True
        self.sample_counter = 0

    def branchinitsol(self):
        self.ndomchgs = 0
        self.ncutoffs = 0
        self.state_buffer = {}

    def branchexeclp(self, allowaddcons):

        # custom policy branching           
        if self.model.getNNodes() == 1:
            
            # extract formula features
            self.state = utilities.extract_state(self.model, self.state_buffer)              

            result = self.model.executeBranchRule(self.exploration_policy, allowaddcons)
                               
        elif self.model.getNNodes() != 1:
               
            result = self.model.executeBranchRule(self.exploration_policy, allowaddcons)   
            
        else:
            raise NotImplementedError

        # fair node counting
        if result == scip.SCIP_RESULT.REDUCEDDOM:
            self.ndomchgs += 1
        elif result == scip.SCIP_RESULT.CUTOFF:
            self.ncutoffs += 1

        return {'result': result}


def make_samples0(in_queue, initial_solution_flag, node_limit, pre_solve, conflict):
    """
    Worker loop: fetch an instance, run an episode and record samples.

    Parameters
    ----------
    in_queue : multiprocessing.Queue
        Input queue from which orders are received.
    out_queue : multiprocessing.Queue
        Output queue in which to send samples.
    """

#     while True:
    episode, instance, seed, exploration_policy, eval_flag, time_limit, out_dir = in_queue
    print('[w {}] episode {}, seed {}, processing instance \'{}\'...'.format(os.getpid(),episode,seed,instance))

    if eval_flag==1:
        seed=0
    else:
        seed=0
    
    m = scip.Model()
    m.setIntParam('display/verblevel', 0)
    m.readProblem('{}'.format(instance))
    # utilities.init_scip_paramsH(m, seed=seed)
    utilities.init_scip_paramsH_setcover(m, heuristics=initial_solution_flag, presolving=pre_solve, conflict=conflict, seed=seed)
    m.setIntParam('timing/clocktype', 2)
    if node_limit:
        m.setLongintParam('limits/nodes', 1)  # 仅处理当前节点
    else:
        m.setParam('limits/solutions', 1)#    m.setRealParam('limits/time', 50) 
#    m.setParam('limits/solutions', 1)

    branchrule = SamplingAgent0(
        episode=episode,
        instance=instance,
        seed=seed,
        exploration_policy=exploration_policy,
        out_dir=out_dir)

    m.includeBranchrule(
        branchrule=branchrule,
        name="Sampling branching rule", desc="",
        priority=666666, maxdepth=-1, maxbounddist=1)

    abc=time.time()
    print(m)
    print("------------------------------------------------------------------------")
    m.optimize()       
    print(time.time()-abc)    

    b_obj = m.getObjVal()  # 得到当前最优解

    K = [m.getVal(x) for x in m.getVars()]       # 取出每个变量取值

    integer_list = []
    integer_length = []
    update_solution = []
    integer_idx_list = []

    all_vars = m.getVars()
    for i in range(len(all_vars)):
        cur_variable = all_vars[i]
        if cur_variable.vtype() == 'INTEGER':
            upbound = cur_variable.getUbOriginal()
            lowbound = cur_variable.getLbOriginal()
            if upbound == lowbound:
                update_solution.extend([K[i]])
                continue
            if upbound < 1e10 and lowbound >= -1e10:  # 判断是否存在上下界
                integer_list.append(cur_variable.name)
                cur_length = int(np.ceil(math.log2(int(upbound - lowbound + 1))))
                integer_length.append(cur_length)
                binary_str = bin(int(K[i]))[2:].zfill(cur_length)
                binary_list = list(binary_str)
                binary_list = [int(temp) for temp in binary_list]
                update_solution.extend(binary_list)
                integer_idx_list.append(i)
            else:
                update_solution.extend([K[i]])
        else:
            update_solution.extend([K[i]])
    
    K = update_solution.copy()

    if node_limit:
        state = branchrule.state
    else:
        g = scip.Model()
        # g.setIntParam('display/verblevel', 0)
        g.readProblem('{}'.format(instance))
        varss = [x for x in g.getVars()]
        idsx = 0
        length = len(varss) * 0.03
        if "131" in instance:
            length = len(varss) * 0.1
        if '108' in instance:
            length = len(varss) * 0.2
        if 'cauction' in instance:
            length = len(varss) * 0.9
        for i in range(len(varss)):
            if (K[i] == 1 or K[i] == 0) and idsx <= length:
                a,b = g.fixVar(varss[i], K[i])
                idsx += 1
        # utilities.init_scip_paramsH(m, seed=seed)
        print("Here")
        utilities.init_scip_paramsH_setcover(g, heuristics=True, presolving=False, seed=seed)
        g.setLongintParam('limits/nodes', 1)
        new_branchrule = SamplingAgent0(
            episode=episode,
            instance=instance,
            seed=seed,
            exploration_policy=exploration_policy,
            out_dir=out_dir)
        g.includeBranchrule(
            branchrule=new_branchrule,
            name="Sampling branching rule", desc="",
            priority=666666, maxdepth=-1, maxbounddist=1)
        g.optimize()
        state = new_branchrule.state

        g.freeProb()


    out_queue = {
        'type': 'formula',
        'episode': episode,
        'instance': instance,
        'state' : state,
        'seed': seed,
        'b_obj': b_obj,
        'sol' : np.array(K),       
        'integer_name' : integer_list,
        'integer_length': integer_length,
        'integer_idx': integer_idx_list, 
    }   

    print(b_obj)
       
    m.freeTransform()  

    obj = [x.getObj() for x in m.getVars()]  
    
    out_queue['obj'] = sum(obj) 
    
    m.freeProb() 
        
    print("[w {}] episode {} done".format(os.getpid(),episode))
    
    return out_queue


def send_orders0(instances, n_samples, seed, exploration_policy, batch_id, eval_flag, time_limit, out_dir):
    """
    Continuously send sampling orders to workers (relies on limited
    queue capacity).

    Parameters
    ----------
    orders_queue : multiprocessing.Queue
        Queue to which to send orders.
    instances : list
        Instance file names from which to sample episodes.
    seed : int
        Random seed for reproducibility.
    exploration_policy : str
        Branching strategy for exploration.
    query_expert_prob : float in [0, 1]
        Probability of running the expert strategy and collecting samples.
    time_limit : float in [0, 1e+20]
        Maximum running time for an episode, in seconds.
    out_dir: str
        Output directory in which to write samples.
    """
    rng = np.random.RandomState(seed)

    episode = 0
    st = batch_id*n_samples
    orders_queue = []
    for i in instances[st:st+n_samples]:     
        seed = rng.randint(2**32)
        orders_queue.append([episode, i, seed, exploration_policy, eval_flag, time_limit, out_dir])
        episode += 1
    return orders_queue



def collect_samples0(instances, out_dir, rng, n_samples, n_jobs,
                    exploration_policy, batch_id, eval_flag, time_limit, initial_solution_flag, node_limit, presolve, conflict):
    """
    Runs branch-and-bound episodes on the given set of instances, and collects
    randomly (state, action) pairs from the 'vanilla-fullstrong' expert
    brancher.

    Parameters
    ----------
    instances : list
        Instance files from which to collect samples.
    out_dir : str
        Directory in which to write samples.
    rng : numpy.random.RandomState
        A random number generator for reproducibility.
    n_samples : int
        Number of samples to collect.
    n_jobs : int
        Number of jobs for parallel sampling.
    exploration_policy : str
        Exploration policy (branching rule) for sampling.
    query_expert_prob : float in [0, 1]
        Probability of using the expert policy and recording a (state, action)
        pair.
    time_limit : float in [0, 1e+20]
        Maximum running time for an episode, in seconds.
    """
    os.makedirs(out_dir, exist_ok=True)

    tmp_samples_dir = '{}/tmp'.format(out_dir)
    os.makedirs(tmp_samples_dir, exist_ok=True)
    
    pars = send_orders0(instances, n_samples, rng.randint(2**32), exploration_policy, batch_id, eval_flag, time_limit, tmp_samples_dir)  
    
    out_Q = []
    for n in range(n_samples):
        out_queue = make_samples0(pars[n], initial_solution_flag, node_limit, presolve, conflict)
        out_Q.append(out_queue)        
        

    # record answers and write samples
    i = 0

    states = []

    epi=[]
    instances=[]
    obje=[]
    bobj=[]
    ini_sol=[]

    integer_list = []
    integer_length = []
    integer_idx = []
    
    for sample in out_Q:
        
        ini_sol.append(sample['sol'])    
        integer_list.append(sample['integer_name'])
        integer_length.append(sample['integer_length'])
        integer_idx.append(sample['integer_idx'])     
        
        states.append(sample['state'])  # states at current node
        
        epi.append(sample['episode'])
        
        instances.append(sample['instance'])
        
        obje.append(sample['obj'])

        bobj.append(sample['b_obj'])
        
        i += 1

    shutil.rmtree(tmp_samples_dir, ignore_errors=True)
    
    return states, np.stack(epi), np.stack(obje), np.stack(bobj), instances, np.concatenate(np.stack(ini_sol), axis=0), integer_list, integer_length, integer_idx  # 需要将初始解concat至大图上


def pad_output(output, n_vars_per_sample, pad_value=-1e8):

    new_output = []

    start_index = 0
    for cur_length in n_vars_per_sample:
        end_index = start_index + cur_length
        new_output.append(output[0][start_index:end_index])
        start_index = end_index
    
    return new_output

    # n_vars_max = tf.reduce_max(n_vars_per_sample)

    # output = tf.split(
    #     value=output,
    #     num_or_size_splits=n_vars_per_sample,
    #     axis=1,
    # )
    # output = tf.concat([
    #     tf.pad(
    #         x,
    #         paddings=[[0, 0], [0, n_vars_max - tf.shape(x)[1]]],
    #         mode='CONSTANT',
    #         constant_values=pad_value)
    #     for x in output
    # ], axis=0)


    # return output



def learn(args, is_maximum=1,network='mlp',
          seed=None,
          total_timesteps=None,
          nb_epochs=None, # with default settings, perform 1M steps total
          nb_epoch_cycles=25,
          nb_rollout_steps=20,
          reward_scale=1.0,
          render=False,
          render_eval=False,
          noise_type=None,
          normalize_returns=False,
          normalize_observations=False,
          critic_l2_reg=1e-2,
          actor_lr=1e-4,
          critic_lr=1e-3,
          popart=False,
          gamma=0.99,  #0.9 #0.96
          clip_norm=None,
          nb_train_steps=1, # per epoch cycle and MPI worker,  50 10 30   100  10  3
          nb_eval_steps=20000, # default:1000
          batch_size=10, # per MPI worker  64 32  64   128   64  128
          tau=0.01,
          eval_env=None,
          save_path=None,
          param_noise_adaption_interval=30):
    
    test_time_limit = 200  # default:200
    load_path = 'models/RL_model/' + args.problem + "/model_graph.joblib"

    print("seed {}".format(args.seed))

    batch_sample = 10  #8
    batch_sample_eval = 1 #8
    exploration_strategy = 'relpscost'
    eval_val = 0
    time_limit = 10  # 5  #2
    variable_to_branching = 0.25  # Branching variable ratio
    sub_mip_ratio = 1  # limited size ratio for sub-mip problems
    start_branching_rounds = 10000  # steps to start local branching
    emb_size = 16
    run_times = 1
    test_time_limit = 200

    integer_update_flag = False  # 问题中是否包含整数，若包含，对fix部分进行range update

    pre_solve = True
    conflict = True

    max_variable_size = None  # 每个子问题的最大variable size

    instances_valid = []

    run_times = 1

    best_bound_data = pd.read_csv("best_bound.csv")
    best_bound_dict = best_bound_data.set_index('Instances')[
        'best_bound'].to_dict()  # used to calculate the primal integral

    if args.problem == 'maxcut':
        instances_valid = glob.glob('data/instances/test_4950_2975/*.lp')
        #        instances_valid += glob.glob('data/instances/transfer_9950_5975/*.lp')
        #         instances_valid += glob.glob('data/instances/transfer_19950_11975/*.lp')

        out_dir = 'data/samples/tmp'
        instances_valid = instances_valid
        initial_solution_heu = True  # initial feasible solution
        is_maximum = 1  # maximum or Minimum problem
        node_limit = True  # limited to one node/ limit to one solution
    elif args.problem == 'cats':
        # instances_valid = glob.glob('data/instances/cauctions/test_2000_4000/*.lp')
        instances_valid += glob.glob('data/instances/cauctions/transfer_4000_8000/*.lp')
        instances_valid += glob.glob('data/instances/cauctions/transfer_8000_16000/*.lp')
        out_dir = 'data/samples/tmp'
        initial_solution_heu = True  # initial feasible solution
        is_maximum = 1  # maximum or Minimum problem
        node_limit = False  # limited to one node/ limit to one solution

        pre_solve = False
        conflict = False
    elif args.problem == 'indset':
        instances_valid = glob.glob('data/instances/indset/test_1500_4/*.lp')
        instances_valid += glob.glob('data/instances/indset/transfer_6000_4/*.lp')
        instances_valid += glob.glob('data/instances/indset/transfer_3000_4/*.lp')
        out_dir = 'data/samples/tmp'
        initial_solution_heu = True  # initial feasible solution
        is_maximum = 1  # maximum or Minimum problem
        node_limit = True  # limited to one node/ limit to one solution

        pre_solve = False
    elif args.problem == 'setcover':
        instances_valid += ["data/instances/setcover/test_5000r_1000c_0.05d/instance_{}.lp".format(i + 101) for i in
                            range(50)]
        instances_valid += glob.glob('data/instances/setcover/transfer_5000r_4000c_0.05d/*.lp')  # transfer
        instances_valid += glob.glob('data/instances/setcover/transfer_5000r_2000c_0.05d/*.lp')  # transfer
        out_dir = 'data/samples/tmp'
        initial_solution_heu = False
        start_branching_rounds = 10  # 10
        is_maximum = 0
        node_limit = True  # limited to one node/ limit to one solution
    elif args.problem == 'item':
        instances_valid += ["data/instances/item_placement/test/item_placement_{}.mps.gz".format(i + 10000) for i in
                            range(100)]
        out_dir = 'data/samples/tmp'
        initial_solution_heu = True
        sub_mip_ratio = 0.15  # 0.15
        is_maximum = 0
        node_limit = True  # limited to one node/ limit to one solution
    elif args.problem == 'miplib':
        # instances_valid += ["data/instances/anonymous/test/anonymous_{}.mps.gz".format(i + 119) for i in range(20)]
        instances_valid = ["data/instances/anonymous/test/anonymous_126.mps.gz"]
        # instances_valid = ["data/instances/anonymous/valid/anonymous_102.mps.gz"]

        out_dir = "data/samples/tmp"
        batch_size = 2  # different training batch size due to the complexity

        is_maximum = 0
        initial_solution_heu = True  # initial feasible solution
        node_limit = False  # limited to one node/ limit to one solution
        time_limit = 60
        test_time_limit = 1800  # different time limit as illustrated in the paper

        emb_size = 6  # different training emb size due to the complexity
        integer_update_flag = True
    else:
        raise NotImplementedError

    memory_size = 490

    memory = Memory(memory_size, batch_size)
    critic = GCNPolicy_critic(batch_size, emb_size)
    actor = GCNPolicy(emb_size)

    action_noise = None
    param_noise = None

    agent = DDPG(actor, critic, memory, gamma=gamma, tau=tau, normalize_returns=normalize_returns,
                 normalize_observations=normalize_observations,
                 batch_size=batch_size, action_noise=action_noise, param_noise=param_noise, critic_l2_reg=critic_l2_reg,
                 actor_lr=actor_lr, critic_lr=critic_lr, enable_popart=popart, clip_norm=clip_norm,
                 reward_scale=reward_scale)


    ### TENSORFLOW SETUP ###
    if args.gpu == -1:
        os.environ['CUDA_VISIBLE_DEVICES'] = ''
    else:
        os.environ['CUDA_VISIBLE_DEVICES'] = '{}'.format(args.gpu)                    

    sess = U.get_session()
    # Prepare everything.
    agent.initialize(sess)

    if load_path is not None:
        agent.load(load_path)

    # sess.graph.finalize()

    rng = np.random.RandomState(args.seed)
    tf.set_random_seed(rng.randint(np.iinfo(int).max))


    agent.reset()

    # branching_policy = df_model(args=args)
    # exit()

    # Load Branching Policies
    branching_sess = None

    branching_load_path = f"branching_net/trained_models/{args.problem}/baseline/{args.seed}/best_params.pkl"
    branching_model = Branching_policy()
    branching_model.restore_state(branching_load_path)
    branching_model.call = tfe.defun(branching_model.call, input_signature=branching_model.input_signature)
    # create output directory, throws an error if it already exists                  
    episodes = 0 #scalar
    t = 0 # scalar

    var0 = tf.placeholder(tf.float32, shape=(None,) + (15,), name='var0')
    cons = tf.placeholder(tf.float32, shape=(None,) + (5,), name='cons')
    edge_features = tf.placeholder(tf.float32, shape=(None,) + (1,), name='edge_index')
    edge_indices = tf.placeholder(tf.int32, shape=(2,) + (None,), name='edge_fea')
    n_vars = tf.placeholder(tf.int32, shape=(None,), name='n_var')
    n_cons = tf.placeholder(tf.int32, shape=(None,), name='n_cons')
    is_training = tf.placeholder(tf.bool, shape=(), name='is_training')

    feature_inputs = (cons, edge_indices, edge_features, var0, n_cons, n_vars)
    outputs = branching_model.call(feature_inputs, is_training)


    max_obj = 0       
    #### start train    
    for epoch in range(1): 

        fieldnames = [
            'instance',
            'obj',
            'initial',
            'bestroot',
            'imp',
            'mean',
            'time',
            'Integral',
        ]
        result_file = "{}_{}.csv".format(args.problem,time.strftime('%Y%m%d-%H%M%S'))    
        os.makedirs('ddpg_test_results', exist_ok=True)
        with open("ddpg_test_results/{}".format(result_file), 'w', newline='') as csvfile:
            writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
            writer.writeheader()    

            for cyc in range(len(instances_valid)//batch_sample_eval):
                best_miunu = None
                for t in range(run_times):  # Run multiple times

                    ### initial formulation features
                    no_improve_rounds = 0

                    obj_lis = []
                    time_list = []
                    abcd=time.time()

                    current_states, epi, ori_objs, best_root, instances, ini_sol, integer_list, integer_length, integer_idx = collect_samples0(instances_valid,
                                                                                                            out_dir + '/train',
                                                                                                            rng,
                                                                                                            batch_sample_eval,
                                                                                                            args.njobs,
                                                                                                            exploration_policy=exploration_strategy,
                                                                                                            batch_id=cyc,
                                                                                                            eval_flag=eval_val,
                                                                                                            time_limit=None,
                                                                                                            initial_solution_flag = initial_solution_heu,
                                                                                                            node_limit=node_limit,
                                                                                                            presolve = pre_solve,
                                                                                                            conflict = conflict)
                    
                    print(integer_length)
                    print("----------------------------------------------------------------------------")
                    
                    ### initial solution
                    init_sols = ini_sol  

                    ori_objs = np.copy(best_root)

                    cur_sols = init_sols

                    record_ini = np.copy(ori_objs)

                    rec_inc = []  # 保存各sample的feasible solution list
                    rec_inc.append(init_sols)  # 每个元素的长度为batch内全部variable num之和

                    rec_best = np.copy(best_root)  # ADD

                    inc_val = rec_inc[-1].copy()  # 当前时刻最优解对应各变量取值
                    avg_inc_val = np.array(rec_inc).mean(0)  # 各variable平均solution取值

                    pre_sols = np.ones([2, len(rec_inc[0])])  # variable feature

                    c_features, e_indices, e_features, v_features, n_cs_per_sample, n_vs_per_sample, variable_objective_features, variable_objective_indices, constraint_objective_features, constraint_objective_indices, objective_features = load_batch_gcnn_integer(
                        current_states, integer_list.copy(), integer_length.copy(), integer_idx.copy())
                    
                    # c_features_update, _, e_features_update, v_features_update, _, _, _, _, _, _, _ = load_batch_gcnn_branching(
                    #     current_states)
                    
                    dynamic_variable_features = np.concatenate(
                        (inc_val[np.newaxis, :], avg_inc_val[np.newaxis, :], pre_sols), axis=0)
                    updated_v_features = np.concatenate((v_features, dynamic_variable_features.transpose(1, 0)),
                                                        axis=1)
                    # updated_v_features_update = np.concatenate((v_features_update, dynamic_variable_features.transpose(1, 0)),
                    #                                     axis=1)

                    # current observations
                    variable_objective_features = variable_objective_features.reshape(-1, 1)
                    constraint_objective_features = constraint_objective_features.reshape(-1, 1)
                    cur_obs = [c_features, e_indices, e_features, updated_v_features, n_cs_per_sample,
                                n_vs_per_sample, variable_objective_features, variable_objective_indices, constraint_objective_features, constraint_objective_indices, objective_features]

                    mask = None

                    count_array = np.array([0 for i in range(n_vs_per_sample[0])])

                    # used to evaluate the parameters
                    variable_to_branching_update = variable_to_branching if 'max' not in args.problem else max(variable_to_branching, 1 - variable_to_branching)

                    # Perform rollouts.                
                    for t_rollout in range(nb_eval_steps): 
                    
                        action, q, _, _ = agent.step(cur_obs, apply_noise=True, compute_Q=True)
                        pre_sols = np.concatenate((pre_sols, cur_sols[np.newaxis, :]), axis=0)

                        action = np.random.binomial(1, action)

                        sample = np.where(action[0] == 1.0)[0]  # 当前step需要求解的变量集合
                        
                        binary_vars = np.where(v_features[:,2] == 0)  # 整数变量集合

                        sample = np.array(list(set(sample).intersection(set(binary_vars[0]))))  # 当前step需要求解的整数变量集合
                        
                        if max_variable_size is None:
                            sub_mip_length = int(len(action[0]) * sub_mip_ratio)
                        else:
                            sub_mip_length = min(int(len(action[0]) * sub_mip_ratio), max_variable_size)

                        count_prob_array = (max(count_array) + 1) - count_array
                        count_prob_array = count_prob_array[sample]
                        count_prob = count_prob_array/sum(count_prob_array)

                        if len(sample) > sub_mip_length:
                            idx = np.random.choice(sample, len(sample) - sub_mip_length, replace=False, p=count_prob)
                        else:
                            idx = []
                        
                        solve_ids = np.array(list(set(sample) - set(idx)))
                        count_array[solve_ids] += 1
                        
                        action = np.where(action > 0.5, action, 0.)
                        action = np.where(action == 0., action, 1.)

                        # action[0] = action[0] * sample
                        action[0][idx] = 0
                        action[0][np.where(v_features[:,2] == 1)] = 1
                        
                        # logits = branching_policy((c_features.astype(np.float32), e_indices.astype(np.int32), e_features.astype(np.float32), cur_obs[3].astype(np.float32), tf.reduce_sum(n_cs_per_sample, keepdims=True), tf.reduce_sum(n_vs_per_sample, keepdims=True)), tf.convert_to_tensor(False))

                        # updated_v_features_update = np.concatenate((updated_v_features_update, action.transpose(1,0)), axis=1)
                        updated_v_features_update = np.concatenate((cur_obs[3], action.transpose(1,0)), axis=1)
                        
                        # logits = branching_model.call((c_features_update.astype(np.float32), 
                        #                                 e_indices.astype(np.int32), 
                        #                                 e_features_update.astype(np.float32), 
                        #                                 updated_v_features_update.astype(np.float32), 
                        #                                 tf.reduce_sum(n_cs_per_sample, keepdims=True), 
                        #                                 tf.reduce_sum(n_vs_per_sample, keepdims=True)), 
                        #                                 tf.convert_to_tensor(False))

                        # logits = tf.keras.backend.get_value(logits)

                        # if branching_sess is None:
                        #     logits = tf.keras.backend.get_value(logits)
                        #     branching_sess = tf.keras.backend.get_session()
                        # else:
                        #     logits = branching_sess.run(logits)

                        if branching_sess is None:
                            # logits = tf.keras.backend.get_value(logits)
                            branching_sess = tf.keras.backend.get_session()
                        
                        # logits = branching_sess.run(outputs, feed_dict={
                        #     var0: updated_v_features_update, edge_features:e_features_update,
                        #     edge_indices:e_indices, cons: c_features_update, 
                        #     n_vars:n_vs_per_sample,
                        #     n_cons: n_cs_per_sample,
                        #     is_training: False
                        # })

                        logits = branching_sess.run(outputs, feed_dict={
                            var0: updated_v_features_update, edge_features:e_features,
                            edge_indices:e_indices, cons: c_features, 
                            n_vars:n_vs_per_sample,
                            n_cons: n_cs_per_sample,
                            is_training: False
                        })

                        branching_variable_counts = int(len(action[0]) * variable_to_branching_update)
                        # branching_variable_counts = int(len(action[0]) * variable_to_branching)
                        branching_index = np.argsort(logits[0])[-branching_variable_counts:]
                        action[0][branching_index] = 1

                        # print(logits[0])
                        # print(sum(sum(action)))
                        # exit()

                        action = pad_output(action, n_vs_per_sample)  # 还原到每个sample

                        # action = tf.Session().run(action)

                        # action = sess.run(action)

                        # for i in range(len(n_vs_per_sample)):
                        #     action[i] = action[i][:n_vs_per_sample[i]]  # 删掉补0的部分

                        sample_cur_sols = pad_output(cur_sols[np.newaxis, :], n_vs_per_sample)
                        # sample_cur_sols = tf.Session().run(sample_cur_sols)

                        # sample_cur_sols = sess.run(sample_cur_sols)

                        # for i in range(len(n_vs_per_sample)):
                        #     sample_cur_sols[i] = sample_cur_sols[i][:n_vs_per_sample[i]]

                        # Execute next action. derive next solution(state)
                        next_sols, epi, cur_objs, instances, mask = collect_samples(instances, epi, sample_cur_sols,
                                                                                    action, out_dir + '/train', rng,
                                                                                    batch_sample_eval,
                                                                                    args.njobs,
                                                                                    exploration_policy=exploration_strategy,
                                                                                    eval_flag=eval_val,
                                                                                    time_limit=time_limit,
                                                                                    no_improve_rounds=no_improve_rounds,
                                                                                    branching_step=start_branching_rounds,
                                                                                    integer_update_flag=integer_update_flag,
                                                                                    integer_list=integer_list.copy(),
                                                                                    integer_length=integer_length.copy())
                        

                        cur_sols = next_sols.copy()  # 获取优化后的solution

                        inc_ind = np.where(cur_objs > rec_best)[0]
                        rec_inc.append(rec_inc[-1])

                        for inds in inc_ind:
                            start_index = sum(n_vs_per_sample[:inds])
                            end_index = start_index + n_vs_per_sample[inds]
                            rec_inc[-1][start_index:end_index] = cur_sols[start_index:end_index]

                        rec_best[inc_ind] = cur_objs[inc_ind]

                        # compute rewards
                        r = cur_objs - ori_objs
                        print(cur_objs)
                        # note these outputs are batched from vecenv
                        t += 1

                        inc_val = rec_inc[-1].copy()
                        avg_inc_val = np.array(rec_inc).mean(0)

                        next_dynamic_variable_features = np.concatenate(
                            (inc_val[np.newaxis, :], avg_inc_val[np.newaxis, :], pre_sols[-2:]), axis=0)
                        next_updated_v_features = np.concatenate(
                            (v_features, next_dynamic_variable_features.transpose(1, 0)), axis=1)

                        next_obs = [c_features, e_indices, e_features, next_updated_v_features, n_cs_per_sample,
                                    n_vs_per_sample, variable_objective_features, variable_objective_indices, constraint_objective_features, constraint_objective_indices, objective_features]

                        cur_obs = next_obs
                        # updated_v_features_update = np.concatenate((v_features_update, next_dynamic_variable_features.transpose(1, 0)),
                        #                                 axis=1)

                        ori_objs = cur_objs

                        # Book-keeping.
                        if len(obj_lis) > 0:
                            if is_maximum == 1:
                                if cur_objs[0] <= max(obj_lis):
                                    no_improve_rounds += 1
                            else:
                                if cur_objs[0] >= min(obj_lis):
                                    no_improve_rounds += 1

                        obj_lis.append(cur_objs)
                        time_list.append(time.time() - abcd)

                        if time.time()-abcd >= test_time_limit:
                            # df = pd.DataFrame()
                            # df['Time'] = time_list
                            # df['Objective'] = obj_lis
                            # df.to_csv("data.csv", index=False)
                            # exit()

                            break
                    
                    if best_miunu is not None:
                        obj_lis.append(best_miunu)

                    tim = time.time()-abcd    
                    if is_maximum == 1:      
                        miniu = np.stack(obj_lis).max(axis=0)
                    else:
                        miniu = np.stack(obj_lis).min(axis=0)
                    
                    best_miunu = miniu

                    obj_lis = [tmp[0] for tmp in obj_lis]
                    obj_lis.insert(0, best_root[0])
                    time_list.insert(0, 0)
                    # time_list = [tmp/100 for tmp in time_list]
                    primal_integral = abs(simps(obj_lis, time_list) - best_bound_dict[instances[0]] * max(time_list))

                ave = np.mean(miniu)
                for j in range(batch_sample_eval):
                    writer.writerow({
                        'instance': instances[j],
                        'obj': miniu[j],
                        'initial': record_ini[j],
                        'bestroot': best_root[j],
                        'imp': miniu[j] - best_root[j],
                        'mean': ave,
                        'time': tim,
                        'Integral': primal_integral
                    })
                    csvfile.flush()



if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument(
        'problem',
        help='MILP instance type to process.',
        choices=['setcover', 'cats', 'facilities', 'indset', 'maxcut', 'item', 'miplib'],
    )
    parser.add_argument(
        '-s', '--seed',
        help='Random generator seed.',
        type=utilities.valid_seed,
        default=0,
    )
    parser.add_argument(
        '-j', '--njobs',
        help='Number of parallel jobs.',
        type=int,
        default=1,
    )
    
    parser.add_argument(
        '-t', '--total_timesteps',
        help='Number of total_timesteps.',
        type=int,
        default=1e4,
    )
                    
    parser.add_argument(
        '-g', '--gpu',
        help='CUDA GPU id (-1 for CPU).',
        type=int,
        default=4,
    )

    arg = parser.parse_args()
    # tf.enable_eager_execution()

    learn(args=arg)  # is_maximum参数，1表示最大问题，0表示最小问题