import sys
from turtle import shape
sys.path.append('..')
import os
import argparse
import multiprocessing as mp
import pickle
import glob
import numpy as np
import shutil
import gzip
import torch
import pyscipopt as scip
from src import utils
from src.utils import set_scip
import time
#nohup python sample_generator_depth.py -p setcover --code_id new_33_1 --scale 100 >logs/generator1.setcover.out
#nohup python sample_generator_depth.py -p setcover --code_id new_33_3 --scale 100 >logs/generator1.setcover.out
#srun -p dell --gres=gpu:1 python sample_generator_depth.py --code_id new_33_1
#srun -p dell --gres=gpu:1 python sample_generator_depth.py --code_id new_33_3
#srun -p dell --gres=gpu:1 python sample_generator_depth.py --code_id new_33_2
#srun -p dell --gres=gpu:1 python sample_generator_depth.py --code_id new_33_4
#srun -p dell --gres=gpu:1 python sample_generator_depth.py --code_id new_35
#
#
class SamplingAgent_baseline(scip.Branchrule):

    def __init__(self, episode, instance, seed, out_queue, exploration_policy, query_expert_prob, out_dir, follow_expert=True):
        super().__init__()
        self.episode = episode
        self.instance = instance
        self.seed = seed
        self.out_queue = out_queue
        self.exploration_policy = exploration_policy
        self.query_expert_prob = query_expert_prob
        self.out_dir = out_dir
        self.follow_expert = follow_expert

        self.rng = np.random.RandomState(seed)
        self.new_node = True
        self.sample_counter = 0
        self.khalil_root_buffer = {}

        self.file_name_list = []
        self.data_dict = {}
        self.out_queue_dict = {}
        self.max_depth = 0

    def branchinit(self):
        
        self.state_buffer = {}
        self.branch_count = 0

    # 从这里整理特征、训练模型
    def branchexeclp(self, allowaddcons):
        # once in a while, also run the expert policy and record the (state, action) pair
        result = scip.SCIP_RESULT.DIDNOTRUN

        query_expert = self.rng.rand() < self.query_expert_prob

        print("query_expert:", query_expert)

        if query_expert:
            
            gcn_state = utils.extract_state(self.model, self.state_buffer)
            cands, *_ = self.model.getPseudoBranchCands()

            result = self.model.executeBranchRule('vanillafullstrong', allowaddcons)

            cands_, scores, npriocands, bestcand, _ = self.model.getVanillafullstrongData()

            assert result == scip.SCIP_RESULT.DIDNOTRUN
            assert all([c1.getCol().getLPPos() == c2.getCol().getLPPos() for c1, c2 in zip(cands, cands_)])

            action_set = [c.getCol().getLPPos() for c in cands]


            data = [gcn_state, bestcand, action_set, scores]

            # Do not record inconsistent scores. May happen if SCIP was early stopped (time limit).
            if not any([s < 0 for s in scores]) and len(gcn_state[1]['values']):
                
                self.max_depth = max(self.max_depth, self.model.getCurrentNode().getDepth())
                filename = f'{self.out_dir}/sample_{self.episode}_{self.sample_counter}.pkl'

                self.file_name_list.append(filename)
                self.data_dict[filename] = {
                        'episode': self.episode,
                        'instance': self.instance,
                        'seed': self.seed,
                        'node_number': self.model.getCurrentNode().getNumber(),
                        'node_depth': self.model.getCurrentNode().getDepth(),
                        'data': data,
                        }
                        
                self.out_queue_dict[filename] = {
                    'type': 'sample',
                    'episode': self.episode,
                    'instance': self.instance,
                    'seed': self.seed,
                    'node_number': self.model.getCurrentNode().getNumber(),
                    'node_depth': self.model.getCurrentNode().getDepth(),
                    'filename': filename,
                }

                self.sample_counter += 1
                

        # if exploration and expert policies are the same, prevent running it twice
        if not query_expert or (not self.follow_expert and self.exploration_policy != 'vanillafullstrong'):
            result = self.model.executeBranchRule(self.exploration_policy, allowaddcons)

        # apply 'vanillafullstrong' branching decision if needed
        if query_expert and self.follow_expert or self.exploration_policy == 'vanillafullstrong':
            assert result == scip.SCIP_RESULT.DIDNOTRUN
            cands, scores, ncands, npriocands, bestcand = self.model.getVanillafullstrongData()
            self.model.branchVar(cands[bestcand])
            result = scip.SCIP_RESULT.BRANCHED

        self.branch_count += 1


        return {"result": result}

class SamplingAgent_new_3(scip.Branchrule):

    def __init__(self, episode, instance, seed, out_queue, exploration_policy, query_expert_prob, out_dir, follow_expert=True):
        super().__init__()
        self.episode = episode
        self.instance = instance
        self.seed = seed
        self.out_queue = out_queue
        self.exploration_policy = exploration_policy
        self.query_expert_prob = query_expert_prob
        self.out_dir = out_dir
        self.follow_expert = follow_expert

        self.rng = np.random.RandomState(seed)
        self.new_node = True
        self.sample_counter = 0
        self.khalil_root_buffer = {}

        self.file_name_list = []
        self.data_dict = {}
        self.out_queue_dict = {}
        self.max_depth = 0

        self.time_gcn = 0
        self.time_kh = 0
        self.time_node = 0
        self.time_mip = 0
        self.time_cand = 0
        self.time_get_cands = 0

    def branchinit(self):
        
        self.state_buffer = {}
        self.branch_count = 0


    # 从这里整理特征、训练模型
    def branchexeclp(self, allowaddcons):
        if self.model.getNNodes() == 1:
            # initialize root buffer for Khalil features extraction
            utils.extract_khalil_variable_features(self.model, [], self.khalil_root_buffer)
        # once in a while, also run the expert policy and record the (state, action) pair
        result = scip.SCIP_RESULT.DIDNOTRUN

        query_expert = self.rng.rand() < self.query_expert_prob
        if query_expert:
            
            cands, *_ = self.model.getPseudoBranchCands()

            gcn_state = utils.extract_state(self.model, self.state_buffer)

            result = self.model.executeBranchRule('vanillafullstrong', allowaddcons)

            cands_, scores, npriocands, bestcand, _ = self.model.getVanillafullstrongData()

            assert result == scip.SCIP_RESULT.DIDNOTRUN
            assert all([c1.getCol().getLPPos() == c2.getCol().getLPPos() for c1, c2 in zip(cands, cands_)])

            action_set = [c.getCol().getLPPos() for c in cands]


            var_dim = 25
            node_dim = 8
            mip_dim = 53
            
            node_state = self.model.getMyNodeState(node_dim)

            mip_state = self.model.getMyMIPState(mip_dim)

            _, _, cands_state_mat = self.model.getMyCandsState(var_dim, self.branch_count)

            state = [gcn_state, node_state, mip_state, cands_state_mat]

            data = [state, bestcand, action_set, scores]


            # Do not record inconsistent scores. May happen if SCIP was early stopped (time limit).
            if not any([s < 0 for s in scores]) and len(gcn_state[1]['values']):
                
                self.max_depth = max(self.max_depth, self.model.getCurrentNode().getDepth())
                filename = f'{self.out_dir}/sample_{self.episode}_{self.sample_counter}.pkl'

                self.file_name_list.append(filename)
                self.data_dict[filename] = {
                        'episode': self.episode,
                        'instance': self.instance,
                        'seed': self.seed,
                        'node_number': self.model.getCurrentNode().getNumber(),
                        'node_depth': self.model.getCurrentNode().getDepth(),
                        'data': data,
                        }
                        
                self.out_queue_dict[filename] = {
                    'type': 'sample',
                    'episode': self.episode,
                    'instance': self.instance,
                    'seed': self.seed,
                    'node_number': self.model.getCurrentNode().getNumber(),
                    'node_depth': self.model.getCurrentNode().getDepth(),
                    'filename': filename,
                }

                self.sample_counter += 1
                

        # if exploration and expert policies are the same, prevent running it twice
        if not query_expert or (not self.follow_expert and self.exploration_policy != 'vanillafullstrong'):
            result = self.model.executeBranchRule(self.exploration_policy, allowaddcons)

        # apply 'vanillafullstrong' branching decision if needed
        if query_expert and self.follow_expert or self.exploration_policy == 'vanillafullstrong':
            assert result == scip.SCIP_RESULT.DIDNOTRUN
            cands, scores, ncands, npriocands, bestcand = self.model.getVanillafullstrongData()
            self.model.branchVar(cands[bestcand])
            result = scip.SCIP_RESULT.BRANCHED

        self.branch_count += 1
        
        return {"result": result}

        


    # 从这里整理特征、训练模型
    def branchexeclp_1(self, allowaddcons):
        if self.model.getNNodes() == 1:
            # initialize root buffer for Khalil features extraction
            utils.extract_khalil_variable_features(self.model, [], self.khalil_root_buffer)
        # once in a while, also run the expert policy and record the (state, action) pair
        result = scip.SCIP_RESULT.DIDNOTRUN

        query_expert = self.rng.rand() < self.query_expert_prob
        if query_expert:
            
            cands, *_ = self.model.getPseudoBranchCands()
            gcn_state = utils.extract_state(self.model, self.state_buffer)
            state_khalil = utils.extract_khalil_variable_features(self.model, cands, self.khalil_root_buffer)
            result = self.model.executeBranchRule('vanillafullstrong', allowaddcons)

            cands_, scores, npriocands, bestcand, _ = self.model.getVanillafullstrongData()

            assert result == scip.SCIP_RESULT.DIDNOTRUN
            assert all([c1.getCol().getLPPos() == c2.getCol().getLPPos() for c1, c2 in zip(cands, cands_)])

            action_set = [c.getCol().getLPPos() for c in cands]


            var_dim = 25
            node_dim = 8
            mip_dim = 53
            

            node_state = self.model.getMyNodeState(node_dim)
            mip_state = self.model.getMyMIPState(mip_dim)
            _, _, cands_state_mat = self.model.getMyCandsState(var_dim, self.branch_count)


            state = [gcn_state, state_khalil, node_state, mip_state, cands_state_mat]

            data = [state, bestcand, action_set, scores]


            # Do not record inconsistent scores. May happen if SCIP was early stopped (time limit).
            if not any([s < 0 for s in scores]) and len(gcn_state[1]['values']):
                
                self.max_depth = max(self.max_depth, self.model.getCurrentNode().getDepth())
                filename = f'{self.out_dir}/sample_{self.episode}_{self.sample_counter}.pkl'

                self.file_name_list.append(filename)
                self.data_dict[filename] = {
                        'episode': self.episode,
                        'instance': self.instance,
                        'seed': self.seed,
                        'node_number': self.model.getCurrentNode().getNumber(),
                        'node_depth': self.model.getCurrentNode().getDepth(),
                        'data': data,
                        }
                        
                self.out_queue_dict[filename] = {
                    'type': 'sample',
                    'episode': self.episode,
                    'instance': self.instance,
                    'seed': self.seed,
                    'node_number': self.model.getCurrentNode().getNumber(),
                    'node_depth': self.model.getCurrentNode().getDepth(),
                    'filename': filename,
                }

                self.sample_counter += 1
                self.branch_count += 1

        # if exploration and expert policies are the same, prevent running it twice
        if not query_expert or (not self.follow_expert and self.exploration_policy != 'vanillafullstrong'):
            result = self.model.executeBranchRule(self.exploration_policy, allowaddcons)

        # apply 'vanillafullstrong' branching decision if needed
        if query_expert and self.follow_expert or self.exploration_policy == 'vanillafullstrong':
            assert result == scip.SCIP_RESULT.DIDNOTRUN
            cands, scores, ncands, npriocands, bestcand = self.model.getVanillafullstrongData()
            self.model.branchVar(cands[bestcand])
            result = scip.SCIP_RESULT.BRANCHED

        return {"result": result}

