import os
import argparse
import multiprocessing as mp
import pickle
import glob
import numpy as np
import shutil
import gzip
from sklearn.metrics import v_measure_score
import tensorflow as tf
import csv
from pyscipopt import Model

import pyscipopt as scip
import utilities
from utilities_tf import load_batch_gcnn, load_batch_gcnn_branching

import time
from collections import deque
import pickle
import random

from ddpg.ddpg_learner import DDPG
from ddpg.models import GCNPolicy, GCNPolicy_critic
from ddpg.memory import Memory
from ddpg.noise import AdaptiveParamNoiseSpec, NormalActionNoise, OrnsteinUhlenbeckActionNoise
from ddpg.common import set_global_seeds
import ddpg.common.tf_util as U

import generate_instances_fly
import pandas as pd

import tensorflow.contrib.eager as tfe

import gurobipy as gp
from gurobipy import GRB

import shutil
import networkx as nx

from branching_net.model import GCNPolicy as Branching_policy

from warnings import simplefilter
simplefilter(action='ignore', category=FutureWarning)


def make_samples(in_queue, branching_step, integer_update_flag):
    """
    Worker loop: fetch an instance, run an episode and record samples.

    Parameters
    ----------
    in_queue : multiprocessing.Queue
        Input queue from which orders are received.
    out_queue : multiprocessing.Queue
        Output queue in which to send samples.
    """

    episode, instance, obs, actions, seed, exploration_policy, eval_flag, time_limit, out_dir, no_improve_rounds = in_queue
    print('[w {}] episode {}, seed {}, processing instance \'{}\'...'.format(os.getpid(),episode,seed,instance)) 

    if eval_flag==1:
        seed=0
    else:
        seed=0

    m = scip.Model()
    m.setIntParam('display/verblevel', 0)
    m.readProblem('{}'.format(instance))
    if time_limit < 30:
        utilities.init_scip_paramsR_setcover(m, seed=seed)
    m.setIntParam('timing/clocktype', 1)
    m.setRealParam('limits/time', time_limit)   # 设定求解时间，避免训练时间无限延长

    varss = [x for x in m.getVars()]   # 取出全部变量

    minimum_k = np.where(np.array(actions.squeeze())<0.5)
    min_k = minimum_k[0]

    if no_improve_rounds > branching_step:
        m.setRealParam('limits/time', 10)   # 设定求解时间，避免训练时间无限延长
        m.addCons(sum(abs(varss[i] - obs[i]) for i in min_k) <= 5)
    else:
        counts = 0
        for i in min_k:
            cur_variable = varss[i]
            if integer_update_flag and cur_variable.vtype() == 'INTEGER' and counts < 2000:
                m.setRealParam('limits/time', 180)  # 对于整数变量放缩范围后，需要增加单步求解时间
                upbound = cur_variable.getUbOriginal()
                lowbound = cur_variable.getLbOriginal()
                
                bound_update = (upbound + lowbound) / 2
                if obs[i] >= bound_update:
                    m.addCons(cur_variable >= bound_update)
                else:
                    m.addCons(cur_variable <= bound_update)
                
                counts += 1
            else:
                a,b = m.fixVar(varss[i],obs[i])   

    m.optimize()

    print(m.getPrimalbound())
    if abs(m.getPrimalbound()) > 1e15:
        K = obs   # 未得到可行解的情况下，各变量的取值保持不变
        obj = abs(m.getPrimalbound())  # 目标值赋予无穷大
    else:
        K = [m.getVal(x) for x in m.getVars()]   #获取各变量的取值
        obj = m.getObjVal()

    # obj = m.getPrimalbound()  # 获取当前最优解

    m.freeProb() 

    out_queue = {
        'type': 'solution',
        'episode': episode,
        'instance': instance,
        'sol' : np.array(K),
        'obj' : obj,
        'seed': seed,
        'mask': min_k,
    }               
    
    return out_queue


def send_orders(instances, epi, obs, actions, seed, exploration_policy, eval_flag, time_limit, out_dir, no_improve_rounds):
    """
    Continuously send sampling orders to workers (relies on limited
    queue capacity).

    Parameters
    ----------
    orders_queue : multiprocessing.Queue
        Queue to which to send orders.
    instances : list
        Instance file names from which to sample episodes.
    seed : int
        Random seed for reproducibility.
    exploration_policy : str
        Branching strategy for exploration.
    query_expert_prob : float in [0, 1]
        Probability of running the expert strategy and collecting samples.
    time_limit : float in [0, 1e+20]
        Maximum running time for an episode, in seconds.
    out_dir: str
        Output directory in which to write samples.
    """
    rng = np.random.RandomState(seed)

    orders_queue = []
    for i in range(len(instances)):
        seed = rng.randint(2**32)
        orders_queue.append([epi[i], instances[i], obs[i], actions[i], seed, exploration_policy, eval_flag, time_limit, out_dir,no_improve_rounds])

    return orders_queue