class SamplingAgent_new_3_root(scip.Branchrule):

    def __init__(self, episode, instance, seed, out_queue, exploration_policy, query_expert_prob, out_dir, follow_expert=True):
        super().__init__()
        self.episode = episode
        self.instance = instance
        self.seed = seed
        self.out_queue = out_queue
        self.exploration_policy = exploration_policy
        self.query_expert_prob = query_expert_prob
        self.out_dir = out_dir
        self.follow_expert = follow_expert

        self.rng = np.random.RandomState(seed)
        self.new_node = True
        self.sample_counter = 0
        self.khalil_root_buffer = {}

        self.file_name_list = []
        self.data_dict = {}
        self.out_queue_dict = {}
        self.max_depth = 0

    def branchinit(self):
        
        self.state_buffer = {}
        self.branch_count = 0

    # 从这里整理特征、训练模型
    def branchexeclp(self, allowaddcons):
        if self.model.getNNodes() == 1:
            # initialize root buffer for Khalil features extraction
            utils.extract_khalil_variable_features(self.model, [], self.khalil_root_buffer)
        # once in a while, also run the expert policy and record the (state, action) pair
        result = scip.SCIP_RESULT.DIDNOTRUN

        query_expert = True
        if query_expert:
            
            cands, *_ = self.model.getPseudoBranchCands()

            lp_data = self.model.getCandsDowngainUpgainData(cands)
            lp_scores_0 = [item[0] for item in lp_data]
            lp_scores_1 = [item[1] for item in lp_data]


            gcn_state = utils.extract_state(self.model, self.state_buffer)
            state_khalil = utils.extract_khalil_variable_features(self.model, cands, self.khalil_root_buffer)
            result = self.model.executeBranchRule('vanillafullstrong', allowaddcons)

            cands_, scores, npriocands, bestcand, _ = self.model.getVanillafullstrongData()

            assert result == scip.SCIP_RESULT.DIDNOTRUN
            assert all([c1.getCol().getLPPos() == c2.getCol().getLPPos() for c1, c2 in zip(cands, cands_)])

            action_set = [c.getCol().getLPPos() for c in cands]


            var_dim = 25
            node_dim = 8
            mip_dim = 53
            
            node_state = self.model.getMyNodeState(node_dim)
            mip_state = self.model.getMyMIPState(mip_dim)
            _, _, cands_state_mat = self.model.getMyCandsState(var_dim, self.branch_count)


            state = [gcn_state, state_khalil, node_state, mip_state, cands_state_mat]

            data = [state, bestcand, action_set, scores, lp_scores_0, lp_scores_1]


            # Do not record inconsistent scores. May happen if SCIP was early stopped (time limit).
            if not any([s < 0 for s in scores]) and len(gcn_state[1]['values']):
                self.max_depth = max(self.max_depth, self.model.getCurrentNode().getDepth())

                filename = f'{self.out_dir}/sample_{self.episode}_{self.sample_counter}.pkl'

                self.file_name_list.append(filename)
                self.data_dict[filename] = {
                        'episode': self.episode,
                        'instance': self.instance,
                        'seed': self.seed,
                        'node_number': self.model.getCurrentNode().getNumber(),
                        'node_depth': self.model.getCurrentNode().getDepth(),
                        'data': data,
                        }
                        
                self.out_queue_dict[filename] = {
                    'type': 'sample',
                    'episode': self.episode,
                    'instance': self.instance,
                    'seed': self.seed,
                    'node_number': self.model.getCurrentNode().getNumber(),
                    'node_depth': self.model.getCurrentNode().getDepth(),
                    'filename': filename,
                }

                self.sample_counter += 1
                self.branch_count += 1

        # if exploration and expert policies are the same, prevent running it twice
        if not query_expert or (not self.follow_expert and self.exploration_policy != 'vanillafullstrong'):
            result = self.model.executeBranchRule(self.exploration_policy, allowaddcons)

        # apply 'vanillafullstrong' branching decision if needed
        if query_expert and self.follow_expert or self.exploration_policy == 'vanillafullstrong':
            assert result == scip.SCIP_RESULT.DIDNOTRUN
            cands, scores, ncands, npriocands, bestcand = self.model.getVanillafullstrongData()
            self.model.branchVar(cands[bestcand])
            result = scip.SCIP_RESULT.BRANCHED

        return {"result": result}

class SamplingAgent_new_20(scip.Branchrule):

    def __init__(self, episode, instance, seed, out_queue, exploration_policy, query_expert_prob, out_dir, follow_expert=True):
        super().__init__()
        self.episode = episode
        self.instance = instance
        self.seed = seed
        self.out_queue = out_queue
        self.exploration_policy = exploration_policy
        self.query_expert_prob = query_expert_prob
        self.out_dir = out_dir
        self.max_depth = 0
        self.follow_expert = follow_expert

        self.rng = np.random.RandomState(seed)
        self.new_node = True
        self.sample_counter = 0

        self.file_name_list = []
        self.data_dict = {}
        self.out_queue_dict = {}
        self.max_depth = 0

    def branchinit(self):
        self.khalil_root_buffer = {}


    # 从这里整理特征、训练模型
    def branchexeclp(self, allowaddcons):

        if self.model.getNNodes() == 1:
            # initialize root buffer for Khalil features extraction
            utils.extract_khalil_variable_features(self.model, [], self.khalil_root_buffer)
        # once in a while, also run the expert policy and record the (state, action) pair
        result = scip.SCIP_RESULT.DIDNOTRUN

        query_expert = self.rng.rand() < self.query_expert_prob

        if query_expert:
            gcn_state = utils.extract_state(self.model)

            cands, *_ = self.model.getPseudoBranchCands()

            lp_data = self.model.getCandsDowngainUpgainData(cands)
            lp_scores_0 = [item[0] for item in lp_data]
            lp_scores_1 = [item[1] for item in lp_data]

            result = self.model.executeBranchRule('vanillafullstrong', allowaddcons)
            cands_, scores, npriocands, bestcand, _ = self.model.getVanillafullstrongData()

            assert result == scip.SCIP_RESULT.DIDNOTRUN
            assert all([c1.getCol().getLPPos() == c2.getCol().getLPPos() for c1, c2 in zip(cands, cands_)])

            action_set = [c.getCol().getLPPos() for c in cands]

            data = [gcn_state, bestcand, action_set, scores, lp_scores_0, lp_scores_1]


            # Do not record inconsistent scores. May happen if SCIP was early stopped (time limit).
            if not any([s < 0 for s in scores]) and len(gcn_state[1]['values']):

                self.max_depth = max(self.max_depth, self.model.getCurrentNode().getDepth())                
                filename = f'{self.out_dir}/sample_{self.episode}_{self.sample_counter}.pkl'

                self.file_name_list.append(filename)
                self.data_dict[filename] = {
                        'episode': self.episode,
                        'instance': self.instance,
                        'seed': self.seed,
                        'node_number': self.model.getCurrentNode().getNumber(),
                        'node_depth': self.model.getCurrentNode().getDepth(),
                        'data': data,
                        }

                self.out_queue_dict[filename] = {
                    'type': 'sample',
                    'episode': self.episode,
                    'instance': self.instance,
                    'seed': self.seed,
                    'node_number': self.model.getCurrentNode().getNumber(),
                    'node_depth': self.model.getCurrentNode().getDepth(),
                    'filename': filename,
                }

                self.sample_counter += 1
    
        # if exploration and expert policies are the same, prevent running it twice
        if not query_expert or (not self.follow_expert and self.exploration_policy != 'vanillafullstrong'):
            result = self.model.executeBranchRule(self.exploration_policy, allowaddcons)

        # apply 'vanillafullstrong' branching decision if needed
        if query_expert and self.follow_expert or self.exploration_policy == 'vanillafullstrong':
            assert result == scip.SCIP_RESULT.DIDNOTRUN
            cands, scores, ncands, npriocands, bestcand = self.model.getVanillafullstrongData()
            self.model.branchVar(cands[bestcand])
            result = scip.SCIP_RESULT.BRANCHED

        return {"result": result}

class SamplingAgent_new_33_1(scip.Branchrule):

    def __init__(self, episode, instance, seed, out_queue, exploration_policy, query_expert_prob, out_dir, follow_expert=True):
        super().__init__()
        self.episode = episode
        self.instance = instance
        self.seed = seed
        self.out_queue = out_queue
        self.exploration_policy = exploration_policy
        self.query_expert_prob = query_expert_prob
        self.out_dir = out_dir
        self.follow_expert = follow_expert

        self.rng = np.random.RandomState(seed)
        self.new_node = True
        self.sample_counter = 0

        self.file_name_list = []
        self.data_dict = {}
        self.out_queue_dict = {}
        self.max_depth = 0
        self.min_depth = 1e9

    def branchinit(self):
        self.khalil_root_buffer = {}

    # 从这里整理特征、训练模型
    def branchexeclp(self, allowaddcons):
        if self.model.getNNodes() == 1:
            # initialize root buffer for Khalil features extraction
            utils.extract_khalil_variable_features(self.model, [], self.khalil_root_buffer)
        # once in a while, also run the expert policy and record the (state, action) pair
        result = scip.SCIP_RESULT.DIDNOTRUN
        
        query_expert = self.rng.rand() < 0.4

        if query_expert:

            gcn_state = utils.extract_state(self.model)
            cands, *_ = self.model.getPseudoBranchCands()

            result = self.model.executeBranchRule('vanillafullstrong', allowaddcons)
            cands_, scores, npriocands, bestcand, _ = self.model.getVanillafullstrongData()

            assert result == scip.SCIP_RESULT.DIDNOTRUN
            assert all([c1.getCol().getLPPos() == c2.getCol().getLPPos() for c1, c2 in zip(cands, cands_)])

            action_set = [c.getCol().getLPPos() for c in cands]

            # lp_data = self.model.getCandsDowngainUpgainData(cands)
            # lp_scores_0 = [item[0] for item in lp_data]
            # lp_scores_1 = [item[1] for item in lp_data]
            lp_scores_0 = 0
            lp_scores_1 = 0

            data = [gcn_state, bestcand, action_set, scores, lp_scores_0, lp_scores_1]


            # Do not record inconsistent scores. May happen if SCIP was early stopped (time limit).
            if not any([s < 0 for s in scores]) and len(gcn_state[1]['values']):
                
                now_depth = self.model.getCurrentNode().getDepth()
                self.min_depth = min(self.min_depth, now_depth)
                self.max_depth = max(self.max_depth, now_depth)
                filename = f'{self.out_dir}/sample_{self.episode}_{self.sample_counter}.pkl'

                self.file_name_list.append(filename)
                self.data_dict[filename] = {
                        'episode': self.episode,
                        'instance': self.instance,
                        'seed': self.seed,
                        'node_number': self.model.getCurrentNode().getNumber(),
                        'node_depth': self.model.getCurrentNode().getDepth(),
                        'data': data,
                        }

                self.out_queue_dict[filename] = {
                    'type': 'sample',
                    'episode': self.episode,
                    'instance': self.instance,
                    'seed': self.seed,
                    'node_number': self.model.getCurrentNode().getNumber(),
                    'node_depth': self.model.getCurrentNode().getDepth(),
                    'filename': filename,
                }

                self.sample_counter += 1

        # if exploration and expert policies are the same, prevent running it twice
        if not query_expert or (not self.follow_expert and self.exploration_policy != 'vanillafullstrong'):
            result = self.model.executeBranchRule(self.exploration_policy, allowaddcons)

        # apply 'vanillafullstrong' branching decision if needed
        if query_expert and self.follow_expert or self.exploration_policy == 'vanillafullstrong':
            assert result == scip.SCIP_RESULT.DIDNOTRUN
            cands, scores, ncands, npriocands, bestcand = self.model.getVanillafullstrongData()
            self.model.branchVar(cands[bestcand])
            result = scip.SCIP_RESULT.BRANCHED

        return {"result": result}

class SamplingAgent_new_33_2(scip.Branchrule):

    def __init__(self, episode, instance, seed, out_queue, exploration_policy, query_expert_prob, out_dir, follow_expert=True):
        super().__init__()
        self.episode = episode
        self.instance = instance
        self.seed = seed
        self.out_queue = out_queue
        self.exploration_policy = exploration_policy
        self.query_expert_prob = query_expert_prob
        self.out_dir = out_dir
        self.follow_expert = follow_expert

        self.rng = np.random.RandomState(seed)
        self.new_node = True
        self.sample_counter = 0

        self.file_name_list = []
        self.data_dict = {}
        self.out_queue_dict = {}
        self.max_depth = 0
        self.min_depth = 1e9

    def branchinit(self):
        self.khalil_root_buffer = {}

    # 从这里整理特征、训练模型
    def branchexeclp(self, allowaddcons):
        if self.model.getNNodes() == 1:
            # initialize root buffer for Khalil features extraction
            utils.extract_khalil_variable_features(self.model, [], self.khalil_root_buffer)
        # once in a while, also run the expert policy and record the (state, action) pair
        result = scip.SCIP_RESULT.DIDNOTRUN

        now_depth = self.model.getCurrentNode().getDepth()
      
        if now_depth > 1 and now_depth <= 4:
            query_expert = self.rng.rand() < 0.1
        else:
            query_expert = False

        if query_expert:
            gcn_state = utils.extract_state(self.model)
            cands, *_ = self.model.getPseudoBranchCands()

            result = self.model.executeBranchRule('vanillafullstrong', allowaddcons)
            cands_, scores, npriocands, bestcand, _ = self.model.getVanillafullstrongData()

            assert result == scip.SCIP_RESULT.DIDNOTRUN
            assert all([c1.getCol().getLPPos() == c2.getCol().getLPPos() for c1, c2 in zip(cands, cands_)])

            action_set = [c.getCol().getLPPos() for c in cands]

            # lp_data = self.model.getCandsDowngainUpgainData(cands)
            # lp_scores_0 = [item[0] for item in lp_data]
            # lp_scores_1 = [item[1] for item in lp_data]

            lp_scores_0 = 0
            lp_scores_1 = 0

            data = [gcn_state, bestcand, action_set, scores, lp_scores_0, lp_scores_1]


            # Do not record inconsistent scores. May happen if SCIP was early stopped (time limit).
            if not any([s < 0 for s in scores]) and len(gcn_state[1]['values']):
                
                now_depth = self.model.getCurrentNode().getDepth()
                self.min_depth = min(self.min_depth, now_depth)
                self.max_depth = max(self.max_depth, now_depth)
                filename = f'{self.out_dir}/sample_{self.episode}_{self.sample_counter}.pkl'

                self.file_name_list.append(filename)
                self.data_dict[filename] = {
                        'episode': self.episode,
                        'instance': self.instance,
                        'seed': self.seed,
                        'node_number': self.model.getCurrentNode().getNumber(),
                        'node_depth': self.model.getCurrentNode().getDepth(),
                        'data': data,
                        }
                        
                self.out_queue_dict[filename] = {
                    'type': 'sample',
                    'episode': self.episode,
                    'instance': self.instance,
                    'seed': self.seed,
                    'node_number': self.model.getCurrentNode().getNumber(),
                    'node_depth': self.model.getCurrentNode().getDepth(),
                    'filename': filename,
                }

                self.sample_counter += 1
    
        # if exploration and expert policies are the same, prevent running it twice
        if not query_expert or (not self.follow_expert and self.exploration_policy != 'vanillafullstrong'):
            result = self.model.executeBranchRule(self.exploration_policy, allowaddcons)

        # apply 'vanillafullstrong' branching decision if needed
        if query_expert and self.follow_expert or self.exploration_policy == 'vanillafullstrong':
            assert result == scip.SCIP_RESULT.DIDNOTRUN
            cands, scores, ncands, npriocands, bestcand = self.model.getVanillafullstrongData()
            self.model.branchVar(cands[bestcand])
            result = scip.SCIP_RESULT.BRANCHED

        return {"result": result}

class SamplingAgent_new_33_3(scip.Branchrule):

    def __init__(self, episode, instance, seed, out_queue, exploration_policy, query_expert_prob, out_dir, follow_expert=True):
        super().__init__()
        self.episode = episode
        self.instance = instance
        self.seed = seed
        self.out_queue = out_queue
        self.exploration_policy = exploration_policy
        self.query_expert_prob = query_expert_prob
        self.out_dir = out_dir
        self.follow_expert = follow_expert

        self.rng = np.random.RandomState(seed)
        self.new_node = True
        self.sample_counter = 0

        self.file_name_list = []
        self.data_dict = {}
        self.out_queue_dict = {}
        self.max_depth = 0
        self.min_depth = 1e9

    def branchinit(self):
        self.khalil_root_buffer = {}

    # 从这里整理特征、训练模型
    def branchexeclp(self, allowaddcons):
        if self.model.getNNodes() == 1:
            # initialize root buffer for Khalil features extraction
            utils.extract_khalil_variable_features(self.model, [], self.khalil_root_buffer)
        # once in a while, also run the expert policy and record the (state, action) pair
        result = scip.SCIP_RESULT.DIDNOTRUN

        now_depth = self.model.getCurrentNode().getDepth()
        if now_depth > 4 and now_depth <= 7:
            query_expert = self.rng.rand() < self.query_expert_prob
        else:
            query_expert = False

        if query_expert:
            gcn_state = utils.extract_state(self.model)
            cands, *_ = self.model.getPseudoBranchCands()

            result = self.model.executeBranchRule('vanillafullstrong', allowaddcons)
            cands_, scores, npriocands, bestcand, _ = self.model.getVanillafullstrongData()

            assert result == scip.SCIP_RESULT.DIDNOTRUN
            assert all([c1.getCol().getLPPos() == c2.getCol().getLPPos() for c1, c2 in zip(cands, cands_)])

            action_set = [c.getCol().getLPPos() for c in cands]

            # lp_data = self.model.getCandsDowngainUpgainData(cands)
            # lp_scores_0 = [item[0] for item in lp_data]
            # lp_scores_1 = [item[1] for item in lp_data]
            lp_scores_0 = 0
            lp_scores_1 = 0

            data = [gcn_state, bestcand, action_set, scores, lp_scores_0, lp_scores_1]


            # Do not record inconsistent scores. May happen if SCIP was early stopped (time limit).
            if not any([s < 0 for s in scores]) and len(gcn_state[1]['values']):

                now_depth = self.model.getCurrentNode().getDepth()

                print("now_depth:", now_depth)
                self.min_depth = min(self.min_depth, now_depth)
                self.max_depth = max(self.max_depth, now_depth)
                

                filename = f'{self.out_dir}/sample_{self.episode}_{self.sample_counter}.pkl'

                self.file_name_list.append(filename)
                self.data_dict[filename] = {
                        'episode': self.episode,
                        'instance': self.instance,
                        'seed': self.seed,
                        'node_number': self.model.getCurrentNode().getNumber(),
                        'node_depth': self.model.getCurrentNode().getDepth(),
                        'data': data,
                        }
                        
                self.out_queue_dict[filename] = {
                    'type': 'sample',
                    'episode': self.episode,
                    'instance': self.instance,
                    'seed': self.seed,
                    'node_number': self.model.getCurrentNode().getNumber(),
                    'node_depth': self.model.getCurrentNode().getDepth(),
                    'filename': filename,
                }

                self.sample_counter += 1
    
        # if exploration and expert policies are the same, prevent running it twice
        if not query_expert or (not self.follow_expert and self.exploration_policy != 'vanillafullstrong'):
            result = self.model.executeBranchRule(self.exploration_policy, allowaddcons)

        # apply 'vanillafullstrong' branching decision if needed
        if query_expert and self.follow_expert or self.exploration_policy == 'vanillafullstrong':
            assert result == scip.SCIP_RESULT.DIDNOTRUN
            cands, scores, ncands, npriocands, bestcand = self.model.getVanillafullstrongData()
            self.model.branchVar(cands[bestcand])
            result = scip.SCIP_RESULT.BRANCHED

        return {"result": result}