def collect_samples(instances, epi, obs, actions, out_dir, rng, n_samples, n_jobs,
                    exploration_policy, eval_flag, time_limit, no_improve_rounds, branching_step, integer_update_flag):
    """
    Runs branch-and-bound episodes on the given set of instances, and collects
    randomly (state, action) pairs from the 'vanilla-fullstrong' expert
    brancher.

    Parameters
    ----------
    instances : list
        Instance files from which to collect samples.
    out_dir : str
        Directory in which to write samples.
    rng : numpy.random.RandomState
        A random number generator for reproducibility.
    n_samples : int
        Number of samples to collect.
    n_jobs : int
        Number of jobs for parallel sampling.
    exploration_policy : str
        Exploration policy (branching rule) for sampling.
    query_expert_prob : float in [0, 1]
        Probability of using the expert policy and recording a (state, action)
        pair.
    time_limit : float in [0, 1e+20]
        Maximum running time for an episode, in seconds.
    """
    tmp_samples_dir = '{}/tmp'.format(out_dir)
    os.makedirs(tmp_samples_dir, exist_ok=True)
    
    pars = send_orders(instances, epi, obs, actions, rng.randint(2**32), exploration_policy, eval_flag, time_limit, tmp_samples_dir, no_improve_rounds) 
    
    out_Q = []
    for n in range(n_samples):
        out_queue = make_samples(pars[n], branching_step, integer_update_flag)
        out_Q.append(out_queue)        

    # record answers 
    i = 0
    collecter=[]
    epi=[]
    obje=[]
    instances=[]
    mask=[]

    for sample in out_Q:
        
        collecter.append(sample['sol'])
        
        epi.append(sample['episode'])
        
        obje.append(sample['obj'])

        instances.append(sample['instance'])

        mask.append(sample['mask'])
        
        i += 1

    shutil.rmtree(tmp_samples_dir, ignore_errors=True)
   
    return np.concatenate(np.stack(collecter), axis=0), np.stack(epi), np.stack(obje), instances, mask
    
##########  initial formulation features    
class SamplingAgent0(scip.Branchrule):

    def __init__(self, episode, instance, seed, exploration_policy, out_dir):
        self.episode = episode
        self.instance = instance
        self.seed = seed
        self.exploration_policy = exploration_policy
        self.out_dir = out_dir

        self.rng = np.random.RandomState(seed)
        self.new_node = True
        self.sample_counter = 0

    def branchinitsol(self):
        self.ndomchgs = 0
        self.ncutoffs = 0
        self.state_buffer = {}

    def branchexeclp(self, allowaddcons):

        # custom policy branching           
        if self.model.getNNodes() == 1:
            
            # extract formula features
            self.state = utilities.extract_state(self.model, self.state_buffer)              

            result = self.model.executeBranchRule(self.exploration_policy, allowaddcons)
                               
        elif self.model.getNNodes() != 1:
               
            result = self.model.executeBranchRule(self.exploration_policy, allowaddcons)   
            
        else:
            raise NotImplementedError

        # fair node counting
        if result == scip.SCIP_RESULT.REDUCEDDOM:
            self.ndomchgs += 1
        elif result == scip.SCIP_RESULT.CUTOFF:
            self.ncutoffs += 1

        return {'result': result}


def make_samples0(in_queue, initial_solution_flag, node_limit, pre_solve, conflict):
    """
    Worker loop: fetch an instance, run an episode and record samples.

    Parameters
    ----------
    in_queue : multiprocessing.Queue
        Input queue from which orders are received.
    out_queue : multiprocessing.Queue
        Output queue in which to send samples.
    """

#     while True:
    episode, instance, seed, exploration_policy, eval_flag, time_limit, out_dir = in_queue
    print('[w {}] episode {}, seed {}, processing instance \'{}\'...'.format(os.getpid(),episode,seed,instance))

    if eval_flag==1:
        seed=0
    else:
        seed=0
    
    m = scip.Model()
    m.setIntParam('display/verblevel', 0)
    m.readProblem('{}'.format(instance))
    # utilities.init_scip_paramsH(m, seed=seed)
    utilities.init_scip_paramsH_setcover(m, heuristics=initial_solution_flag, presolving=pre_solve, conflict=conflict, seed=seed)
    m.setIntParam('timing/clocktype', 2)
    if node_limit:
        m.setLongintParam('limits/nodes', 1)  # 仅处理当前节点
    else:
        m.setParam('limits/solutions', 1)#    m.setRealParam('limits/time', 50) 