class SamplingAgent_new_33_4(scip.Branchrule):

    def __init__(self, episode, instance, seed, out_queue, exploration_policy, query_expert_prob, out_dir, follow_expert=True):
        super().__init__()
        self.episode = episode
        self.instance = instance
        self.seed = seed
        self.out_queue = out_queue
        self.exploration_policy = exploration_policy
        self.query_expert_prob = query_expert_prob
        self.out_dir = out_dir
        self.follow_expert = follow_expert

        self.rng = np.random.RandomState(seed)
        self.new_node = True
        self.sample_counter = 0

        self.file_name_list = []
        self.data_dict = {}
        self.out_queue_dict = {}
        self.max_depth = 0
        self.min_depth = 1e9

    def branchinit(self):
        self.khalil_root_buffer = {}

    # 从这里整理特征、训练模型
    def branchexeclp(self, allowaddcons):
        if self.model.getNNodes() == 1:
            # initialize root buffer for Khalil features extraction
            utils.extract_khalil_variable_features(self.model, [], self.khalil_root_buffer)
        # once in a while, also run the expert policy and record the (state, action) pair
        result = scip.SCIP_RESULT.DIDNOTRUN

        now_depth = self.model.getCurrentNode().getDepth()
        if now_depth > 7:
            query_expert = self.rng.rand() < self.query_expert_prob
        else:
            query_expert = False

        if query_expert:
            gcn_state = utils.extract_state(self.model)
            cands, *_ = self.model.getPseudoBranchCands()

            result = self.model.executeBranchRule('vanillafullstrong', allowaddcons)
            cands_, scores, npriocands, bestcand, _ = self.model.getVanillafullstrongData()

            assert result == scip.SCIP_RESULT.DIDNOTRUN
            assert all([c1.getCol().getLPPos() == c2.getCol().getLPPos() for c1, c2 in zip(cands, cands_)])

            action_set = [c.getCol().getLPPos() for c in cands]

            # lp_data = self.model.getCandsDowngainUpgainData(cands)
            # lp_scores_0 = [item[0] for item in lp_data]
            # lp_scores_1 = [item[1] for item in lp_data]
            lp_scores_0 = 0
            lp_scores_1 = 0

            data = [gcn_state, bestcand, action_set, scores, lp_scores_0, lp_scores_1]

            # Do not record inconsistent scores. May happen if SCIP was early stopped (time limit).
            if not any([s < 0 for s in scores]) and len(gcn_state[1]['values']):

                now_depth = self.model.getCurrentNode().getDepth()
                self.min_depth = min(self.min_depth, now_depth)
                self.max_depth = max(self.max_depth, now_depth)

                filename = f'{self.out_dir}/sample_{self.episode}_{self.sample_counter}.pkl'

                self.file_name_list.append(filename)
                self.data_dict[filename] = {
                        'episode': self.episode,
                        'instance': self.instance,
                        'seed': self.seed,
                        'node_number': self.model.getCurrentNode().getNumber(),
                        'node_depth': self.model.getCurrentNode().getDepth(),
                        'data': data,
                        }
                        
                self.out_queue_dict[filename] = {
                    'type': 'sample',
                    'episode': self.episode,
                    'instance': self.instance,
                    'seed': self.seed,
                    'node_number': self.model.getCurrentNode().getNumber(),
                    'node_depth': self.model.getCurrentNode().getDepth(),
                    'filename': filename,
                }

                self.sample_counter += 1
    
        # if exploration and expert policies are the same, prevent running it twice
        if not query_expert or (not self.follow_expert and self.exploration_policy != 'vanillafullstrong'):
            result = self.model.executeBranchRule(self.exploration_policy, allowaddcons)

        # apply 'vanillafullstrong' branching decision if needed
        if query_expert and self.follow_expert or self.exploration_policy == 'vanillafullstrong':
            assert result == scip.SCIP_RESULT.DIDNOTRUN
            cands, scores, ncands, npriocands, bestcand = self.model.getVanillafullstrongData()
            self.model.branchVar(cands[bestcand])
            result = scip.SCIP_RESULT.BRANCHED

        return {"result": result}

class SamplingAgent_new_35(scip.Branchrule):

    def __init__(self, episode, instance, seed, out_queue, exploration_policy, query_expert_prob, out_dir, follow_expert=True):
        super().__init__()
        self.episode = episode
        self.instance = instance
        self.seed = seed
        self.out_queue = out_queue
        self.exploration_policy = exploration_policy
        self.query_expert_prob = query_expert_prob
        self.out_dir = out_dir
        self.follow_expert = follow_expert

        self.rng = np.random.RandomState(seed)
        self.new_node = True
        self.sample_counter = 0

        self.file_name_list = []
        self.data_dict = {}
        self.out_queue_dict = {}
        self.max_depth = 0
        self.min_depth = 1e9

    def branchinit(self):
        self.khalil_root_buffer = {}

    # 从这里整理特征、训练模型
    def branchexeclp(self, allowaddcons):
        if self.model.getNNodes() == 1:
            # initialize root buffer for Khalil features extraction
            utils.extract_khalil_variable_features(self.model, [], self.khalil_root_buffer)
        # once in a while, also run the expert policy and record the (state, action) pair
        result = scip.SCIP_RESULT.DIDNOTRUN
        
        query_expert = self.rng.rand() < self.query_expert_prob

        if query_expert:

            gcn_state = utils.extract_state(self.model)
            cands, *_ = self.model.getPseudoBranchCands()

            result = self.model.executeBranchRule('vanillafullstrong', allowaddcons)
            cands_, scores, npriocands, bestcand, _ = self.model.getVanillafullstrongData()

            assert result == scip.SCIP_RESULT.DIDNOTRUN
            assert all([c1.getCol().getLPPos() == c2.getCol().getLPPos() for c1, c2 in zip(cands, cands_)])

            action_set = [c.getCol().getLPPos() for c in cands]

            # lp_data = self.model.getCandsDowngainUpgainData(cands)
            # lp_scores_0 = [item[0] for item in lp_data]
            # lp_scores_1 = [item[1] for item in lp_data]
            lp_scores_0 = 0
            lp_scores_1 = 0

            data = [gcn_state, bestcand, action_set, scores, lp_scores_0, lp_scores_1]


            # Do not record inconsistent scores. May happen if SCIP was early stopped (time limit).
            if not any([s < 0 for s in scores]) and len(gcn_state[1]['values']):
                
                now_depth = self.model.getCurrentNode().getDepth()
                self.min_depth = min(self.min_depth, now_depth)
                self.max_depth = max(self.max_depth, now_depth)
                filename = f'{self.out_dir}/sample_{self.episode}_{self.sample_counter}.pkl'

                self.file_name_list.append(filename)
                self.data_dict[filename] = {
                        'episode': self.episode,
                        'instance': self.instance,
                        'seed': self.seed,
                        'node_number': self.model.getCurrentNode().getNumber(),
                        'node_depth': self.model.getCurrentNode().getDepth(),
                        'data': data,
                        }

                self.out_queue_dict[filename] = {
                    'type': 'sample',
                    'episode': self.episode,
                    'instance': self.instance,
                    'seed': self.seed,
                    'node_number': self.model.getCurrentNode().getNumber(),
                    'node_depth': self.model.getCurrentNode().getDepth(),
                    'filename': filename,
                }

                self.sample_counter += 1

        # if exploration and expert policies are the same, prevent running it twice
        if not query_expert or (not self.follow_expert and self.exploration_policy != 'vanillafullstrong'):
            result = self.model.executeBranchRule(self.exploration_policy, allowaddcons)

        # apply 'vanillafullstrong' branching decision if needed
        if query_expert and self.follow_expert or self.exploration_policy == 'vanillafullstrong':
            assert result == scip.SCIP_RESULT.DIDNOTRUN
            cands, scores, ncands, npriocands, bestcand = self.model.getVanillafullstrongData()
            self.model.branchVar(cands[bestcand])
            result = scip.SCIP_RESULT.BRANCHED

        return {"result": result}

class SamplingAgent_new_36(scip.Branchrule):

    def __init__(self, episode, instance, seed, out_queue, exploration_policy, query_expert_prob, out_dir, follow_expert=True):
        super().__init__()
        self.episode = episode
        self.instance = instance
        self.seed = seed
        self.out_queue = out_queue
        self.exploration_policy = exploration_policy
        self.query_expert_prob = query_expert_prob
        self.out_dir = out_dir
        self.follow_expert = follow_expert

        self.rng = np.random.RandomState(seed)
        self.new_node = True
        self.sample_counter = 0

        self.file_name_list = []
        self.data_dict = {}
        self.out_queue_dict = {}
        self.max_depth = 0
        self.min_depth = 1e9

    def branchinit(self):
        self.khalil_root_buffer = {}

    # 从这里整理特征、训练模型
    def branchexeclp(self, allowaddcons):
        if self.model.getNNodes() == 1:
            # initialize root buffer for Khalil features extraction
            utils.extract_khalil_variable_features(self.model, [], self.khalil_root_buffer)
        # once in a while, also run the expert policy and record the (state, action) pair
        result = scip.SCIP_RESULT.DIDNOTRUN

        now_depth = self.model.getCurrentNode().getDepth()
        if now_depth > 7:
            query_expert = self.rng.rand() < self.query_expert_prob
        else:
            query_expert = False

        if query_expert:
            gcn_state = utils.extract_state(self.model)
            cands, *_ = self.model.getPseudoBranchCands()

            result = self.model.executeBranchRule('vanillafullstrong', allowaddcons)
            cands_, scores, npriocands, bestcand, _ = self.model.getVanillafullstrongData()

            assert result == scip.SCIP_RESULT.DIDNOTRUN
            assert all([c1.getCol().getLPPos() == c2.getCol().getLPPos() for c1, c2 in zip(cands, cands_)])

            action_set = [c.getCol().getLPPos() for c in cands]

            # lp_data = self.model.getCandsDowngainUpgainData(cands)
            # lp_scores_0 = [item[0] for item in lp_data]
            # lp_scores_1 = [item[1] for item in lp_data]
            lp_scores_0 = 0
            lp_scores_1 = 0

            data = [gcn_state, bestcand, action_set, scores, lp_scores_0, lp_scores_1]

            # Do not record inconsistent scores. May happen if SCIP was early stopped (time limit).
            if not any([s < 0 for s in scores]) and len(gcn_state[1]['values']):

                now_depth = self.model.getCurrentNode().getDepth()
                self.min_depth = min(self.min_depth, now_depth)
                self.max_depth = max(self.max_depth, now_depth)

                filename = f'{self.out_dir}/sample_{self.episode}_{self.sample_counter}.pkl'

                self.file_name_list.append(filename)
                self.data_dict[filename] = {
                        'episode': self.episode,
                        'instance': self.instance,
                        'seed': self.seed,
                        'node_number': self.model.getCurrentNode().getNumber(),
                        'node_depth': self.model.getCurrentNode().getDepth(),
                        'data': data,
                        }
                        
                self.out_queue_dict[filename] = {
                    'type': 'sample',
                    'episode': self.episode,
                    'instance': self.instance,
                    'seed': self.seed,
                    'node_number': self.model.getCurrentNode().getNumber(),
                    'node_depth': self.model.getCurrentNode().getDepth(),
                    'filename': filename,
                }

                self.sample_counter += 1
    
        # if exploration and expert policies are the same, prevent running it twice
        if not query_expert or (not self.follow_expert and self.exploration_policy != 'vanillafullstrong'):
            result = self.model.executeBranchRule(self.exploration_policy, allowaddcons)

        # apply 'vanillafullstrong' branching decision if needed
        if query_expert and self.follow_expert or self.exploration_policy == 'vanillafullstrong':
            assert result == scip.SCIP_RESULT.DIDNOTRUN
            cands, scores, ncands, npriocands, bestcand = self.model.getVanillafullstrongData()
            self.model.branchVar(cands[bestcand])
            result = scip.SCIP_RESULT.BRANCHED

        return {"result": result}

class SamplingAgent_new_37(scip.Branchrule):

    def __init__(self, episode, instance, seed, out_queue, exploration_policy, query_expert_prob, out_dir, follow_expert=True):
        super().__init__()
        self.episode = episode
        self.instance = instance
        self.seed = seed
        self.out_queue = out_queue
        self.exploration_policy = exploration_policy
        self.query_expert_prob = query_expert_prob
        self.out_dir = out_dir
        self.follow_expert = follow_expert

        self.rng = np.random.RandomState(seed)
        self.new_node = True
        self.sample_counter = 0

        self.file_name_list = []
        self.data_dict = {}
        self.out_queue_dict = {}
        self.max_depth = 0

    def branchinit(self):
        self.khalil_root_buffer = {}
        self.state_buffer = {}
        self.root_state_base = None
    # 从这里整理特征、训练模型
    def branchexeclp(self, allowaddcons):

        if self.model.getNNodes() == 1:
            # initialize root buffer for Khalil features extraction
            utils.extract_khalil_variable_features(self.model, [], self.khalil_root_buffer)
            self.root_state_base, _, _ = utils.extract_state_new_37_base(self.model)

        # once in a while, also run the expert policy and record the (state, action) pair
        result = scip.SCIP_RESULT.DIDNOTRUN

        # query_expert = self.rng.rand() < self.query_expert_prob
        query_expert = True
        # print("query_expert:", query_expert)
        if query_expert:
            state_base, cands, cands_pos = utils.extract_state_new_37_base(self.model)
            # print("extract_state_new_37_base is ok")
            state_cands_col, cands, cands_pos = utils.extract_state_new_37_col(self.model)
            # print("extract_state_new_37_col is ok")
            state_cands_all, cands, cands_pos = utils.extract_state_new_37_all(self.model)
            # print("extract_state_new_37_all is ok")
            

            result = self.model.executeBranchRule('vanillafullstrong', allowaddcons)
            cands_, scores, npriocands, bestcand, _ = self.model.getVanillafullstrongData()

            assert result == scip.SCIP_RESULT.DIDNOTRUN
            assert all([c1.getCol().getLPPos() == c2.getCol().getLPPos() for c1, c2 in zip(cands, cands_)])

            action_set = [c.getCol().getLPPos() for c in cands]

            data = [self.root_state_base, state_base, state_cands_col, state_cands_all,  bestcand, action_set, scores]


            # Do not record inconsistent scores. May happen if SCIP was early stopped (time limit).
            if not any([s < 0 for s in scores]):
                
                self.max_depth = max(self.max_depth, self.model.getCurrentNode().getDepth())
                filename = f'{self.out_dir}/sample_{self.episode}_{self.sample_counter}.pkl'

                self.file_name_list.append(filename)
                self.data_dict[filename] = {
                        'episode': self.episode,
                        'instance': self.instance,
                        'seed': self.seed,
                        'node_number': self.model.getCurrentNode().getNumber(),
                        'node_depth': self.model.getCurrentNode().getDepth(),
                        'data': data,
                        }
                        
                self.out_queue_dict[filename] = {
                    'type': 'sample',
                    'episode': self.episode,
                    'instance': self.instance,
                    'seed': self.seed,
                    'node_number': self.model.getCurrentNode().getNumber(),
                    'node_depth': self.model.getCurrentNode().getDepth(),
                    'filename': filename,
                }

                self.sample_counter += 1
    
        # if exploration and expert policies are the same, prevent running it twice
        if not query_expert or (not self.follow_expert and self.exploration_policy != 'vanillafullstrong'):
            result = self.model.executeBranchRule(self.exploration_policy, allowaddcons)

        # apply 'vanillafullstrong' branching decision if needed
        if query_expert and self.follow_expert or self.exploration_policy == 'vanillafullstrong':
            assert result == scip.SCIP_RESULT.DIDNOTRUN
            cands, scores, ncands, npriocands, bestcand = self.model.getVanillafullstrongData()
            self.model.branchVar(cands[bestcand])
            result = scip.SCIP_RESULT.BRANCHED

        return {"result": result}

class SamplingAgent_new_37_1(scip.Branchrule):

    def __init__(self, episode, instance, seed, out_queue, exploration_policy, query_expert_prob, out_dir, follow_expert=True):
        super().__init__()
        self.episode = episode
        self.instance = instance
        self.seed = seed
        self.out_queue = out_queue
        self.exploration_policy = exploration_policy
        self.query_expert_prob = query_expert_prob
        self.out_dir = out_dir
        self.max_depth = 0
        self.follow_expert = follow_expert

        self.rng = np.random.RandomState(seed)
        self.new_node = True
        self.sample_counter = 0

        self.file_name_list = []
        self.data_dict = {}
        self.out_queue_dict = {}
        self.max_depth = 0

    def branchinit(self):
        self.khalil_root_buffer = {}
        self.buffer = {}

    # 从这里整理特征、训练模型
    def branchexeclp(self, allowaddcons):

        if self.model.getNNodes() == 1:
            # initialize root buffer for Khalil features extraction
            utils.extract_khalil_variable_features(self.model, [], self.khalil_root_buffer)
        # once in a while, also run the expert policy and record the (state, action) pair
        result = scip.SCIP_RESULT.DIDNOTRUN

        query_expert = self.rng.rand() < self.query_expert_prob

        # print("query_expert:", query_expert)

        if query_expert:
            
            constraint_features, edge_features, variable_features, cands, cands_pos = utils.extract_state_new_37_1(self.model, self.buffer)

            gcn_state = constraint_features, edge_features, variable_features

            cands, *_ = self.model.getPseudoBranchCands()

            result = self.model.executeBranchRule('vanillafullstrong', allowaddcons)
            cands_, scores, npriocands, bestcand, _ = self.model.getVanillafullstrongData()

            assert result == scip.SCIP_RESULT.DIDNOTRUN
            assert all([c1.getCol().getLPPos() == c2.getCol().getLPPos() for c1, c2 in zip(cands, cands_)])

            action_set = [c.getCol().getLPPos() for c in cands]

            data = [gcn_state, bestcand, action_set, scores]


            # Do not record inconsistent scores. May happen if SCIP was early stopped (time limit).
            if not any([s < 0 for s in scores]) and len(gcn_state[1]['values']):

                self.max_depth = max(self.max_depth, self.model.getCurrentNode().getDepth())                
                filename = f'{self.out_dir}/sample_{self.episode}_{self.sample_counter}.pkl'

                self.file_name_list.append(filename)
                self.data_dict[filename] = {
                        'episode': self.episode,
                        'instance': self.instance,
                        'seed': self.seed,
                        'node_number': self.model.getCurrentNode().getNumber(),
                        'node_depth': self.model.getCurrentNode().getDepth(),
                        'data': data,
                        }
                        
                self.out_queue_dict[filename] = {
                    'type': 'sample',
                    'episode': self.episode,
                    'instance': self.instance,
                    'seed': self.seed,
                    'node_number': self.model.getCurrentNode().getNumber(),
                    'node_depth': self.model.getCurrentNode().getDepth(),
                    'filename': filename,
                }

                self.sample_counter += 1
    
        # if exploration and expert policies are the same, prevent running it twice
        if not query_expert or (not self.follow_expert and self.exploration_policy != 'vanillafullstrong'):
            result = self.model.executeBranchRule(self.exploration_policy, allowaddcons)

        # apply 'vanillafullstrong' branching decision if needed
        if query_expert and self.follow_expert or self.exploration_policy == 'vanillafullstrong':
            assert result == scip.SCIP_RESULT.DIDNOTRUN
            cands, scores, ncands, npriocands, bestcand = self.model.getVanillafullstrongData()
            self.model.branchVar(cands[bestcand])
            result = scip.SCIP_RESULT.BRANCHED

        return {"result": result}