#    m.setParam('limits/solutions', 1)

    branchrule = SamplingAgent0(
        episode=episode,
        instance=instance,
        seed=seed,
        exploration_policy=exploration_policy,
        out_dir=out_dir)

    m.includeBranchrule(
        branchrule=branchrule,
        name="Sampling branching rule", desc="",
        priority=666666, maxdepth=-1, maxbounddist=1)

    abc=time.time()
    print(m)
    print("------------------------------------------------------------------------")
    m.optimize()       
    print(time.time()-abc)    

    b_obj = m.getObjVal()  # 得到当前最优解

    K = [m.getVal(x) for x in m.getVars()]       # 取出每个变量取值
    varss = [x for x in m.getVars()]

    if node_limit:
        state = branchrule.state
    else:
        g = scip.Model()
        # g.setIntParam('display/verblevel', 0)
        g.readProblem('{}'.format(instance))
        varss = [x for x in g.getVars()]
        idsx = 0
        length = len(varss) * 0.03
        if "131" in instance:
            length = len(varss) * 0.1
        if '108' in instance:
            length = len(varss) * 0.2
        if 'cauction' in instance:
            if '48' in instance:
                length = len(varss) * 0.85
            else:
                length = len(varss) * 0.9
        for i in range(len(varss)):
            if (K[i] == 1 or K[i] == 0) and idsx <= length:
                a,b = g.fixVar(varss[i], K[i])
                idsx += 1
        # utilities.init_scip_paramsH(m, seed=seed)
        print("Here")
        utilities.init_scip_paramsH_setcover(g, heuristics=True, presolving=False, seed=seed)
        g.setLongintParam('limits/nodes', 1)
        new_branchrule = SamplingAgent0(
            episode=episode,
            instance=instance,
            seed=seed,
            exploration_policy=exploration_policy,
            out_dir=out_dir)
        g.includeBranchrule(
            branchrule=new_branchrule,
            name="Sampling branching rule", desc="",
            priority=666666, maxdepth=-1, maxbounddist=1)
        g.optimize()
        state = new_branchrule.state

        g.freeProb()
    
    print("HERE")
    var_list = []
    # print(varss)
    # for cur_var in varss:
    #     print(cur_var.name)
    #     var_list.append({})
    #     var_list[-1]['name'] = cur_var.name
    #     var_list[-1]['type'] = cur_var.vtype()
    #     var_list[-1]['upbound'] = cur_var.getUbOriginal()
    #     var_list[-1]['lowbound'] = cur_var.getLbOriginal()
    
    print("FINISH")

    out_queue = {
        'type': 'formula',
        'episode': episode,
        'instance': instance,
        'state' : state,
        'seed': seed,
        'b_obj': b_obj,
        'sol' : np.array(K),
        'vars': var_list,      
    }   

    print(b_obj)
       
    m.freeTransform()  

    obj = [x.getObj() for x in m.getVars()]  
    
    out_queue['obj'] = sum(obj) 
    
    m.freeProb() 
        
    print("[w {}] episode {} done".format(os.getpid(),episode))
    
    return out_queue


def send_orders0(instances, n_samples, seed, exploration_policy, batch_id, eval_flag, time_limit, out_dir):
    """
    Continuously send sampling orders to workers (relies on limited
    queue capacity).

    Parameters
    ----------
    orders_queue : multiprocessing.Queue
        Queue to which to send orders.
    instances : list
        Instance file names from which to sample episodes.
    seed : int
        Random seed for reproducibility.
    exploration_policy : str
        Branching strategy for exploration.
    query_expert_prob : float in [0, 1]
        Probability of running the expert strategy and collecting samples.
    time_limit : float in [0, 1e+20]
        Maximum running time for an episode, in seconds.
    out_dir: str
        Output directory in which to write samples.
    """
    rng = np.random.RandomState(seed)

    episode = 0
    st = batch_id*n_samples
    orders_queue = []
    for i in instances[st:st+n_samples]:     
        seed = rng.randint(2**32)
        orders_queue.append([episode, i, seed, exploration_policy, eval_flag, time_limit, out_dir])
        episode += 1
    return orders_queue



def collect_samples0(instances, out_dir, rng, n_samples, n_jobs,
                    exploration_policy, batch_id, eval_flag, time_limit, initial_solution_flag, node_limit, presolve, conflict):
    """
    Runs branch-and-bound episodes on the given set of instances, and collects
    randomly (state, action) pairs from the 'vanilla-fullstrong' expert
    brancher.

    Parameters
    ----------
    instances : list
        Instance files from which to collect samples.
    out_dir : str
        Directory in which to write samples.
    rng : numpy.random.RandomState
        A random number generator for reproducibility.
    n_samples : int
        Number of samples to collect.
    n_jobs : int
        Number of jobs for parallel sampling.
    exploration_policy : str
        Exploration policy (branching rule) for sampling.
    query_expert_prob : float in [0, 1]
        Probability of using the expert policy and recording a (state, action)
        pair.
    time_limit : float in [0, 1e+20]
        Maximum running time for an episode, in seconds.
    """
    os.makedirs(out_dir, exist_ok=True)

    tmp_samples_dir = '{}/tmp'.format(out_dir)
    os.makedirs(tmp_samples_dir, exist_ok=True)
    
    pars = send_orders0(instances, n_samples, rng.randint(2**32), exploration_policy, batch_id, eval_flag, time_limit, tmp_samples_dir)  
    
    out_Q = []
    for n in range(n_samples):
        out_queue = make_samples0(pars[n], initial_solution_flag, node_limit, presolve, conflict)
        out_Q.append(out_queue)        
        

    # record answers and write samples
    i = 0

    states = []

    epi=[]
    instances=[]
    obje=[]
    bobj=[]
    ini_sol=[]
    varss = []

    
    for sample in out_Q:
        
        ini_sol.append(sample['sol'])         
        
        states.append(sample['state'])  # states at current node
        
        epi.append(sample['episode'])
        
        instances.append(sample['instance'])
        
        obje.append(sample['obj'])

        bobj.append(sample['b_obj'])

        varss.append(sample['vars'])

        i += 1

    shutil.rmtree(tmp_samples_dir, ignore_errors=True)
    
    return states, np.stack(epi), np.stack(obje), np.stack(bobj), instances, np.concatenate(np.stack(ini_sol), axis=0), np.concatenate(np.stack(varss), axis=0)  # 需要将初始解concat至大图上