class SamplingAgent_new_39(scip.Branchrule):

    def __init__(self, episode, instance, seed, out_queue, exploration_policy, query_expert_prob, out_dir, follow_expert=True):
        super().__init__()
        self.episode = episode
        self.instance = instance
        self.seed = seed
        self.out_queue = out_queue
        self.exploration_policy = exploration_policy
        self.query_expert_prob = query_expert_prob
        self.out_dir = out_dir
        self.max_depth = 0
        self.follow_expert = follow_expert

        self.rng = np.random.RandomState(seed)
        self.new_node = True
        self.sample_counter = 0

        self.file_name_list = []
        self.data_dict = {}
        self.out_queue_dict = {}
        self.max_depth = 0

    def branchinit(self):
        self.khalil_root_buffer = {}


    # 从这里整理特征、训练模型
    def branchexeclp(self, allowaddcons):

        if self.model.getNNodes() == 1:
            # initialize root buffer for Khalil features extraction
            utils.extract_khalil_variable_features(self.model, [], self.khalil_root_buffer)
        # once in a while, also run the expert policy and record the (state, action) pair
        result = scip.SCIP_RESULT.DIDNOTRUN

        query_expert = self.rng.rand() < self.query_expert_prob
        
        if query_expert:
            gcn_state = utils.extract_state(self.model)

            cands, *_ = self.model.getPseudoBranchCands()

            result = self.model.executeBranchRule('vanillafullstrong', allowaddcons)
            cands_, scores, npriocands, bestcand, _ = self.model.getVanillafullstrongData()

            assert result == scip.SCIP_RESULT.DIDNOTRUN
            assert all([c1.getCol().getLPPos() == c2.getCol().getLPPos() for c1, c2 in zip(cands, cands_)])

            action_set = [c.getCol().getLPPos() for c in cands]

            data = [gcn_state, bestcand, action_set, scores]


            # Do not record inconsistent scores. May happen if SCIP was early stopped (time limit).
            if not any([s < 0 for s in scores]) and len(gcn_state[1]['values']):

                self.max_depth = max(self.max_depth, self.model.getCurrentNode().getDepth())                
                filename = f'{self.out_dir}/sample_{self.episode}_{self.sample_counter}.pkl'

                self.file_name_list.append(filename)
                self.data_dict[filename] = {
                        'episode': self.episode,
                        'instance': self.instance,
                        'seed': self.seed,
                        'node_number': self.model.getCurrentNode().getNumber(),
                        'node_depth': self.model.getCurrentNode().getDepth(),
                        'data': data,
                        }
                        
                self.out_queue_dict[filename] = {
                    'type': 'sample',
                    'episode': self.episode,
                    'instance': self.instance,
                    'seed': self.seed,
                    'node_number': self.model.getCurrentNode().getNumber(),
                    'node_depth': self.model.getCurrentNode().getDepth(),
                    'filename': filename,
                }

                self.sample_counter += 1
    
        # if exploration and expert policies are the same, prevent running it twice
        if not query_expert or (not self.follow_expert and self.exploration_policy != 'vanillafullstrong'):
            result = self.model.executeBranchRule(self.exploration_policy, allowaddcons)

        # apply 'vanillafullstrong' branching decision if needed
        if query_expert and self.follow_expert or self.exploration_policy == 'vanillafullstrong':
            assert result == scip.SCIP_RESULT.DIDNOTRUN
            cands, scores, ncands, npriocands, bestcand = self.model.getVanillafullstrongData()
            self.model.branchVar(cands[bestcand])
            result = scip.SCIP_RESULT.BRANCHED

        return {"result": result}

class SamplingAgent_new_39_data(scip.Branchrule):

    def __init__(self, episode, instance, seed, out_queue, exploration_policy, query_expert_prob, out_dir, follow_expert=True):
        super().__init__()
        self.episode = episode
        self.instance = instance
        self.seed = seed
        self.out_queue = out_queue
        self.exploration_policy = exploration_policy
        self.query_expert_prob = query_expert_prob
        self.out_dir = out_dir
        self.max_depth = 0
        self.follow_expert = follow_expert

        self.rng = np.random.RandomState(seed)
        self.new_node = True
        self.sample_counter = 0

        self.file_name_list = []
        self.data_dict = {}
        self.out_queue_dict = {}
        self.max_depth = 0

        self.app_data = []

    def branchinit(self):
        self.khalil_root_buffer = {}


    # 从这里整理特征、训练模型
    def branchexeclp(self, allowaddcons):

        one_data = {}

        ques_time = self.model.getTotalTime()
        ques_nodes_sum = self.model.getNNodes()
        ques_lp_sum = self.model.getNLPIterations()
        ques_gap = self.model.getGap()
        ques_Primalbound = self.model.getPrimalbound()
        ques_Dualbound = self.model.getDualbound()

        one_data["ques_time"] = ques_time
        one_data["ques_nodes_sum"] = ques_nodes_sum
        one_data["ques_lp_sum"] = ques_lp_sum
        one_data["ques_gap"] = ques_gap
        one_data["ques_Primalbound"] = ques_Primalbound
        one_data["ques_Dualbound"] = ques_Dualbound


        now_node = self.model.getCurrentNode()
        parent_node = now_node.getParent()

        node_id = now_node.getNumber()
        node_parent_id = -1 if parent_node is None else parent_node.getNumber()
        node_depth = now_node.getDepth()
        now_low_bound = now_node.getLowerbound()

        one_data["node_id"] = node_id
        one_data["node_parent_id"] = node_parent_id
        one_data["node_depth"] = node_depth
        one_data["now_low_bound"] = now_low_bound

        

        





        if self.model.getNNodes() == 1:
            # initialize root buffer for Khalil features extraction
            utils.extract_khalil_variable_features(self.model, [], self.khalil_root_buffer)
        # once in a while, also run the expert policy and record the (state, action) pair
        result = scip.SCIP_RESULT.DIDNOTRUN

        query_expert = self.rng.rand() < self.query_expert_prob

        # print("query_expert:", query_expert)

        if query_expert:
            gcn_state = utils.extract_state(self.model)

            cands, *_ = self.model.getPseudoBranchCands()

            result = self.model.executeBranchRule('vanillafullstrong', allowaddcons)
            cands_, scores, npriocands, bestcand, _ = self.model.getVanillafullstrongData()

            assert result == scip.SCIP_RESULT.DIDNOTRUN
            assert all([c1.getCol().getLPPos() == c2.getCol().getLPPos() for c1, c2 in zip(cands, cands_)])

            action_set = [c.getCol().getLPPos() for c in cands]

            data = [gcn_state, bestcand, action_set, scores]


            # Do not record inconsistent scores. May happen if SCIP was early stopped (time limit).
            if not any([s < 0 for s in scores]) and len(gcn_state[1]['values']):

                self.max_depth = max(self.max_depth, self.model.getCurrentNode().getDepth())                
                filename = f'{self.out_dir}/sample_{self.episode}_{self.sample_counter}.pkl'

                self.file_name_list.append(filename)
                self.data_dict[filename] = {
                        'episode': self.episode,
                        'instance': self.instance,
                        'seed': self.seed,
                        'node_number': self.model.getCurrentNode().getNumber(),
                        'node_depth': self.model.getCurrentNode().getDepth(),
                        'data': data,
                        }
                        
                self.out_queue_dict[filename] = {
                    'type': 'sample',
                    'episode': self.episode,
                    'instance': self.instance,
                    'seed': self.seed,
                    'node_number': self.model.getCurrentNode().getNumber(),
                    'node_depth': self.model.getCurrentNode().getDepth(),
                    'filename': filename,
                }

                self.sample_counter += 1
    
        # if exploration and expert policies are the same, prevent running it twice
        if not query_expert or (not self.follow_expert and self.exploration_policy != 'vanillafullstrong'):
            result = self.model.executeBranchRule(self.exploration_policy, allowaddcons)

        # apply 'vanillafullstrong' branching decision if needed
        if query_expert and self.follow_expert or self.exploration_policy == 'vanillafullstrong':
            assert result == scip.SCIP_RESULT.DIDNOTRUN
            cands, scores, ncands, npriocands, bestcand = self.model.getVanillafullstrongData()
            self.model.branchVar(cands[bestcand])
            result = scip.SCIP_RESULT.BRANCHED


        

        return {"result": result}

class SamplingAgent_new_40(scip.Branchrule):

    def __init__(self, episode, instance, seed, out_queue, exploration_policy, query_expert_prob, out_dir, follow_expert=True):
        super().__init__()
        self.episode = episode
        self.instance = instance
        self.seed = seed
        self.out_queue = out_queue
        self.exploration_policy = exploration_policy
        self.query_expert_prob = query_expert_prob
        self.out_dir = out_dir
        self.max_depth = 0
        self.follow_expert = follow_expert

        self.rng = np.random.RandomState(seed)
        self.new_node = True
        self.sample_counter = 0

        self.file_name_list = []
        self.data_dict = {}
        self.out_queue_dict = {}
        self.max_depth = 0

    def branchinit(self):
        self.khalil_root_buffer = {}


    # 从这里整理特征、训练模型
    def branchexeclp(self, allowaddcons):

        if self.model.getNNodes() == 1:
            # initialize root buffer for Khalil features extraction
            utils.extract_khalil_variable_features(self.model, [], self.khalil_root_buffer)
        # once in a while, also run the expert policy and record the (state, action) pair
        result = scip.SCIP_RESULT.DIDNOTRUN

        query_expert = self.rng.rand() < self.query_expert_prob

        # print("query_expert:", query_expert)

        if query_expert:
            gcn_state = utils.extract_state_new_40(self.model)

            cands, *_ = self.model.getPseudoBranchCands()

            result = self.model.executeBranchRule('vanillafullstrong', allowaddcons)
            cands_, scores, npriocands, bestcand, _ = self.model.getVanillafullstrongData()

            assert result == scip.SCIP_RESULT.DIDNOTRUN
            assert all([c1.getCol().getLPPos() == c2.getCol().getLPPos() for c1, c2 in zip(cands, cands_)])

            action_set = [c.getCol().getLPPos() for c in cands]

            data = [gcn_state, bestcand, action_set, scores]


            # Do not record inconsistent scores. May happen if SCIP was early stopped (time limit).
            if not any([s < 0 for s in scores]) and len(gcn_state[1]['values']):

                self.max_depth = max(self.max_depth, self.model.getCurrentNode().getDepth())                
                filename = f'{self.out_dir}/sample_{self.episode}_{self.sample_counter}.pkl'

                self.file_name_list.append(filename)
                self.data_dict[filename] = {
                        'episode': self.episode,
                        'instance': self.instance,
                        'seed': self.seed,
                        'node_number': self.model.getCurrentNode().getNumber(),
                        'node_depth': self.model.getCurrentNode().getDepth(),
                        'data': data,
                        }
                        
                self.out_queue_dict[filename] = {
                    'type': 'sample',
                    'episode': self.episode,
                    'instance': self.instance,
                    'seed': self.seed,
                    'node_number': self.model.getCurrentNode().getNumber(),
                    'node_depth': self.model.getCurrentNode().getDepth(),
                    'filename': filename,
                }

                self.sample_counter += 1
    
        # if exploration and expert policies are the same, prevent running it twice
        if not query_expert or (not self.follow_expert and self.exploration_policy != 'vanillafullstrong'):
            result = self.model.executeBranchRule(self.exploration_policy, allowaddcons)

        # apply 'vanillafullstrong' branching decision if needed
        if query_expert and self.follow_expert or self.exploration_policy == 'vanillafullstrong':
            assert result == scip.SCIP_RESULT.DIDNOTRUN
            cands, scores, ncands, npriocands, bestcand = self.model.getVanillafullstrongData()
            self.model.branchVar(cands[bestcand])
            result = scip.SCIP_RESULT.BRANCHED

        return {"result": result}






def make_samples_max_depth(args, in_queue, out_queue):
    """
    Worker loop: fetch an instance, run an episode and record samples.
    Parameters
    ----------
    in_queue : multiprocessing.Queue
        Input queue from which orders are received.
    out_queue : multiprocessing.Queue
        Output queue in which to send samples.
    """

    while True:
        episode, instance, seed, exploration_policy, query_expert_prob, time_limit, out_dir = in_queue.get()
        print(f'[w {os.getpid()}] episode {episode}, seed {seed}, processing instance \'{instance}\'...')

        m = scip.Model()
        m.setIntParam('display/verblevel', 0)
        m.readProblem(f'{instance}')
        set_scip(m, seed=seed, separator=False, restart=False)
        # personalize_scip(m, seed=seed, presolver=True, separator_root=True, propagator=True)
        m.setIntParam('timing/clocktype', 2)
        m.setRealParam('limits/time', time_limit)


        if args.code_id == "baseline":
            branchrule = SamplingAgent_baseline(
                episode=episode,
                instance=instance,
                seed=seed,
                out_queue=out_queue,
                exploration_policy=exploration_policy,
                query_expert_prob=query_expert_prob,
                out_dir=out_dir)
            
            m.includeBranchrule(
            branchrule=branchrule,
            name="Sampling branching rule", desc="",
            priority=666666, maxdepth=-1, maxbounddist=1)

        elif args.code_id == "new_3":
            branchrule = SamplingAgent_new_3(
                episode=episode,
                instance=instance,
                seed=seed,
                out_queue=out_queue,
                exploration_policy=exploration_policy,
                query_expert_prob=query_expert_prob,
                out_dir=out_dir)

            m.includeBranchrule(
            branchrule=branchrule,
            name="Sampling branching rule", desc="",
            priority=666666, maxdepth=-1, maxbounddist=1)

        elif args.code_id == "new_3_root":
            m.setParam('limits/nodes', 1)
            m.setParam('limits/totalnodes', 1)
            m.setParam('limits/stallnodes', 1)
            branchrule = SamplingAgent_new_3_root(
                episode=episode,
                instance=instance,
                seed=seed,
                out_queue=out_queue,
                exploration_policy=exploration_policy,
                query_expert_prob=query_expert_prob,
                out_dir=out_dir)

            m.includeBranchrule(
            branchrule=branchrule,
            name="Sampling branching rule", desc="",
            priority=666666, maxdepth=-1, maxbounddist=1)

        elif args.code_id == "new_20":
            branchrule = SamplingAgent_new_20(
                episode=episode,
                instance=instance,
                seed=seed,
                out_queue=out_queue,
                exploration_policy=exploration_policy,
                query_expert_prob=query_expert_prob,
                out_dir=out_dir)
            
            m.includeBranchrule(
            branchrule=branchrule,
            name="Sampling branching rule", desc="",
            priority=666666, maxdepth=-1, maxbounddist=1)

        elif args.code_id == "new_33_1":
            branchrule = SamplingAgent_new_33_1(
                episode=episode,
                instance=instance,
                seed=seed,
                out_queue=out_queue,
                exploration_policy=exploration_policy,
                query_expert_prob=query_expert_prob,
                out_dir=out_dir)
            
            m.includeBranchrule(
            branchrule=branchrule,
            name="Sampling branching rule", desc="",
            priority=666666, maxdepth=1, maxbounddist=1)

        elif args.code_id == "new_33_2":
            branchrule = SamplingAgent_new_33_2(
                episode=episode,
                instance=instance,
                seed=seed,
                out_queue=out_queue,
                exploration_policy=exploration_policy,
                query_expert_prob=query_expert_prob,
                out_dir=out_dir)

            m.includeBranchrule(
            branchrule=branchrule,
            name="Sampling branching rule", desc="",
            priority=666666, maxdepth=4, maxbounddist=1)

        elif args.code_id == "new_33_3":
            branchrule = SamplingAgent_new_33_3(
                episode=episode,
                instance=instance,
                seed=seed,
                out_queue=out_queue,
                exploration_policy=exploration_policy,
                query_expert_prob=query_expert_prob,
                out_dir=out_dir)
            
            m.includeBranchrule(
            branchrule=branchrule,
            name="Sampling branching rule", desc="",
            priority=666666, maxdepth=7, maxbounddist=1)

        elif args.code_id == "new_33_4":
            branchrule = SamplingAgent_new_33_4(
                episode=episode,
                instance=instance,
                seed=seed,
                out_queue=out_queue,
                exploration_policy=exploration_policy,
                query_expert_prob=query_expert_prob,
                out_dir=out_dir)
            
            m.includeBranchrule(
            branchrule=branchrule,
            name="Sampling branching rule", desc="",
            priority=666666, maxdepth=-1, maxbounddist=1)

        elif args.code_id == "new_35":
            branchrule = SamplingAgent_new_35(
                episode=episode,
                instance=instance,
                seed=seed,
                out_queue=out_queue,
                exploration_policy=exploration_policy,
                query_expert_prob=query_expert_prob,
                out_dir=out_dir)
            
            m.includeBranchrule(
            branchrule=branchrule,
            name="Sampling branching rule", desc="",
            priority=666666, maxdepth=6, maxbounddist=1)


        elif args.code_id == "new_37":
            branchrule = SamplingAgent_new_37(
                episode=episode,
                instance=instance,
                seed=seed,
                out_queue=out_queue,
                exploration_policy=exploration_policy,
                query_expert_prob=query_expert_prob,
                out_dir=out_dir)

            m.includeBranchrule(
            branchrule=branchrule,
            name="Sampling branching rule", desc="",
            priority=666666, maxdepth=-1, maxbounddist=1)

        elif args.code_id == "new_37_1":
            branchrule = SamplingAgent_new_37_1(
                episode=episode,
                instance=instance,
                seed=seed,
                out_queue=out_queue,
                exploration_policy=exploration_policy,
                query_expert_prob=query_expert_prob,
                out_dir=out_dir)

            m.includeBranchrule(
            branchrule=branchrule,
            name="Sampling branching rule", desc="",
            priority=666666, maxdepth=-1, maxbounddist=1)


        elif args.code_id == "new_39":
            branchrule = SamplingAgent_new_39(
                episode=episode,
                instance=instance,
                seed=seed,
                out_queue=out_queue,
                exploration_policy=exploration_policy,
                query_expert_prob=query_expert_prob,
                out_dir=out_dir)

            m.includeBranchrule(
            branchrule=branchrule,
            name="Sampling branching rule", desc="",
            priority=666666, maxdepth=-1, maxbounddist=1)

        elif args.code_id == "new_39_data":
            branchrule = SamplingAgent_new_39_data(
                episode=episode,
                instance=instance,
                seed=seed,
                out_queue=out_queue,
                exploration_policy=exploration_policy,
                query_expert_prob=query_expert_prob,
                out_dir=out_dir)

            m.includeBranchrule(
            branchrule=branchrule,
            name="Sampling branching rule", desc="",
            priority=666666, maxdepth=-1, maxbounddist=1)

        


        elif args.code_id == "new_40":
            branchrule = SamplingAgent_new_40(
                episode=episode,
                instance=instance,
                seed=seed,
                out_queue=out_queue,
                exploration_policy=exploration_policy,
                query_expert_prob=query_expert_prob,
                out_dir=out_dir)

            m.includeBranchrule(
            branchrule=branchrule,
            name="Sampling branching rule", desc="",
            priority=666666, maxdepth=-1, maxbounddist=1)



        m.setBoolParam('branching/vanillafullstrong/integralcands', True)
        m.setBoolParam('branching/vanillafullstrong/scoreall', True)
        m.setBoolParam('branching/vanillafullstrong/collectscores', True)
        m.setBoolParam('branching/vanillafullstrong/donotbranch', True)
        m.setBoolParam('branching/vanillafullstrong/idempotent', True)

        out_queue.put({
            'type': 'start',
            'episode': episode,
            'instance': instance,
            'seed': seed,
        })

        m.optimize()

        for one_file_name in branchrule.file_name_list:
            branchrule.data_dict[one_file_name]['max_depth'] = branchrule.max_depth
            branchrule.data_dict[one_file_name]['min_depth'] = branchrule.min_depth

            print("branchrule.max_depth:", branchrule.max_depth)
            print("branchrule.min_depth:", branchrule.min_depth)

            with gzip.open(one_file_name, 'wb') as f:
                    pickle.dump(branchrule.data_dict[one_file_name], f)

            out_queue.put(branchrule.out_queue_dict[one_file_name])


        m.freeProb()

        print(f"[w {os.getpid()}] episode {episode} done, {branchrule.sample_counter} samples")

        out_queue.put({
            'type': 'done',
            'episode': episode,
            'instance': instance,
            'seed': seed,
        })