def pad_output(output, n_vars_per_sample, pad_value=-1e8):

    new_output = []

    start_index = 0
    for cur_length in n_vars_per_sample:
        end_index = start_index + cur_length
        new_output.append(output[0][start_index:end_index])
        start_index = end_index
    
    return new_output

    # n_vars_max = tf.reduce_max(n_vars_per_sample)

    # output = tf.split(
    #     value=output,
    #     num_or_size_splits=n_vars_per_sample,
    #     axis=1,
    # )
    # output = tf.concat([
    #     tf.pad(
    #         x,
    #         paddings=[[0, 0], [0, n_vars_max - tf.shape(x)[1]]],
    #         mode='CONSTANT',
    #         constant_values=pad_value)
    #     for x in output
    # ], axis=0)


    # return output



def learn(args, is_maximum=1,network='mlp',
          seed=None,
          total_timesteps=None,
          nb_epochs=None, # with default settings, perform 1M steps total
          nb_epoch_cycles=25,
          nb_rollout_steps=20,
          reward_scale=1.0,
          render=False,
          render_eval=False,
          noise_type=None,
          normalize_returns=False,
          normalize_observations=False,
          critic_l2_reg=1e-2,
          actor_lr=1e-4,
          critic_lr=1e-3,
          popart=False,
          gamma=0.99,  #0.9 #0.96
          clip_norm=None,
          nb_train_steps=1, # per epoch cycle and MPI worker,  50 10 30   100  10  3
          nb_eval_steps=20000, # default:1000
          batch_size=10, # per MPI worker  64 32  64   128   64  128
          tau=0.01,
          eval_env=None,
          save_path=None,
          param_noise_adaption_interval=30):
    
    test_time_limit = 200  # default:200
    load_path = 'models/RL_model/' + args.problem + "/model_graph.joblib"

    print("seed {}".format(args.seed))

    batch_sample = 10  #8
    batch_sample_eval = 1 #8
    exploration_strategy = 'relpscost'
    eval_val = 0
    time_limit = 2  # 5  #2
    variable_to_branching = 0.02  # Branching variable ratio

    integer_update_flag = False  # 问题中是否包含整数，若包含，对fix部分进行range update

    pre_solve = True
    conflict = True

    max_variable_size = None  # 每个子问题的最大variable size

    instances_valid = []

    run_times = 1

    if args.problem == 'maxcut':
#        instances_valid = glob.glob('data/instances/test_4950_2975/*.lp')
        instances_valid += glob.glob('data/instances/transfer_9950_5975/*.lp')
        instances_valid += glob.glob('data/instances/transfer_19950_11975/*.lp')
        out_dir = 'data/samples/tmp'
        initial_solution_heu = True  # initial feasible solution
        sub_mip_ratio = 0.75  # limited size ratio for sub-mip problems
        start_branching_rounds = 10000  # steps to start local branching
        is_maximum = 1  # maximum or Minimum problem
        node_limit = True  # limited to one node/ limit to one solution
        variable_to_branching = 0.85 # 0.85
        emb_size = 16

        run_times = 3

        test_time_limit = 500
    elif args.problem == 'cats':
#        instances_valid = glob.glob('data/instances/cauctions/test_2000_4000/*.lp')
        instances_valid += glob.glob('data/instances/cauctions/transfer_4000_8000/*.lp')
        instances_valid += glob.glob('data/instances/cauctions/transfer_8000_16000/*.lp')
        out_dir = 'data/samples/tmp'
        initial_solution_heu = True  # initial feasible solution
        sub_mip_ratio = 1  # limited size ratio for sub-mip problems
        start_branching_rounds = 10000  # steps to start local branching
        is_maximum = 1  # maximum or Minimum problem
        node_limit = False  # limited to one node/ limit to one solution

        pre_solve = False
        conflict = False
        variable_to_branching = 0.25  # 0.55,0.35
        emb_size = 16

        time_limit = 10  # 12

        run_times = 3

        test_time_limit = 500
    elif args.problem == 'indset':