def send_orders(orders_queue, instances, seed, exploration_policy, query_expert_prob, time_limit, out_dir):
    """
    Continuously send sampling orders to workers (relies on limited
    queue capacity).
    Parameters
    ----------
    orders_queue : multiprocessing.Queue
        Queue to which to send orders.
    instances : list
        Instance file names from which to sample episodes.
    seed : int
        Random seed for reproducibility.
    exploration_policy : str
        Branching strategy for exploration.
    query_expert_prob : float in [0, 1]
        Probability of running the expert strategy and collecting samples.
    time_limit : float in [0, 1e+20]
        Maximum running time for an episode, in seconds.
    out_dir: str
        Output directory in which to write samples.
    """
    rng = np.random.RandomState(seed)

    episode = 0
    while True:
        instance = rng.choice(instances)
        seed = rng.randint(2**32)
        orders_queue.put([episode, instance, seed, exploration_policy, query_expert_prob, time_limit, out_dir])
        episode += 1


def collect_samples(args, instances, out_dir, rng, n_samples, n_jobs,
                    exploration_policy, query_expert_prob, time_limit):
    """
    Runs branch-and-bound episodes on the given set of instances, and collects
    randomly (state, action) pairs from the 'vanilla-fullstrong' expert
    brancher.
    Parameters
    ----------
    instances : list
        Instance files from which to collect samples.
    out_dir : str
        Directory in which to write samples.
    rng : numpy.random.RandomState
        A random number generator for reproducibility.
    n_samples : int
        Number of samples to collect.
    n_jobs : int
        Number of jobs for parallel sampling.
    exploration_policy : str
        Exploration policy (branching rule) for sampling.
    query_expert_prob : float in [0, 1]
        Probability of using the expert policy and recording a (state, action)
        pair.
    time_limit : float in [0, 1e+20]
        Maximum running time for an episode, in seconds.
    """

    os.makedirs(out_dir, exist_ok=True)

    # start workers
    orders_queue = mp.Queue(maxsize=2*n_jobs)
    answers_queue = mp.SimpleQueue()
    workers = []
    for i in range(n_jobs):
        p = mp.Process(
            target=make_samples_max_depth,
            args=(args, orders_queue, answers_queue),
            daemon=True)

        workers.append(p)
        p.start()

    tmp_samples_dir = f'{out_dir}/tmp'

    os.makedirs(tmp_samples_dir, exist_ok=True)

    # start dispatcher
    dispatcher = mp.Process(
            target=send_orders,
            args=(orders_queue, instances, rng.randint(2**32), exploration_policy, query_expert_prob, time_limit, tmp_samples_dir),
            daemon=True)
    dispatcher.start()

    # record answers and write samples
    buffer = {}
    current_episode = 0
    i = 0
    in_buffer = 0
    while i < n_samples:
        sample = answers_queue.get()

        # add received sample to buffer
        if sample['type'] == 'start':
            buffer[sample['episode']] = []
        else:
            buffer[sample['episode']].append(sample)
            if sample['type'] == 'sample':
                in_buffer += 1

        # if any, write samples from current episode
        while current_episode in buffer and buffer[current_episode]:
            samples_to_write = buffer[current_episode]
            buffer[current_episode] = []

            for sample in samples_to_write:

                # if no more samples here, move to next episode
                if sample['type'] == 'done':
                    del buffer[current_episode]
                    current_episode += 1

                # else write sample
                else:
                    os.rename(sample['filename'], f'{out_dir}/sample_{i+1}.pkl')
                    in_buffer -= 1
                    i += 1
                    print(f"[m {os.getpid()}] {i} / {n_samples} samples written, ep {sample['episode']} ({in_buffer} in buffer).")

                    # early stop dispatcher (hard)
                    if in_buffer + i >= n_samples and dispatcher.is_alive():
                        dispatcher.terminate()
                        print(f"[m {os.getpid()}] dispatcher stopped...")

                    # as soon as enough samples are collected, stop
                    if i == n_samples:
                        buffer = {}
                        break

    # stop all workers (hard)
    for p in workers:
        p.terminate()

    shutil.rmtree(tmp_samples_dir, ignore_errors=True)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '-p', '--problem',
        help='MILP instance type to process.',
        choices=['setcover', 'cauctions', 'facility', 'indset', 'gisp', 'wpms', 'fcmcnf'],
        default='setcover',
    )
    parser.add_argument(
        '-s', '--seed',
        help='Random generator seed.',
        type=utils.valid_seed,
        default=0,
    )
    parser.add_argument(
        '-j', '--njobs',
        help='Number of parallel jobs.',
        type=int, 
        default=1,
    )
    parser.add_argument(
        '--code_id', 
        type=str, 
        default='baseline')

    parser.add_argument('--scale', type=int, default=0)
    args = parser.parse_args()

    print(f"seed {args.seed}")
    
    if args.scale == 0:
        train_size = 5
        valid_size = 3
    else:
        train_size = 100 * args.scale#10
        valid_size = 20 * args.scale#2
        
    test_size = 0*args.scale
    exploration_strategy = 'pscost'
    node_record_prob = 0.05
    time_limit = 3600   

    if args.problem == 'setcover':
        num = 0
        instances_train = glob.glob(f'../data/instances/setcover_400r_1000c_0.05d_100mc_0se/train_500r/*.lp')
        instances_valid = glob.glob(f'../data/instances/setcover_400r_1000c_0.05d_100mc_0se/valid_500r/*.lp')
        instances_test = glob.glob(f'../data/instances/setcover_400r_1000c_0.05d_100mc_0se/test_500r/*.lp')
        # instances_train =instances_train[int(0.1*len(instances_train))*num: int(0.1*len(instances_train))*(num+1)]
        # instances_valid =instances_valid[int(0.1*len(instances_valid))*num: int(0.1*len(instances_valid))*(num+1)]
        # instances_test =instances_test[int(0.1*len(instances_test))*num: int(0.1*len(instances_test))*(num+1)]
        out_dir = f'../data/samples/setcover_400r_1000c_0.05d_100mc_0se/{args.code_id}/'

    elif args.problem == 'cauctions':
        instances_train = glob.glob(f'../data/instances/cauctions_0se/train_100_500/*.lp')
        instances_valid = glob.glob(f'../data/instances/cauctions_0se/valid_100_500/*.lp')
        instances_test = glob.glob(f'../data/instances/cauctions_0se/test_100_500/*.lp')
        out_dir = f'../data/samples/cauctions_0se/{args.code_id}/'

    elif args.problem == 'indset':
        instances_train = glob.glob(f'../data/instances/indset_400n_4a_0se/train_500n/*.lp')
        instances_valid = glob.glob(f'../data/instances/indset_400n_4a_0se/valid_500n/*.lp')
        instances_test = glob.glob(f'../data/instances/indset_400n_4a_0se/test_500n/*.lp')
        out_dir = f'../data/samples/indset_400n_4a_0se/{args.code_id}/'

    elif args.problem == 'facility':
        instances_train = glob.glob('../data/instances/facility_0se/train_100_100_5/*.lp')
        instances_valid = glob.glob('../data/instances/facility_0se/valid_100_100_5/*.lp')
        instances_test = glob.glob('../data/instances/facility_0se/test_100_100_5/*.lp')
        out_dir = f'../data/samples/facility_0se/{args.code_id}/'
        time_limit = 600

    elif args.problem == 'facility':
        instances_train = glob.glob('../data/instances/facility_0se/train_100_100_5/*.lp')
        instances_valid = glob.glob('../data/instances/facility_0se/valid_100_100_5/*.lp')
        instances_test = glob.glob('../data/instances/facility_0se/test_100_100_5/*.lp')
        out_dir = f'../data/samples/facility_0se/{args.code_id}/'
        time_limit = 600

    elif args.problem == 'gisp':
        instances_train = glob.glob('/home/LAB/linglong/3.rl4my/3.nodes/1.node_model/problem_generation/data/GISP/train/*.lp')
        instances_valid = glob.glob('/home/LAB/linglong/3.rl4my/3.nodes/1.node_model/problem_generation/data/GISP/test/*.lp')
        instances_test = glob.glob('/home/LAB/linglong/3.rl4my/3.nodes/1.node_model/problem_generation/data/GISP/transfer/*.lp')
        out_dir = f'../data/samples/gisp/{args.code_id}/'
        time_limit = 600

    elif args.problem == 'wpms':
        instances_train = glob.glob('/home/LAB/linglong/3.rl4my/3.nodes/1.node_model/problem_generation/data/WPMS/train/*.lp')
        instances_valid = glob.glob('/home/LAB/linglong/3.rl4my/3.nodes/1.node_model/problem_generation/data/WPMS/test/*.lp')
        instances_test = glob.glob('/home/LAB/linglong/3.rl4my/3.nodes/1.node_model/problem_generation/data/WPMS/transfer/*.lp')
        out_dir = f'../data/samples/wpms/{args.code_id}/'
        time_limit = 600

    elif args.problem == 'fcmcnf':
        instances_train = glob.glob('/home/LAB/linglong/3.rl4my/3.nodes/1.node_model/problem_generation/data/FCMCNF/train/*.lp')
        instances_valid = glob.glob('/home/LAB/linglong/3.rl4my/3.nodes/1.node_model/problem_generation/data/FCMCNF/test/*.lp')
        instances_test = glob.glob('/home/LAB/linglong/3.rl4my/3.nodes/1.node_model/problem_generation/data/FCMCNF/transfer/*.lp')
        out_dir = f'../data/samples/fcmcnf/{args.code_id}/'
        time_limit = 600

    else:
        raise NotImplementedError

    print(f"{len(instances_train)} train instances for {train_size} samples")
    print(f"{len(instances_valid)} validation instances for {valid_size} samples")
    print(f"{len(instances_test)} test instances for {test_size} samples")

    # create output directory, throws an error if it already exists

    os.makedirs(out_dir, exist_ok=True)

    rng = np.random.RandomState(args.seed)
    collect_samples(args, instances_train, out_dir + '/train', rng, train_size,
                    args.njobs, exploration_policy=exploration_strategy,
                    query_expert_prob=node_record_prob,
                    time_limit=time_limit)

    rng = np.random.RandomState(args.seed + 1)
    collect_samples(args, instances_valid, out_dir + '/valid', rng, valid_size,
                    args.njobs, exploration_policy=exploration_strategy,
                    query_expert_prob=node_record_prob,
                    time_limit=time_limit)

    rng = np.random.RandomState(args.seed + 2)
    collect_samples(args, instances_test, out_dir + '/test', rng, test_size,
                    args.njobs, exploration_policy=exploration_strategy,
                    query_expert_prob=node_record_prob,
                    time_limit=time_limit)