#        instances_valid = glob.glob('data/instances/indset/test_1500_4/*.lp')
        instances_valid += glob.glob('data/instances/indset/transfer_6000_4/*.lp')
        instances_valid += glob.glob('data/instances/indset/transfer_3000_4/*.lp')
        out_dir = 'data/samples/tmp'
        initial_solution_heu = True  # initial feasible solution
        sub_mip_ratio = 0.75  # limited size ratio for sub-mip problems
        start_branching_rounds = 10000  # steps to start local branching
        is_maximum = 1  # maximum or Minimum problem
        node_limit = True  # limited to one node/ limit to one solution

        pre_solve = False
        variable_to_branching = 0.65
        emb_size = 16

        time_limit = 2

        run_times = 3
    elif args.problem == 'setcover':
#        instances_valid += ["data/instances/setcover/test_5000r_1000c_0.05d/instance_{}.lp".format(i+101) for i in range(50)]
        instances_valid += glob.glob('data/instances/setcover/transfer_5000r_4000c_0.05d/*.lp')  # transfer
        instances_valid += glob.glob('data/instances/setcover/transfer_5000r_2000c_0.05d/*.lp')  # transfer
        out_dir = 'data/samples/tmp'
        initial_solution_heu = False
        sub_mip_ratio = 1
        start_branching_rounds = 10  # 10
        is_maximum = 0
        node_limit = True  # limited to one node/ limit to one solution

        emb_size = 16

        run_times = 3
    elif args.problem == 'item':
        instances_valid += ["data/instances/item_placement/test/item_placement_{}.mps.gz".format(i + 10000) for i in range(100)]
        out_dir = 'data/samples/tmp'
        initial_solution_heu = True
        sub_mip_ratio = 0.15  # 0.15
        start_branching_rounds = 10000
        is_maximum = 0
        node_limit = True  # limited to one node/ limit to one solution
        emb_size = 16

        variable_to_branching = 0.25

        run_times = 8

        time_limit = 1
    elif args.problem == 'miplib':
        # instances_valid += ["data/instances/anonymous/test/anonymous_{}.mps.gz".format(i + 119) for i in range(20)]
        instances_valid += ["data/instances/anonymous/valid/anonymous_{}.mps.gz".format(i + 98) for i in range(20)]
        # instances_valid = ["data/instances/anonymous/test/anonymous_126.mps.gz"]
        # instances_valid = ["data/instances/anonymous/test/anonymous_133.mps.gz"]

        out_dir = "data/samples/tmp"

        # filter the samples
        for instances in list(instances_valid):
            model = Model()
            model.readProblem(instances)
            if model.getNVars() > 200000:
                instances_valid.remove(instances)
        
        batch_size = 2

        is_maximum = 0
        initial_solution_heu = True  # initial feasible solution
        node_limit = False  # limited to one node/ limit to one solution

        sub_mip_ratio = 0.7
        start_branching_rounds = 10000

        time_limit = 60
        test_time_limit = 1800

        max_variable_size = 100000  # 对于MIPLIB，需要限制最大variable size，否则子问题可能无法求解
        emb_size = 6

        variable_to_branching = 0.25
        integer_update_flag = True
    else:
        raise NotImplementedError
    
    memory_size = 490

    memory = Memory(memory_size, batch_size)
    critic = GCNPolicy_critic(batch_size, emb_size)
    actor = GCNPolicy(emb_size)

    action_noise = None
    param_noise = None

    agent = DDPG(actor, critic, memory, gamma=gamma, tau=tau, normalize_returns=normalize_returns,
                 normalize_observations=normalize_observations,
                 batch_size=batch_size, action_noise=action_noise, param_noise=param_noise, critic_l2_reg=critic_l2_reg,
                 actor_lr=actor_lr, critic_lr=critic_lr, enable_popart=popart, clip_norm=clip_norm,
                 reward_scale=reward_scale)


    ### TENSORFLOW SETUP ###
    if args.gpu == -1:
        os.environ['CUDA_VISIBLE_DEVICES'] = ''
    else:
        os.environ['CUDA_VISIBLE_DEVICES'] = '{}'.format(args.gpu)                    

    sess = U.get_session()
    # Prepare everything.
    agent.initialize(sess)

    if load_path is not None:
        agent.load(load_path)

    # sess.graph.finalize()

    rng = np.random.RandomState(args.seed)
    tf.set_random_seed(rng.randint(np.iinfo(int).max))


    agent.reset()

    # branching_policy = df_model(args=args)
    # exit()

    # Load Branching Policies
    branching_sess = None

    branching_load_path = f"branching_net/trained_models/{args.problem}/baseline/{args.seed}/best_params.pkl"
    branching_model = Branching_policy()
    branching_model.restore_state(branching_load_path)
    branching_model.call = tfe.defun(branching_model.call, input_signature=branching_model.input_signature)
    # create output directory, throws an error if it already exists                  
    episodes = 0 #scalar
    t = 0 # scalar

    var0 = tf.placeholder(tf.float32, shape=(None,) + (15,), name='var0')
    cons = tf.placeholder(tf.float32, shape=(None,) + (5,), name='cons')
    edge_features = tf.placeholder(tf.float32, shape=(None,) + (1,), name='edge_index')
    edge_indices = tf.placeholder(tf.int32, shape=(2,) + (None,), name='edge_fea')
    n_vars = tf.placeholder(tf.int32, shape=(None,), name='n_var')
    n_cons = tf.placeholder(tf.int32, shape=(None,), name='n_cons')
    is_training = tf.placeholder(tf.bool, shape=(), name='is_training')

    feature_inputs = (cons, edge_indices, edge_features, var0, n_cons, n_vars)
    outputs = branching_model.call(feature_inputs, is_training)


    max_obj = 0       
    #### start train
    for epoch in range(1):

        fieldnames = [
            'instance',
            'obj',
            'initial',
            'bestroot',
            'imp',
            'mean',
            'time',
        ]
        result_file = "{}_{}.csv".format(args.problem,time.strftime('%Y%m%d-%H%M%S'))    
        # result_file = "{}_{}.csv".format(args.problem, str(variable_to_branching))
        os.makedirs('ddpg_test_results', exist_ok=True)
        with open("ddpg_test_results/{}".format(result_file), 'w', newline='') as csvfile:
            writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
            writer.writeheader()    

            for cyc in range(len(instances_valid)//batch_sample_eval):
                best_miunu = None
                for t in range(run_times):  # Run multiple times

                    ### initial formulation features
                    no_improve_rounds = 0

                    obj_lis = []
                    time_list = []
                    abcd=time.time()

                    current_states, epi, ori_objs, best_root, instances, ini_sol, varss = collect_samples0(instances_valid,
                                                                                                            out_dir + '/train',
                                                                                                            rng,
                                                                                                            batch_sample_eval,
                                                                                                            args.njobs,
                                                                                                            exploration_policy=exploration_strategy,
                                                                                                            batch_id=cyc,
                                                                                                            eval_flag=eval_val,
                                                                                                            time_limit=None,
                                                                                                            initial_solution_flag = initial_solution_heu,
                                                                                                            node_limit=node_limit,
                                                                                                            presolve = pre_solve,
                                                                                                            conflict = conflict)

                    collect_data = [current_states, epi, ori_objs, best_root, instances, ini_sol, varss]
                    print(instances)
                    print(instances[0])
                    if args.problem == 'item' or args.problem == 'miplib':
                        dirr = instances[0].split("/")[4][:-7]
                    elif args.problem == 'maxcut':
                        cur_dir = instances[0].split("/")[2]
                        all_cur_dir = 'collect/' + args.problem + "/" + cur_dir
                        if not os.path.exists(all_cur_dir):
                            os.makedirs(all_cur_dir, exist_ok=True)
                        dirr = instances[0].split("/")[3][:-3]
                    else:
                        cur_dir = instances[0].split("/")[3]
                        all_cur_dir = 'collect/' + args.problem + "/" + cur_dir
                        if not os.path.exists(all_cur_dir):
                            os.makedirs(all_cur_dir, exist_ok=True)
                        dirr = instances[0].split("/")[4][:-3]
                    # filename = f'collect/' + args.problem + "/" + cur_dir + "/" + dirr + '.pkl'
                    if args.problem == 'item' or args.problem == 'miplib':
                        filename = f'collect/' + args.problem + "/" + dirr + '.pkl'
                    else:
                        filename = f'collect/' + args.problem + "/" + cur_dir + "/" + dirr + '.pkl'

                    with gzip.open(filename, 'wb') as f:
                        pickle.dump({
                        'data': collect_data,
                        }, f)
                    
                    break


                    ### initial solution
                    init_sols = ini_sol  

                    ori_objs = np.copy(best_root)

                    cur_sols = init_sols

                    record_ini = np.copy(ori_objs)

                    rec_inc = []  # 保存各sample的feasible solution list
                    rec_inc.append(init_sols)  # 每个元素的长度为batch内全部variable num之和

                    rec_best = np.copy(best_root)  # ADD

                    inc_val = rec_inc[-1].copy()  # 当前时刻最优解对应各变量取值
                    avg_inc_val = np.array(rec_inc).mean(0)  # 各variable平均solution取值

                    pre_sols = np.ones([2, len(rec_inc[0])])  # variable feature

                    c_features, e_indices, e_features, v_features, n_cs_per_sample, n_vs_per_sample, variable_objective_features, variable_objective_indices, constraint_objective_features, constraint_objective_indices, objective_features = load_batch_gcnn(
                        current_states)
                    
                    c_features_update, _, e_features_update, v_features_update, _, _, _, _, _, _, _ = load_batch_gcnn_branching(
                        current_states)
                    
                    dynamic_variable_features = np.concatenate(
                        (inc_val[np.newaxis, :], avg_inc_val[np.newaxis, :], pre_sols), axis=0)
                    updated_v_features = np.concatenate((v_features, dynamic_variable_features.transpose(1, 0)),
                                                        axis=1)
                    updated_v_features_update = np.concatenate((v_features_update, dynamic_variable_features.transpose(1, 0)),
                                                        axis=1)

                    # current observations
                    variable_objective_features = variable_objective_features.reshape(-1, 1)
                    constraint_objective_features = constraint_objective_features.reshape(-1, 1)
                    cur_obs = [c_features, e_indices, e_features, updated_v_features, n_cs_per_sample,
                                n_vs_per_sample, variable_objective_features, variable_objective_indices, constraint_objective_features, constraint_objective_indices, objective_features]

                    mask = None

                    count_array = np.array([0 for i in range(n_vs_per_sample[0])])

                    variable_to_branching_update = variable_to_branching

                    # Perform rollouts.                
                    for t_rollout in range(nb_eval_steps): 
                    
                        action, q, _, _ = agent.step(cur_obs, apply_noise=True, compute_Q=True)
                        pre_sols = np.concatenate((pre_sols, cur_sols[np.newaxis, :]), axis=0)

                        action = np.random.binomial(1, action)

                        sample = np.where(action[0] == 1.0)[0]  # 当前step需要求解的变量集合
                        
                        binary_vars = np.where(v_features[:,2] == 0)  # 整数变量集合

                        sample = np.array(list(set(sample).intersection(set(binary_vars[0]))))  # 当前step需要求解的整数变量集合
                        
                        if max_variable_size is None:
                            sub_mip_length = int(len(action[0]) * sub_mip_ratio)
                        else:
                            sub_mip_length = min(int(len(action[0]) * sub_mip_ratio), max_variable_size)

                        count_prob_array = (max(count_array) + 1) - count_array
                        count_prob_array = count_prob_array[sample]
                        count_prob = count_prob_array/sum(count_prob_array)

                        if len(sample) > sub_mip_length:
                            idx = np.random.choice(sample, len(sample) - sub_mip_length, replace=False, p=count_prob)
                        else:
                            idx = []
                        
                        solve_ids = np.array(list(set(sample) - set(idx)))
                        count_array[solve_ids] += 1
                        
                        action = np.where(action > 0.5, action, 0.)
                        action = np.where(action == 0., action, 1.)

                        # action[0] = action[0] * sample
                        action[0][idx] = 0
                        action[0][np.where(v_features[:,2] == 1)] = 1
                        
                        # logits = branching_policy((c_features.astype(np.float32), e_indices.astype(np.int32), e_features.astype(np.float32), cur_obs[3].astype(np.float32), tf.reduce_sum(n_cs_per_sample, keepdims=True), tf.reduce_sum(n_vs_per_sample, keepdims=True)), tf.convert_to_tensor(False))

                        # updated_v_features_update = np.concatenate((updated_v_features_update, action.transpose(1,0)), axis=1)
                        updated_v_features_update = np.concatenate((cur_obs[3], action.transpose(1,0)), axis=1)
                        
                        # logits = branching_model.call((c_features_update.astype(np.float32), 
                        #                                 e_indices.astype(np.int32), 
                        #                                 e_features_update.astype(np.float32), 
                        #                                 updated_v_features_update.astype(np.float32), 
                        #                                 tf.reduce_sum(n_cs_per_sample, keepdims=True), 
                        #                                 tf.reduce_sum(n_vs_per_sample, keepdims=True)), 
                        #                                 tf.convert_to_tensor(False))

                        # logits = tf.keras.backend.get_value(logits)

                        # if branching_sess is None:
                        #     logits = tf.keras.backend.get_value(logits)
                        #     branching_sess = tf.keras.backend.get_session()
                        # else:
                        #     logits = branching_sess.run(logits)

                        if branching_sess is None:
                            # logits = tf.keras.backend.get_value(logits)
                            branching_sess = tf.keras.backend.get_session()
                        
                        # logits = branching_sess.run(outputs, feed_dict={
                        #     var0: updated_v_features_update, edge_features:e_features_update,
                        #     edge_indices:e_indices, cons: c_features_update, 
                        #     n_vars:n_vs_per_sample,
                        #     n_cons: n_cs_per_sample,
                        #     is_training: False
                        # })

                        logits = branching_sess.run(outputs, feed_dict={
                            var0: updated_v_features_update, edge_features:e_features,
                            edge_indices:e_indices, cons: c_features, 
                            n_vars:n_vs_per_sample,
                            n_cons: n_cs_per_sample,
                            is_training: False
                        })
                        
                        # 需要branching的变量，强制转换为自由变量
                        if no_improve_rounds > 25 and (args.problem == 'cats' or args.problem == 'miplib'):
                            variable_to_branching_update = max(variable_to_branching_update + 0.04, 0.55)
                            no_improve_rounds = 0
                            # logits[0] = 1 - logits[0]
                            # logits[0] = logits[0] + np.array([random.random() for i in range(len(logits[0]))])

                        branching_variable_counts = int(len(action[0]) * variable_to_branching_update)
                        # branching_variable_counts = int(len(action[0]) * variable_to_branching)
                        branching_index = np.argsort(logits[0])[-branching_variable_counts:]
                        action[0][branching_index] = 1

                        # print(logits[0])
                        # print(sum(sum(action)))
                        # exit()

                        action = pad_output(action, n_vs_per_sample)  # 还原到每个sample

                        # action = tf.Session().run(action)

                        # action = sess.run(action)

                        # for i in range(len(n_vs_per_sample)):
                        #     action[i] = action[i][:n_vs_per_sample[i]]  # 删掉补0的部分

                        sample_cur_sols = pad_output(cur_sols[np.newaxis, :], n_vs_per_sample)
                        # sample_cur_sols = tf.Session().run(sample_cur_sols)

                        # sample_cur_sols = sess.run(sample_cur_sols)

                        # for i in range(len(n_vs_per_sample)):
                        #     sample_cur_sols[i] = sample_cur_sols[i][:n_vs_per_sample[i]]

                        # Execute next action. derive next solution(state)
                        next_sols, epi, cur_objs, instances, mask = collect_samples(instances, epi, sample_cur_sols,
                                                                                    action, out_dir + '/train', rng,
                                                                                    batch_sample_eval,
                                                                                    args.njobs,
                                                                                    exploration_policy=exploration_strategy,
                                                                                    eval_flag=eval_val,
                                                                                    time_limit=time_limit,
                                                                                    no_improve_rounds=no_improve_rounds,
                                                                                    branching_step=start_branching_rounds,
                                                                                    integer_update_flag=integer_update_flag)
                        

                        cur_sols = next_sols.copy()  # 获取优化后的solution

                        inc_ind = np.where(cur_objs > rec_best)[0]
                        rec_inc.append(rec_inc[-1])

                        for inds in inc_ind:
                            start_index = sum(n_vs_per_sample[:inds])
                            end_index = start_index + n_vs_per_sample[inds]
                            rec_inc[-1][start_index:end_index] = cur_sols[start_index:end_index]

                        rec_best[inc_ind] = cur_objs[inc_ind]

                        # compute rewards
                        r = cur_objs - ori_objs
                        print(cur_objs)
                        # note these outputs are batched from vecenv
                        t += 1

                        inc_val = rec_inc[-1].copy()
                        avg_inc_val = np.array(rec_inc).mean(0)

                        next_dynamic_variable_features = np.concatenate(
                            (inc_val[np.newaxis, :], avg_inc_val[np.newaxis, :], pre_sols[-2:]), axis=0)
                        next_updated_v_features = np.concatenate(
                            (v_features, next_dynamic_variable_features.transpose(1, 0)), axis=1)

                        next_obs = [c_features, e_indices, e_features, next_updated_v_features, n_cs_per_sample,
                                    n_vs_per_sample, variable_objective_features, variable_objective_indices, constraint_objective_features, constraint_objective_indices, objective_features]

                        cur_obs = next_obs
                        updated_v_features_update = np.concatenate((v_features_update, next_dynamic_variable_features.transpose(1, 0)),
                                                        axis=1)

                        ori_objs = cur_objs

                        # Book-keeping.
                        if len(obj_lis) > 0:
                            if is_maximum == 1:
                                if cur_objs[0] <= max(obj_lis):
                                    no_improve_rounds += 1
                            else:
                                if cur_objs[0] >= min(obj_lis):
                                    no_improve_rounds += 1

                        obj_lis.append(cur_objs)
                        time_list.append(time.time() - abcd)

                        if time.time()-abcd >= test_time_limit:
                            # df = pd.DataFrame()
                            # df['Time'] = time_list
                            # df['Objective'] = obj_lis
                            # df.to_csv("data_maxcut.csv", index=False)
                            # exit()

                            break
                    
                    if best_miunu is not None:
                        obj_lis.append(best_miunu)

                    tim = time.time()-abcd    
                    if is_maximum == 1:      
                        miniu = np.stack(obj_lis).max(axis=0)
                    else:
                        miniu = np.stack(obj_lis).min(axis=0)
                    
                    best_miunu = miniu



if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument(
        'problem',
        help='MILP instance type to process.',
        choices=['setcover', 'cats', 'facilities', 'indset', 'maxcut', 'item', 'miplib'],
    )
    parser.add_argument(
        '-s', '--seed',
        help='Random generator seed.',
        type=utilities.valid_seed,
        default=0,
    )
    parser.add_argument(
        '-j', '--njobs',
        help='Number of parallel jobs.',
        type=int,
        default=1,
    )
    
    parser.add_argument(
        '-t', '--total_timesteps',
        help='Number of total_timesteps.',
        type=int,
        default=1e4,
    )
                    
    parser.add_argument(
        '-g', '--gpu',
        help='CUDA GPU id (-1 for CPU).',
        type=int,
        default=4,
    )

    arg = parser.parse_args()
    # tf.enable_eager_execution()

    learn(args=arg)  # is_maximum参数，1表示最大问题，0表示最小问题