import argparse
# import pulp
from ast import arg
import sys, time
import os
import pandas as pd
import pickle
import torch
from pyscipopt import Model, Branchrule, SCIP_RESULT
sys.path.append('../')
import  src.utils as utils
import src.model as model
from src.utils import set_device_seed, set_scip
from src.logger import Logger
logger = Logger.logger
import shutil
from collections import deque,OrderedDict
import dgl
import heapq
import numpy as np
from multiprocessing.dummy import Pool as ThreadPool
import time
import copy
import pickle
from datetime import datetime
from scipy.optimize import linprog
from concurrent.futures import ProcessPoolExecutor
from functools import partial
import multiprocessing

epsilon = 1e-5

def check_deque(temp_deque):
    check_list = deque()
    for one_data in temp_deque:
        if one_data is None:
            check_list.append(None)
        else:
            check_list.append("data")
    
    # print("check_list:",check_list)

def geo_mean(x, shift=1.0):
    # 1-shift geometry product
    a = torch.tensor(x).float() + shift
    return a.log().mean().exp() - shift

def num_mean(x):
    if len(x) == 0:
        return 0
    return sum(x) / len(x)

class PolicyBranching(Branchrule):
    # Get accurate data for imitation learning and supervised learning
    # 为模仿学习和监督学习获得准确的数据

    def __init__(self, scip, agent, device, args):
        super().__init__()
        self.model = scip
        self.device = device
        self.agent = agent
        self.args = args
        self.count = 0
        self.dom_reduction = 0
        self.cut_off = 0

    def branchinit(self):
        self.root_buffer = {}


    def branchexeclp(self, allowaddcons):
        
        result = SCIP_RESULT.DIDNOTRUN
        self.count += 1

        result = self.model.executeBranchRule('relpscost', allowaddcons)
        if result == SCIP_RESULT.REDUCEDDOM:
            self.dom_reduction += 1
        elif result == SCIP_RESULT.CUTOFF:
            self.cut_off += 2
        return {'result': result}

        

    def branchexeclp_old(self, allowaddcons):

        result = SCIP_RESULT.DIDNOTRUN
        self.count += 1

        if self.args.method == 'relpscost' or self.args.method == 'presolve':
            result= self.model.executeBranchRule('relpscost', allowaddcons)
        elif self.args.method == 'fullstrong':
            result = self.model.executeBranchRule('fullstrong', allowaddcons)
        elif self.args.method == 'vanillafullstrong' or self.args.method == 'vanilla':
            result = self.model.executeBranchRule('vanillafullstrong', allowaddcons)
        elif self.args.method == 'mostinf':
            result = self.model.executeBranchRule('mostinf', allowaddcons)
        elif self.args.method == 'random':
            result = self.model.executeBranchRule('random', allowaddcons)

        if result == SCIP_RESULT.REDUCEDDOM:
            self.dom_reduction += 1
        elif result == SCIP_RESULT.CUTOFF:
            self.cut_off += 2
        return {'result': result}

class PolicyBranching_baseline(Branchrule):
    # Get accurate data for imitation learning and supervised learning
    # 为模仿学习和监督学习获得准确的数据

    def __init__(self, scip, agent, device, args):
        super().__init__()
        self.model = scip
        self.device = device
        self.agent = agent
        self.args = args
        self.count = 0
        self.dom_reduction = 0
        self.cut_off = 0

        self.list_time_get_state = []
        self.list_time_get_graph = []
        self.list_time_get_cands = []
        self.list_time_get_action_set = []
        self.list_time_get_agent_out = []
        self.list_time_all = []
        self.list_time_get_best_cand = []


    def branchinit(self):
        self.root_buffer = {}

    def branchexeclp(self, allowaddcons):
        
        time_start = time.time()

        result = SCIP_RESULT.DIDNOTRUN
        

        state = utils.extract_state(self.model)
        g = utils.graph_transform(state).to(self.device)

        cands, *_ = self.model.getPseudoBranchCands()

        action_set = [c.getCol().getLPPos() for c in cands]

        agent_out = self.agent.eval_forward(g)


        score = agent_out[action_set]
        best_cand = int(score.argmax())
        
        time_end = time.time() - time_start
        self.list_time_all.append(time_end)

        self.model.branchVar(cands[best_cand])

        result = SCIP_RESULT.BRANCHED

        # print("SCIP_RESULT.BRANCHED:",SCIP_RESULT.BRANCHED)

        if result == SCIP_RESULT.REDUCEDDOM:
            self.dom_reduction += 1
        elif result == SCIP_RESULT.CUTOFF:
            self.cut_off += 2

        self.count += 1

        return {'result': result}

class PolicyBranching_gcn(Branchrule):
    # Get accurate data for imitation learning and supervised learning
    # 为模仿学习和监督学习获得准确的数据

    def __init__(self, scip, agent, device, args):
        super().__init__()
        self.model = scip
        self.device = device
        self.agent = agent
        self.args = args
        self.count = 0
        self.dom_reduction = 0
        self.cut_off = 0

        self.list_time_get_state = []
        self.list_time_get_graph = []
        self.list_time_get_cands = []
        self.list_time_get_action_set = []
        self.list_time_get_agent_out = []
        self.list_time_all = []
        self.list_time_get_best_cand = []


    def branchinit(self):
        self.root_buffer = {}

    def branchexeclp(self, allowaddcons):

        result = SCIP_RESULT.DIDNOTRUN
        self.count += 1


        if self.args.method == 'gcnn':
            state = utils.extract_state(self.model)
            g = utils.graph_transform(state).to(self.device)

        elif self.args.method == 'pd':
            state = utils.extract_pd_state(self.model)
            g = utils.graph_transform(state).to(self.device)

        # new
        elif self.args.method == 'baseline':
            state = utils.extract_state(self.model)
            g = utils.graph_transform(state).to(self.device)

        elif self.args.method == 'new_1':
            state = utils.extract_state_new_1(self.model)
            g = utils.graph_transform_new_1(state).to(self.device)

        elif self.args.method == 'new_2':
            state = utils.extract_state_new_2(self.model)
            g = utils.graph_transform_new_2(state).to(self.device)

        elif self.args.method == 'new_3':
            state = utils.extract_state_new_3(self.model)
            g = utils.graph_transform_new_3(state).to(self.device)

        elif self.args.method == 'new_5':
            state = utils.extract_state_new_5(self.model)
            g = utils.graph_transform_new_5(state).to(self.device)

        elif self.args.method == 'new_6':
            state = utils.extract_state_new_6(self.model)
            g = utils.graph_transform_new_6(state).to(self.device)

        elif self.args.method == 'new_7':
            state = utils.extract_state_new_7(self.model)
            g = utils.graph_transform_new_7(state).to(self.device)

        elif self.args.method == 'new_16':
            state = utils.extract_state_new_16(self.model)
            g = utils.graph_transform_new_16(state).to(self.device)

        elif self.args.method == 'new_17':
            state = utils.extract_state_new_17(self.model)
            g = utils.graph_transform_new_17(state).to(self.device)

        elif self.args.method == 'new_18':
            state = utils.extract_state_new_16(self.model)
            g = utils.graph_transform_new_16(state).to(self.device)

        else:
            print("!!!!!!!method error!!!!")
            raise NotImplementedError
        cands, *_ = self.model.getPseudoBranchCands()

        action_set = [c.getCol().getLPPos() for c in cands]
        
        agent_out = self.agent.eval_forward(g)
        score = agent_out[action_set]
        best_cand = int(score.argmax())

        self.model.branchVar(cands[best_cand])
        result = SCIP_RESULT.BRANCHED

        if result == SCIP_RESULT.REDUCEDDOM:
            self.dom_reduction += 1
        elif result == SCIP_RESULT.CUTOFF:
            self.cut_off += 2
        return {'result': result}

    def branchexeclp_time(self, allowaddcons):
        
        time_start = time.time()

        result = SCIP_RESULT.DIDNOTRUN
        self.count += 1


        if self.args.method == 'relpscost' or self.args.method == 'presolve':
            result= self.model.executeBranchRule('relpscost', allowaddcons)
        elif self.args.method == 'fullstrong':
            result = self.model.executeBranchRule('fullstrong', allowaddcons)
        elif self.args.method == 'vanillafullstrong' or self.args.method == 'vanilla':
            result = self.model.executeBranchRule('vanillafullstrong', allowaddcons)
        elif self.args.method == 'mostinf':
            result = self.model.executeBranchRule('mostinf', allowaddcons)
        elif self.args.method == 'random':
            result = self.model.executeBranchRule('random', allowaddcons)
        # new
        elif self.args.method in ['pd', 'gcnn','baseline','new_1','new_2','new_3','new_5','new_6','new_7']:
            if self.args.method == 'gcnn':
                state = utils.extract_state(self.model)
                g = utils.graph_transform(state).to(self.device)
            elif self.args.method == 'pd':
                state = utils.extract_pd_state(self.model)
                g = utils.graph_transform(state).to(self.device)
            # new
            elif self.args.method == 'baseline':
                
                time_1 = time.time()
                state = utils.extract_state(self.model)
                time_get_state = time.time() - time_1
                self.list_time_get_state.append(time_get_state)

                time_1 = time.time()
                g = utils.graph_transform(state).to(self.device)
                time_get_graph = time.time() - time_1
                self.list_time_get_graph.append(time_get_graph)

            elif self.args.method == 'new_1':
                state = utils.extract_state_new_1(self.model)
                g = utils.graph_transform_new_1(state).to(self.device)

            elif self.args.method == 'new_2':
                state = utils.extract_state_new_2(self.model)
                g = utils.graph_transform_new_2(state).to(self.device)

            elif self.args.method == 'new_3':
                state = utils.extract_state_new_3(self.model)
                g = utils.graph_transform_new_3(state).to(self.device)

            elif self.args.method == 'new_5':
                state = utils.extract_state_new_5(self.model)
                g = utils.graph_transform_new_5(state).to(self.device)

            elif self.args.method == 'new_6':
                state = utils.extract_state_new_6(self.model)
                g = utils.graph_transform_new_6(state).to(self.device)

            elif self.args.method == 'new_7':
                state = utils.extract_state_new_7(self.model)
                g = utils.graph_transform_new_7(state).to(self.device)

            else:
                print("!!!!!!!method error!!!!")
                raise NotImplementedError


            time_1 = time.time()
            cands, *_ = self.model.getPseudoBranchCands()
            time_get_cands = time.time() - time_1
            self.list_time_get_cands.append(time_get_cands)

            
            time_1 = time.time()
            action_set = [c.getCol().getLPPos() for c in cands]
            time_get_action_set = time.time() - time_1
            self.list_time_get_action_set.append(time_get_action_set)

            time_1 = time.time()
            agent_out = self.agent.eval_forward(g)
            time_get_agent_out = time.time() - time_1
            self.list_time_get_agent_out.append(time_get_agent_out)

        
            time_1 = time.time()
            score = agent_out[action_set]
            best_cand = int(score.argmax())
            time_get_best_cand = time.time() - time_1
            self.list_time_get_best_cand.append(time_get_best_cand)

            
            

            time_end = time.time() - time_start
            self.list_time_all.append(time_end)



            self.model.branchVar(cands[best_cand])

            result = SCIP_RESULT.BRANCHED

            # print("SCIP_RESULT.BRANCHED:",SCIP_RESULT.BRANCHED)

        if result == SCIP_RESULT.REDUCEDDOM:
            self.dom_reduction += 1
        elif result == SCIP_RESULT.CUTOFF:
            self.cut_off += 2
        return {'result': result}

    def branchexeclp_lp(self, allowaddcons):
        result = SCIP_RESULT.DIDNOTRUN
        self.count += 1

        time_1 = time.time()
        self.model.lpiSolvePrimal()
        time_lp_primal = time.time() - time_1
        self.list_1.append(time_lp_primal)

        time_1 = time.time()
        self.model.lpiSolveDual()
        time_lp_dual = time.time() - time_1
        self.list_2.append(time_lp_dual)


        if self.args.method == 'relpscost' or self.args.method == 'presolve':
            result= self.model.executeBranchRule('relpscost', allowaddcons)
        elif self.args.method == 'fullstrong':
            result = self.model.executeBranchRule('fullstrong', allowaddcons)
        elif self.args.method == 'vanillafullstrong' or self.args.method == 'vanilla':
            result = self.model.executeBranchRule('vanillafullstrong', allowaddcons)
        elif self.args.method == 'mostinf':
            result = self.model.executeBranchRule('mostinf', allowaddcons)
        elif self.args.method == 'random':
            result = self.model.executeBranchRule('random', allowaddcons)
        # new
        elif self.args.method in ['pd', 'gcnn','baseline','new_1','new_2','new_3','new_5','new_6','new_7']:
            if self.args.method == 'gcnn':
                state = utils.extract_state(self.model)
                g = utils.graph_transform(state).to(self.device)
            elif self.args.method == 'pd':
                state = utils.extract_pd_state(self.model)
                g = utils.graph_transform(state).to(self.device)
            # new
            elif self.args.method == 'baseline':
                state = utils.extract_state(self.model)
                g = utils.graph_transform(state).to(self.device)

            elif self.args.method == 'new_1':
                state = utils.extract_state_new_1(self.model)
                g = utils.graph_transform_new_1(state).to(self.device)

            elif self.args.method == 'new_2':
                state = utils.extract_state_new_2(self.model)
                g = utils.graph_transform_new_2(state).to(self.device)

            elif self.args.method == 'new_3':
                state = utils.extract_state_new_3(self.model)
                g = utils.graph_transform_new_3(state).to(self.device)

            elif self.args.method == 'new_5':
                state = utils.extract_state_new_5(self.model)
                g = utils.graph_transform_new_5(state).to(self.device)

            elif self.args.method == 'new_6':
                state = utils.extract_state_new_6(self.model)
                g = utils.graph_transform_new_6(state).to(self.device)

            elif self.args.method == 'new_7':
                state = utils.extract_state_new_7(self.model)
                g = utils.graph_transform_new_7(state).to(self.device)

            else:
                print("!!!!!!!method error!!!!")
                raise NotImplementedError


            cands, *_ = self.model.getPseudoBranchCands()
            action_set = [c.getCol().getLPPos() for c in cands]
            
            agent_out = self.agent.eval_forward(g)

            score = agent_out[action_set]

            best_cand = int(score.argmax())

            # global count
            # print("count:",count,"\tbest_cand:",best_cand)
            # count = count + 1

            self.model.branchVar(cands[best_cand])
            result = SCIP_RESULT.BRANCHED

            # print("SCIP_RESULT.BRANCHED:",SCIP_RESULT.BRANCHED)

        if result == SCIP_RESULT.REDUCEDDOM:
            self.dom_reduction += 1
        elif result == SCIP_RESULT.CUTOFF:
            self.cut_off += 2
        return {'result': result}

class PolicyBranching_fsb_gcnn(Branchrule):
    # Get accurate data for imitation learning and supervised learning
    # 为模仿学习和监督学习获得准确的数据

    def __init__(self, scip, agent, device, args):
        super().__init__()
        self.model = scip
        self.device = device
        self.agent = agent
        self.args = args
        self.count = 0
        self.dom_reduction = 0
        self.cut_off = 0


    def branchinit(self):
        self.root_buffer = {}

    def branchexeclp(self, allowaddcons):

        result = SCIP_RESULT.DIDNOTRUN
        self.count += 1

        if self.count <= 20:
            self.args.method = 'fullstrong'
        else:
            self.args.method = 'baseline'


        if self.args.method == 'relpscost' or self.args.method == 'presolve':
            result= self.model.executeBranchRule('relpscost', allowaddcons)
        elif self.args.method == 'fullstrong':
            result = self.model.executeBranchRule('fullstrong', allowaddcons)
        elif self.args.method == 'vanillafullstrong' or self.args.method == 'vanilla':
            result = self.model.executeBranchRule('vanillafullstrong', allowaddcons)
        elif self.args.method == 'mostinf':
            result = self.model.executeBranchRule('mostinf', allowaddcons)
        elif self.args.method == 'random':
            result = self.model.executeBranchRule('random', allowaddcons)
        # new
        elif self.args.method in ['pd', 'gcnn','baseline','new_1','new_2','new_3','new_5','new_6','new_7']:
            if self.args.method == 'gcnn':
                state = utils.extract_state(self.model)
                g = utils.graph_transform(state).to(self.device)
            elif self.args.method == 'pd':
                state = utils.extract_pd_state(self.model)
                g = utils.graph_transform(state).to(self.device)
            # new
            elif self.args.method == 'baseline':
                state = utils.extract_state(self.model)
                g = utils.graph_transform(state).to(self.device)

            elif self.args.method == 'new_1':
                state = utils.extract_state_new_1(self.model)
                g = utils.graph_transform_new_1(state).to(self.device)

            elif self.args.method == 'new_2':
                state = utils.extract_state_new_2(self.model)
                g = utils.graph_transform_new_2(state).to(self.device)

            elif self.args.method == 'new_3':
                state = utils.extract_state_new_3(self.model)
                g = utils.graph_transform_new_3(state).to(self.device)

            elif self.args.method == 'new_5':
                state = utils.extract_state_new_5(self.model)
                g = utils.graph_transform_new_5(state).to(self.device)

            elif self.args.method == 'new_6':
                state = utils.extract_state_new_6(self.model)
                g = utils.graph_transform_new_6(state).to(self.device)

            elif self.args.method == 'new_7':
                state = utils.extract_state_new_7(self.model)
                g = utils.graph_transform_new_7(state).to(self.device)

            else:
                print("!!!!!!!method error!!!!")
                raise NotImplementedError


            cands, *_ = self.model.getPseudoBranchCands()
            action_set = [c.getCol().getLPPos() for c in cands]
            
            agent_out = self.agent.eval_forward(g)

            score = agent_out[action_set]

            # print("score.len:",len(score))
            # print("score:",score[:10])

            best_cand = int(score.argmax())

            # global count
            # print("count:",count,"\tbest_cand:",best_cand)
            # count = count + 1

            self.model.branchVar(cands[best_cand])
            result = SCIP_RESULT.BRANCHED

            # print("SCIP_RESULT.BRANCHED:",SCIP_RESULT.BRANCHED)

        if result == SCIP_RESULT.REDUCEDDOM:
            self.dom_reduction += 1
        elif result == SCIP_RESULT.CUTOFF:
            self.cut_off += 2
        return {'result': result}

class PolicyBranching_3(Branchrule):
    # Get accurate data for imitation learning and supervised learning
    # 为模仿学习和监督学习获得准确的数据

    def __init__(self, scip, agent, device, args):
        super().__init__()
        self.model = scip
        self.device = device
        self.agent = agent
        self.args = args
        self.count = 0
        self.dom_reduction = 0
        self.cut_off = 0
        self.state_buffer = {}
        self.khalil_root_buffer = {}
        self.branch_count = 0


    def branchinit(self):
        self.root_buffer = {}
        self.count = 0

    
    def branchexeclp(self, allowaddcons):

        result = SCIP_RESULT.DIDNOTRUN
        cands, *_ = self.model.getPseudoBranchCands()

        gcn_state = utils.extract_state(self.model, self.state_buffer)
        state_khalil = None

        var_dim = 25
        node_dim = 8
        mip_dim = 53
        
        node_state = self.model.getMyNodeState(node_dim)
        mip_state = self.model.getMyMIPState(mip_dim)
        # _, _, cands_state_mat = self.model.getMyCandsState(var_dim, self.branch_count)
        cands_state_mat = None
        state = [gcn_state, state_khalil, node_state, mip_state, cands_state_mat]

        g = utils.graph_transform_new_3(state).to(self.device)

        agent_out = self.agent(g)

        agent_out = agent_out.nodes['v'].data['s']

        cands, *_ = self.model.getPseudoBranchCands()

        action_set = [c.getCol().getLPPos() for c in cands]
        
        
        score = agent_out[action_set]
        best_cand = int(score.argmax())

        self.model.branchVar(cands[best_cand])
        result = SCIP_RESULT.BRANCHED

        if result == SCIP_RESULT.REDUCEDDOM:
            self.dom_reduction += 1
        elif result == SCIP_RESULT.CUTOFF:
            self.cut_off += 2

        self.branch_count = self.branch_count + 1

        return {'result': result}

class PolicyBranching_3_0(Branchrule):
    # Get accurate data for imitation learning and supervised learning
    # 为模仿学习和监督学习获得准确的数据

    def __init__(self, scip, agent, device, args):
        super().__init__()
        self.model = scip
        self.device = device
        self.agent = agent
        self.args = args
        self.count = 0
        self.dom_reduction = 0
        self.cut_off = 0
        self.state_buffer = {}
        self.khalil_root_buffer = {}
        self.branch_count = 0


    def branchinit(self):
        self.root_buffer = {}
        self.count = 0

    
    def branchexeclp(self, allowaddcons):

        result = SCIP_RESULT.DIDNOTRUN
        cands, *_ = self.model.getPseudoBranchCands()

        gcn_state = utils.extract_state(self.model, self.state_buffer)
        state_khalil = utils.extract_khalil_variable_features(self.model, cands, self.khalil_root_buffer)

        var_dim = 25
        node_dim = 8
        mip_dim = 53
        
        node_state = self.model.getMyNodeState(node_dim)
        mip_state = self.model.getMyMIPState(mip_dim)
        # _, _, cands_state_mat = self.model.getMyCandsState(var_dim, self.branch_count)
        cands_state_mat = None
        state = [gcn_state, state_khalil, node_state, mip_state, cands_state_mat]

        g = utils.graph_transform_new_3_0(state).to(self.device)

        agent_out = self.agent(g)

        agent_out = agent_out.nodes['v'].data['s']

        cands, *_ = self.model.getPseudoBranchCands()

        action_set = [c.getCol().getLPPos() for c in cands]
        
        
        score = agent_out[action_set]
        best_cand = int(score.argmax())

        self.model.branchVar(cands[best_cand])
        result = SCIP_RESULT.BRANCHED

        if result == SCIP_RESULT.REDUCEDDOM:
            self.dom_reduction += 1
        elif result == SCIP_RESULT.CUTOFF:
            self.cut_off += 2

        self.branch_count = self.branch_count + 1

        return {'result': result}

class PolicyBranching_3_1(Branchrule):
    # Get accurate data for imitation learning and supervised learning
    # 为模仿学习和监督学习获得准确的数据

    def __init__(self, scip, agent, device, args):
        super().__init__()
        self.model = scip
        self.device = device
        self.agent = agent
        self.args = args
        self.count = 0
        self.dom_reduction = 0
        self.cut_off = 0
        self.state_buffer = {}
        self.khalil_root_buffer = {}
        self.branch_count = 0


    def branchinit(self):
        self.root_buffer = {}
        self.count = 0

    
    def branchexeclp(self, allowaddcons):

        var_dim = 25
        node_dim = 8
        mip_dim = 53

        result = SCIP_RESULT.DIDNOTRUN
        cands, *_ = self.model.getPseudoBranchCands()


        gcn_state = utils.extract_state(self.model, self.state_buffer)
        state_khalil = utils.extract_khalil_variable_features(self.model, cands, self.khalil_root_buffer)
        node_state = self.model.getMyNodeState(node_dim)
        mip_state = self.model.getMyMIPState(mip_dim)
        # _, _, cands_state_mat = self.model.getMyCandsState(var_dim, self.branch_count)
        cands_state_mat = None

        state = [gcn_state, state_khalil, node_state, mip_state, cands_state_mat]

        g = utils.graph_transform_new_3_1(state).to(self.device)

        agent_out = self.agent(g)

        agent_out = agent_out.nodes['v'].data['s']

        cands, *_ = self.model.getPseudoBranchCands()

        action_set = [c.getCol().getLPPos() for c in cands]
        
        
        score = agent_out[action_set]
        best_cand = int(score.argmax())

        self.model.branchVar(cands[best_cand])
        result = SCIP_RESULT.BRANCHED

        if result == SCIP_RESULT.REDUCEDDOM:
            self.dom_reduction += 1
        elif result == SCIP_RESULT.CUTOFF:
            self.cut_off += 2

        self.branch_count = self.branch_count + 1

        return {'result': result}

class PolicyBranching_3_2(Branchrule):
    # Get accurate data for imitation learning and supervised learning
    # 为模仿学习和监督学习获得准确的数据

    def __init__(self, scip, agent, device, args):
        super().__init__()
        self.model = scip
        self.device = device
        self.agent = agent
        self.args = args
        self.count = 0
        self.dom_reduction = 0
        self.cut_off = 0
        self.state_buffer = {}
        self.khalil_root_buffer = {}
        self.branch_count = 0


    def branchinit(self):
        self.root_buffer = {}
        self.count = 0

    
    def branchexeclp(self, allowaddcons):

        var_dim = 25
        node_dim = 8
        mip_dim = 53

        result = SCIP_RESULT.DIDNOTRUN
        cands, *_ = self.model.getPseudoBranchCands()


        gcn_state = utils.extract_state(self.model, self.state_buffer)
        state_khalil = utils.extract_khalil_variable_features(self.model, cands, self.khalil_root_buffer)
        node_state = self.model.getMyNodeState(node_dim)
        mip_state = self.model.getMyMIPState(mip_dim)
        # _, _, cands_state_mat = self.model.getMyCandsState(var_dim, self.branch_count)
        cands_state_mat = None

        state = [gcn_state, state_khalil, node_state, mip_state, cands_state_mat]

        g = utils.graph_transform_new_3_2(state).to(self.device)

        agent_out = self.agent(g)

        agent_out = agent_out.nodes['v'].data['s']

        cands, *_ = self.model.getPseudoBranchCands()

        action_set = [c.getCol().getLPPos() for c in cands]
        
        
        score = agent_out[action_set]
        best_cand = int(score.argmax())

        self.model.branchVar(cands[best_cand])
        result = SCIP_RESULT.BRANCHED

        if result == SCIP_RESULT.REDUCEDDOM:
            self.dom_reduction += 1
        elif result == SCIP_RESULT.CUTOFF:
            self.cut_off += 2

        self.branch_count = self.branch_count + 1

        return {'result': result}

class PolicyBranching_4(Branchrule):
    # Get accurate data for imitation learning and supervised learning
    # 为模仿学习和监督学习获得准确的数据

    def __init__(self, scip, agent, device, args):
        super().__init__()
        self.model = scip
        self.device = device
        self.agent = agent
        self.args = args
        self.count = 0
        self.dom_reduction = 0
        self.cut_off = 0
        self.state_buffer = {}
        self.khalil_root_buffer = {}
        self.branch_count = 0
        
        self.branch_time = 0

    def branchinit(self):
        self.root_buffer = {}
        self.count = 0

    
    def branchexeclp(self, allowaddcons):

        branch_start_time = time.time()

        result = SCIP_RESULT.DIDNOTRUN
        cands, *_ = self.model.getPseudoBranchCands()

        state = utils.extract_state_new_4(self.model, self.state_buffer)
        g = utils.graph_transform_new_4(state).to(self.device)
        
        cands, *_ = self.model.getPseudoBranchCands()
        action_set = [c.getCol().getLPPos() for c in cands]
            
        agent_out = self.agent(g)
        agent_out = agent_out.nodes['v'].data['s']

        score = agent_out[action_set]

        best_cand = int(score.argmax())


        self.model.branchVar(cands[best_cand])
        result = SCIP_RESULT.BRANCHED

        if result == SCIP_RESULT.REDUCEDDOM:
            self.dom_reduction += 1
        elif result == SCIP_RESULT.CUTOFF:
            self.cut_off += 2

        time_end = time.time() - branch_start_time
        self.branch_time += time_end
        self.count += 1

        if self.count in [10, 100, 1000, 10000]:
            print(f"PolicyBranching_4 self.count:{self.count}, branch time:{self.branch_time}")



        return {'result': result}

class PolicyBranching_8(Branchrule):
    # Get accurate data for imitation learning and supervised learning
    # 为模仿学习和监督学习获得准确的数据

    def __init__(self, scip, agent, device, args):
        super().__init__()
        self.model = scip
        self.device = device
        self.agent = agent
        self.args = args
        self.count = 0
        self.dom_reduction = 0
        self.cut_off = 0

        self.cache_deque = deque([None,None,None,None,None,None,None,None])

        # self.cache_deque = deque()
        # temp_data = torch.zeros(1000, 64).to(self.device)
        # for _ in range(8):
        #     self.cache_deque.append(temp_data)

    def get_cache_img(self, state):
        self.cache_deque.popleft()
        self.cache_deque.append(state)
        graphs = []
        for d in self.cache_deque:
            g = utils.graph_transform_new_8(d).to(self.device)
            graphs.append(g)
        batched_graph = dgl.batch(graphs)
        return batched_graph

    def get_cache_deque(self, agent_out):
        self.cache_deque.popleft()
        self.cache_deque.append(agent_out)
        graphs_list = []
        for one_cache in self.cache_deque:
            graphs_list.append(one_cache)
        
        return graphs_list

            
    def branchinit(self):
        self.root_buffer = {}

    def branchexeclp(self, allowaddcons):

        result = SCIP_RESULT.DIDNOTRUN
        self.count += 1

        state = utils.extract_state_new_8(self.model)
        # g = utils.graph_transform_new_8(state).to(self.device)
        batched_graph = self.get_cache_img(state)

        cands, *_ = self.model.getPseudoBranchCands()
        action_set = [c.getCol().getLPPos() for c in cands]
        
        agent_out_1 = self.agent.eval_forward_1(batched_graph)
        agent_out_2 = self.agent.eval_forward_2(agent_out_1)

        score = agent_out_2[action_set]
        best_cand = int(score.argmax())

        self.model.branchVar(cands[best_cand])
        result = SCIP_RESULT.BRANCHED

        if result == SCIP_RESULT.REDUCEDDOM:
            self.dom_reduction += 1
        elif result == SCIP_RESULT.CUTOFF:
            self.cut_off += 2
        return {'result': result}

class PolicyBranching_8_2(Branchrule):
    # Get accurate data for imitation learning and supervised learning
    # 为模仿学习和监督学习获得准确的数据

    def __init__(self, scip, agent, device, args):
        super().__init__()
        self.model = scip
        self.device = device
        self.agent = agent
        self.args = args
        self.count = 0
        self.dom_reduction = 0
        self.cut_off = 0

        # self.cache_deque = deque([None,None,None,None,None,None,None,None])

        self.cache_deque = deque()
        temp_data = torch.zeros(1000, 64).to(self.device)
        for _ in range(8):
            self.cache_deque.append(temp_data)

    def get_cache_deque(self, agent_out):
        agent_out = agent_out[0]

        self.cache_deque.popleft()
        self.cache_deque.append(agent_out)

        graphs_list = []
        for one_cache in self.cache_deque:
            graphs_list.append(one_cache)
        
        return graphs_list

            
    def branchinit(self):
        self.root_buffer = {}

    def branchexeclp(self, allowaddcons):

        result = SCIP_RESULT.DIDNOTRUN
        self.count += 1

        state = utils.extract_state_new_8(self.model)
        g = utils.graph_transform_new_8(state).to(self.device)

        cands, *_ = self.model.getPseudoBranchCands()
        action_set = [c.getCol().getLPPos() for c in cands]
        
        agent_out_1 = self.agent.eval_forward_1(g)


        agent_out_2 = self.agent.eval_forward_2(self.get_cache_deque(agent_out_1))

        score = agent_out_2[action_set]
        best_cand = int(score.argmax())

        self.model.branchVar(cands[best_cand])
        result = SCIP_RESULT.BRANCHED

        if result == SCIP_RESULT.REDUCEDDOM:
            self.dom_reduction += 1
        elif result == SCIP_RESULT.CUTOFF:
            self.cut_off += 2
        return {'result': result}

class PolicyBranching_10(Branchrule):
    # Get accurate data for imitation learning and supervised learning
    # 为模仿学习和监督学习获得准确的数据

    def __init__(self, scip, agent, device, args):
        super().__init__()
        self.model = scip
        self.device = device
        self.agent = agent
        self.args = args
        self.count = 0
        self.dom_reduction = 0
        self.cut_off = 0

        self.order_dict_length = 1000
        self.father_order_dict = OrderedDict()
        self.cache_order_dict = OrderedDict()


    def append_now_node(self, node, state):
        # 插入元素
        father_node = node.getParent()
        self.father_order_dict[node] = father_node
        self.cache_order_dict[node] = state

        # 如果cache得到上限，则删除第一个元素
        if len(self.cache_order_dict) > self.order_dict_length:
            self.father_order_dict.popitem(last=False)
            self.cache_order_dict.popitem(last=False)
        
    def get_cache_deque(self, node,depth=0):
        temp_deque = deque()

        # 如果当前深度小于8，并且当前父节点查找到，并且当前父节点的state查找到，递归
        if depth < 7 and node in self.father_order_dict.keys() and self.father_order_dict[node] in self.cache_order_dict.keys():
            temp_deque = self.get_cache_deque(self.father_order_dict[node],depth+1)
            
        # 存入当前信息
        temp_deque.append(self.cache_order_dict[node])

        # 检查当前的数据长度是否满足要求
        if depth == 0:
            if len(temp_deque) < 8:
                while len(temp_deque) < 8:
                    temp_deque.appendleft(None)
                    
        return temp_deque

    def get_cache_img(self, cache_deque):
        graphs = []
        for d in cache_deque:
            g = utils.graph_transform_new_10(d).to(self.device)
            graphs.append(g)
        batched_graph = dgl.batch(graphs)
        return batched_graph
    
     
    def branchinit(self):
        self.root_buffer = {}

    def branchexeclp(self, allowaddcons):

        result = SCIP_RESULT.DIDNOTRUN
        self.count += 1

        # if self.count % 100 == 0:
        #     print("self.count:",self.count)

        state = utils.extract_state_new_10(self.model)
        cands, *_ = self.model.getPseudoBranchCands()
        action_set = [c.getCol().getLPPos() for c in cands]
        # g = utils.graph_transform_new_8(state).to(self.device)

        now_node = self.model.getCurrentNode()
        self.append_now_node(now_node, state)

        agent_out_1 = self.agent.eval_forward_1(self.get_cache_img(self.get_cache_deque(now_node)))
        agent_out_2 = self.agent.eval_forward_2(agent_out_1)

        score = agent_out_2[action_set]
        best_cand = int(score.argmax())

        self.model.branchVar(cands[best_cand])
        result = SCIP_RESULT.BRANCHED

        if result == SCIP_RESULT.REDUCEDDOM:
            self.dom_reduction += 1
        elif result == SCIP_RESULT.CUTOFF:
            self.cut_off += 2
        return {'result': result}

class PolicyBranching_10_2(Branchrule):
    # Get accurate data for imitation learning and supervised learning
    # 为模仿学习和监督学习获得准确的数据

    def __init__(self, scip, agent, device, args):
        super().__init__()
        self.model = scip
        self.device = device
        self.agent = agent
        self.args = args
        self.count = 0
        self.dom_reduction = 0
        self.cut_off = 0

        self.order_dict_length = 100
        self.father_order_dict = OrderedDict()
        self.cache_order_dict = OrderedDict()

        


    def append_now_node(self, node, state):
        # 插入元素
        father_node = node.getParent()
        self.father_order_dict[node] = father_node
        self.cache_order_dict[node] = state

        # 如果cache得到上限，则删除第一个元素
        if len(self.cache_order_dict) > self.order_dict_length:
            self.father_order_dict.popitem(last=False)
            self.cache_order_dict.popitem(last=False)
        
    def get_cache_deque(self, node,depth=0):
        temp_deque = deque()

        # 如果当前深度小于8，并且当前父节点查找到，并且当前父节点的state查找到，递归
        if depth < 7 and node in self.father_order_dict.keys() and self.father_order_dict[node] in self.cache_order_dict.keys():
            temp_deque = self.get_cache_deque(self.father_order_dict[node],depth+1)

        # 存入当前信息
        temp_deque.append(self.cache_order_dict[node])

        # 检查当前的数据长度是否满足要求
        if depth == 0:
            if len(temp_deque) < 8:
                while len(temp_deque) < 8:
                    temp_deque.appendleft(None)
        return temp_deque

    def get_cache_list(self, temp_deque):
        cache_list = []
        temp_data = torch.zeros(1000, 64).to(self.device)
        
        for one_data in temp_deque:
            if one_data is None:
                cache_list.append(temp_data)
            else:
                cache_list.append(one_data)

        return cache_list

    
     
    def branchinit(self):
        self.root_buffer = {}

    def branchexeclp(self, allowaddcons):

        result = SCIP_RESULT.DIDNOTRUN
        self.count += 1

        # if self.count % 100 == 0:
        #     print("self.count:",self.count)

        state = utils.extract_state_new_10(self.model)
        cands, *_ = self.model.getPseudoBranchCands()
        action_set = [c.getCol().getLPPos() for c in cands]
        g = utils.graph_transform_new_8(state).to(self.device)

        now_node = self.model.getCurrentNode()
        
        agent_out_1 = self.agent.eval_forward_1(g)
        
        self.append_now_node(now_node, agent_out_1[0])

        agent_out_2 = self.agent.eval_forward_2(self.get_cache_list(self.get_cache_deque(now_node)))

        score = agent_out_2[action_set]
        best_cand = int(score.argmax())

        self.model.branchVar(cands[best_cand])
        result = SCIP_RESULT.BRANCHED

        if result == SCIP_RESULT.REDUCEDDOM:
            self.dom_reduction += 1
        elif result == SCIP_RESULT.CUTOFF:
            self.cut_off += 2
        return {'result': result}

class PolicyBranching_12_2(Branchrule):
    # Get accurate data for imitation learning and supervised learning
    # 为模仿学习和监督学习获得准确的数据

    def __init__(self, scip, agent, device, args):
        super().__init__()
        self.model = scip
        self.device = device
        self.agent = agent
        self.args = args
        self.count = 0
        self.dom_reduction = 0
        self.cut_off = 0

        self.order_dict_length = 100
        self.father_order_dict = OrderedDict()
        self.cache_order_dict = OrderedDict()

        


    def append_now_node(self, node, state):
        # 插入元素
        father_node = node.getParent()
        self.father_order_dict[node] = father_node
        self.cache_order_dict[node] = state

        # 如果cache得到上限，则删除第一个元素
        if len(self.cache_order_dict) > self.order_dict_length:
            self.father_order_dict.popitem(last=False)
            self.cache_order_dict.popitem(last=False)
        
    def get_cache_deque(self, node,depth=0):
        temp_deque = deque()

        # 如果当前深度小于8，并且当前父节点查找到，并且当前父节点的state查找到，递归
        if depth < 7 and node in self.father_order_dict.keys() and self.father_order_dict[node] in self.cache_order_dict.keys():
            temp_deque = self.get_cache_deque(self.father_order_dict[node],depth+1)

        # 存入当前信息
        temp_deque.append(self.cache_order_dict[node])

        # 检查当前的数据长度是否满足要求
        if depth == 0:
            if len(temp_deque) < 8:
                while len(temp_deque) < 8:
                    temp_deque.appendleft(None)
        return temp_deque

    def get_cache_list(self, temp_deque):
        cache_list = []
        temp_data = torch.zeros(1000, 64).to(self.device)
        
        for one_data in temp_deque:
            if one_data is None:
                cache_list.append(temp_data)
            else:
                cache_list.append(one_data)

        return cache_list

    
     
    def branchinit(self):
        self.root_buffer = {}

    def branchexeclp(self, allowaddcons):

        result = SCIP_RESULT.DIDNOTRUN
        self.count += 1

        # if self.count % 100 == 0:
        #     print("self.count:",self.count)

        state = utils.extract_state_new_12(self.model)
        cands, *_ = self.model.getPseudoBranchCands()
        action_set = [c.getCol().getLPPos() for c in cands]
        g = utils.graph_transform_new_12(state).to(self.device)

        now_node = self.model.getCurrentNode()
        
        agent_out_1 = self.agent.eval_forward_1(g)
        
        self.append_now_node(now_node, agent_out_1[0])

        agent_out_2 = self.agent.eval_forward_2(self.get_cache_list(self.get_cache_deque(now_node)))

        score = agent_out_2[action_set]
        best_cand = int(score.argmax())

        self.model.branchVar(cands[best_cand])
        result = SCIP_RESULT.BRANCHED

        if result == SCIP_RESULT.REDUCEDDOM:
            self.dom_reduction += 1
        elif result == SCIP_RESULT.CUTOFF:
            self.cut_off += 2
        return {'result': result}

class PolicyBranching_13_2(Branchrule):
    # Get accurate data for imitation learning and supervised learning
    # 为模仿学习和监督学习获得准确的数据

    def __init__(self, scip, agent, device, args):
        super().__init__()
        self.model = scip
        self.device = device
        self.agent = agent
        self.args = args
        self.count = 0
        self.dom_reduction = 0
        self.cut_off = 0

        self.order_dict_length = 100
        self.father_order_dict = OrderedDict()
        self.cache_order_dict = OrderedDict()


    def append_now_node(self, node, state):
        # 插入元素
        father_node = node.getParent()
        self.father_order_dict[node] = father_node
        self.cache_order_dict[node] = state

        # 如果cache得到上限，则删除第一个元素
        if len(self.cache_order_dict) > self.order_dict_length:
            self.father_order_dict.popitem(last=False)
            self.cache_order_dict.popitem(last=False)
        
    def get_cache_deque(self, node,depth=0):
        temp_deque = deque()

        # 如果当前深度小于8，并且当前父节点查找到，并且当前父节点的state查找到，递归
        if depth < 7 and node in self.father_order_dict.keys() and self.father_order_dict[node] in self.cache_order_dict.keys():
            temp_deque = self.get_cache_deque(self.father_order_dict[node],depth+1)

        # 存入当前信息
        temp_deque.append(self.cache_order_dict[node])

        # 检查当前的数据长度是否满足要求
        if depth == 0:
            if len(temp_deque) < 8:
                while len(temp_deque) < 8:
                    temp_deque.appendleft(None)
        return temp_deque

    def get_cache_list(self, temp_deque):
        cache_list = []
        temp_data = torch.zeros(1000, 64).to(self.device)
        
        for one_data in temp_deque:
            if one_data is None:
                cache_list.append(temp_data)
            else:
                cache_list.append(one_data)

        return cache_list

    
     
    def branchinit(self):
        self.root_buffer = {}

    def branchexeclp(self, allowaddcons):

        result = SCIP_RESULT.DIDNOTRUN
        self.count += 1

        # if self.count % 100 == 0:
        #     print("self.count:",self.count)

        state = utils.extract_state_new_13(self.model)
        cands, *_ = self.model.getPseudoBranchCands()
        action_set = [c.getCol().getLPPos() for c in cands]
        g = utils.graph_transform_new_13(state).to(self.device)

        now_node = self.model.getCurrentNode()
        
        agent_out_1 = self.agent.eval_forward_1(g)
        
        self.append_now_node(now_node, agent_out_1[0])

        agent_out_2 = self.agent.eval_forward_2(self.get_cache_list(self.get_cache_deque(now_node)))

        score = agent_out_2[action_set]
        best_cand = int(score.argmax())

        self.model.branchVar(cands[best_cand])
        result = SCIP_RESULT.BRANCHED

        if result == SCIP_RESULT.REDUCEDDOM:
            self.dom_reduction += 1
        elif result == SCIP_RESULT.CUTOFF:
            self.cut_off += 2
        return {'result': result}

class PolicyBranching_14_2(Branchrule):
    # Get accurate data for imitation learning and supervised learning
    # 为模仿学习和监督学习获得准确的数据

    def __init__(self, scip, agent, device, args):
        super().__init__()
        self.model = scip
        self.device = device
        self.agent = agent
        self.args = args
        self.count = 0
        self.dom_reduction = 0
        self.cut_off = 0

        self.order_dict_length = 100
        self.father_order_dict = OrderedDict()
        self.cache_order_dict = OrderedDict()


    def append_now_node(self, node, state):
        # 插入元素
        father_node = node.getParent()
        self.father_order_dict[node] = father_node
        self.cache_order_dict[node] = state

        # 如果cache得到上限，则删除第一个元素
        if len(self.cache_order_dict) > self.order_dict_length:
            self.father_order_dict.popitem(last=False)
            self.cache_order_dict.popitem(last=False)
        
    def get_cache_deque(self, node,depth=0):
        temp_deque = deque()

        # 如果当前深度小于8，并且当前父节点查找到，并且当前父节点的state查找到，递归
        if depth < 7 and node in self.father_order_dict.keys() and self.father_order_dict[node] in self.cache_order_dict.keys():
            temp_deque = self.get_cache_deque(self.father_order_dict[node],depth+1)

        # 存入当前信息
        temp_deque.append(self.cache_order_dict[node])

        # 检查当前的数据长度是否满足要求
        if depth == 0:
            if len(temp_deque) < 8:
                while len(temp_deque) < 8:
                    temp_deque.appendleft(None)
        return temp_deque

    def get_cache_list(self, temp_deque):
        cache_list = []
        temp_data = torch.zeros(1000, 64).to(self.device)
        
        for one_data in temp_deque:
            if one_data is None:
                cache_list.append(temp_data)
            else:
                cache_list.append(one_data)

        return cache_list

    
     
    def branchinit(self):
        self.root_buffer = {}

    def branchexeclp(self, allowaddcons):

        result = SCIP_RESULT.DIDNOTRUN
        self.count += 1

        # if self.count % 100 == 0:
        #     print("self.count:",self.count)

        state = utils.extract_state_new_14(self.model)
        cands, *_ = self.model.getPseudoBranchCands()
        action_set = [c.getCol().getLPPos() for c in cands]
        g = utils.graph_transform_new_14(state).to(self.device)

        now_node = self.model.getCurrentNode()
        
        agent_out_1 = self.agent.eval_forward_1(g)
        
        self.append_now_node(now_node, agent_out_1[0])

        agent_out_2 = self.agent.eval_forward_2(self.get_cache_list(self.get_cache_deque(now_node)))

        score = agent_out_2[action_set]
        best_cand = int(score.argmax())

        self.model.branchVar(cands[best_cand])
        result = SCIP_RESULT.BRANCHED

        if result == SCIP_RESULT.REDUCEDDOM:
            self.dom_reduction += 1
        elif result == SCIP_RESULT.CUTOFF:
            self.cut_off += 2
        return {'result': result}

class PolicyBranching_15_2(Branchrule):
    # Get accurate data for imitation learning and supervised learning
    # 为模仿学习和监督学习获得准确的数据

    def __init__(self, scip, agent, device, args):
        super().__init__()
        self.model = scip
        self.device = device
        self.agent = agent
        self.args = args
        self.count = 0
        self.dom_reduction = 0
        self.cut_off = 0

        self.order_dict_length = 100
        self.father_order_dict = OrderedDict()
        self.cache_order_dict = OrderedDict()


    def append_now_node(self, node, state):
        # 插入元素
        father_node = node.getParent()
        self.father_order_dict[node] = father_node
        self.cache_order_dict[node] = state

        # 如果cache得到上限，则删除第一个元素
        if len(self.cache_order_dict) > self.order_dict_length:
            self.father_order_dict.popitem(last=False)
            self.cache_order_dict.popitem(last=False)
        
    def get_cache_deque(self, node,depth=0):
        temp_deque = deque()

        # 如果当前深度小于8，并且当前父节点查找到，并且当前父节点的state查找到，递归
        if depth < 7 and node in self.father_order_dict.keys() and self.father_order_dict[node] in self.cache_order_dict.keys():
            temp_deque = self.get_cache_deque(self.father_order_dict[node],depth+1)

        # 存入当前信息
        temp_deque.append(self.cache_order_dict[node])

        # 检查当前的数据长度是否满足要求
        if depth == 0:
            if len(temp_deque) < 8:
                while len(temp_deque) < 8:
                    temp_deque.appendleft(None)
        return temp_deque

    def get_cache_list(self, temp_deque):
        cache_list = []
        temp_data = torch.zeros(1000, 64).to(self.device)
        
        for one_data in temp_deque:
            if one_data is None:
                cache_list.append(temp_data)
            else:
                cache_list.append(one_data)

        return cache_list

    
     
    def branchinit(self):
        self.root_buffer = {}

    def branchexeclp(self, allowaddcons):

        result = SCIP_RESULT.DIDNOTRUN
        self.count += 1

        # if self.count % 100 == 0:
        #     print("self.count:",self.count)

        state = utils.extract_state_new_15(self.model)
        cands, *_ = self.model.getPseudoBranchCands()
        action_set = [c.getCol().getLPPos() for c in cands]
        g = utils.graph_transform_new_15(state).to(self.device)

        now_node = self.model.getCurrentNode()
        
        agent_out_1 = self.agent.eval_forward_1(g)
        
        self.append_now_node(now_node, agent_out_1[0])

        agent_out_2 = self.agent.eval_forward_2(self.get_cache_list(self.get_cache_deque(now_node)))

        score = agent_out_2[action_set]
        best_cand = int(score.argmax())

        self.model.branchVar(cands[best_cand])
        result = SCIP_RESULT.BRANCHED

        if result == SCIP_RESULT.REDUCEDDOM:
            self.dom_reduction += 1
        elif result == SCIP_RESULT.CUTOFF:
            self.cut_off += 2
        return {'result': result}

class PolicyBranching_20(Branchrule):
    # Get accurate data for imitation learning and supervised learning
    # 为模仿学习和监督学习获得准确的数据

    def __init__(self, scip, agent, device, args):
        super().__init__()
        self.model = scip
        self.device = device
        self.agent = agent
        self.args = args
        self.count = 0
        self.dom_reduction = 0
        self.cut_off = 0

    def branchinit(self):
        self.root_buffer = {}

    def branchexeclp(self, allowaddcons):

        result = SCIP_RESULT.DIDNOTRUN
        self.count += 1

        state = utils.extract_state(self.model)
        g = utils.graph_transform_new_20(state).to(self.device)

        cands, *_ = self.model.getPseudoBranchCands()
        action_set = [c.getCol().getLPPos() for c in cands]
    
        # agent_out = self.agent.forward_lp(g)

        # logits_0 = agent_out.nodes['v'].data['s_0'][action_set]
        # logits_1 = agent_out.nodes['v'].data['s_1'][action_set]

        # logits = torch.mul(logits_0, logits_1)

        agent_out = self.agent.forward_sb(g)
        logits = agent_out.nodes['v'].data['s'][action_set]

        best_cand = int(logits.argmax())

        self.model.branchVar(cands[best_cand])
        result = SCIP_RESULT.BRANCHED

        if result == SCIP_RESULT.REDUCEDDOM:
            self.dom_reduction += 1
        elif result == SCIP_RESULT.CUTOFF:
            self.cut_off += 2
        return {'result': result}

class PolicyBranching_20_topk(Branchrule):
    # Get accurate data for imitation learning and supervised learning
    # 为模仿学习和监督学习获得准确的数据

    def __init__(self, scip, agent, device, args):
        super().__init__()
        self.model = scip
        self.device = device
        self.agent = agent
        self.args = args
        self.count = 0
        self.dom_reduction = 0
        self.cut_off = 0

    def branchinit(self):
        self.root_buffer = {}

    def branchexeclp(self, allowaddcons, k = 5):#getDowngainUpgainData试试这个

        result = SCIP_RESULT.DIDNOTRUN
        self.count += 1

        state = utils.extract_state(self.model)
        g = utils.graph_transform_new_20(state).to(self.device)

        cands, *_ = self.model.getPseudoBranchCands()
        action_set = [c.getCol().getLPPos() for c in cands]
    
        # agent_out = self.agent.forward_lp(g)

        # logits_0 = agent_out.nodes['v'].data['s_0'][action_set]
        # logits_1 = agent_out.nodes['v'].data['s_1'][action_set]

        # logits = torch.mul(logits_0, logits_1)

        agent_out = self.agent.forward_sb(g)
        logits = agent_out.nodes['v'].data['s'][action_set]
        score = logits
        if (self.model.getCurrentNode().getDepth() <= 2 and self.model.getCurrentNode().getDepth() >-1):
            # if (self.model.getCurrentNode().getDepth()<1):
            #     k = 10
            # elif (self.model.getCurrentNode().getDepth()<=2):
            #     k = 5
            # else:
            #     k = 2
            #187286 不用top_k方法的，即正常的 ，187287 用top-5且self.model.getCurrentNode().getDepth() <= 2 and self.model.getCurrentNode().getDepth() >-1
            k = min(k, len(action_set))
            _, topk_cand_index = torch.topk(score, k, dim=0)#先对应出top_k score的index，再根据index完善c,A,b
            print('current depth',self.model.getCurrentNode().getDepth(),len(action_set))
            topk_index = [cands[int(cand_index)].getCol().getLPPos() for cand_index in topk_cand_index]#在所有变量中的index
            aaaa = time.time()
            cur_z = self.model.getLPObjVal()
            # aa = time.time()
            # SB_list = []
            topk_index_cand = [cands[int(cand_index)] for cand_index in topk_cand_index]
            sb_list = []
            for cand_var in topk_index_cand:
                self.model.startProbing()
                assert not self.model.isObjChangedProbing()
                lb = cand_var.getLbGlobal()
                ub = cand_var.getUbGlobal()
                self.model.fixVarProbing(cand_var, lb)
                self.model.constructLP()
                self.model.solveProbingLP()
                lb_score = self.model.getLPObjVal()
                self.model.endProbing()
                self.model.startProbing()
                self.model.fixVarProbing(cand_var, ub)
                self.model.constructLP()
                self.model.solveProbingLP()
                ub_score = self.model.getLPObjVal()
                self.model.endProbing()
                sb_score = max(lb_score-cur_z, 1e-6)*max(ub_score-cur_z, 1e-6)
            #     #sb = max(lp_down-cur_z, 1e-6)*max(lp_up-cur_z, 1e-6)
                sb_list.append(sb_score)

            cand_indexx = np.argmax(sb_list)
            # print('cand_index',cand_indexx,topk_index_cand[cand_indexx])
            print('SCIP time', time.time() - aaaa)
            best_cand = topk_index_cand[cand_indexx]#topk_index_cand[calculate_SB_score(topk_state, topk_index, cur_z, k)]
        else:
            best_cand = cands[int(score.argmax())]

        self.model.branchVar(best_cand)#self.model.branchVar(cands[best_cand])

        # best_cand = int(logits.argmax())

        # self.model.branchVar(cands[best_cand])
        result = SCIP_RESULT.BRANCHED

        if result == SCIP_RESULT.REDUCEDDOM:
            self.dom_reduction += 1
        elif result == SCIP_RESULT.CUTOFF:
            self.cut_off += 2
        return {'result': result}

# class PolicyBranching_20_topk(Branchrule):
#     # Get accurate data for imitation learning and supervised learning
#     # 为模仿学习和监督学习获得准确的数据

#     def __init__(self, scip, agent, device, args):
#         super().__init__()
#         self.model = scip
#         self.device = device
#         self.agent = agent
#         self.args = args
#         self.count = 0
#         self.dom_reduction = 0
#         self.cut_off = 0

#     def branchinit(self):
#         self.root_buffer = {}

#     def branchexeclp(self, allowaddcons, k = 5):#getDowngainUpgainData试试这个

#         result = SCIP_RESULT.DIDNOTRUN
#         self.count += 1

#         state = utils.extract_state(self.model)
#         g = utils.graph_transform_new_20(state).to(self.device)

#         cands, *_ = self.model.getPseudoBranchCands()
#         action_set = [c.getCol().getLPPos() for c in cands]
    
#         # agent_out = self.agent.forward_lp(g)

#         # logits_0 = agent_out.nodes['v'].data['s_0'][action_set]
#         # logits_1 = agent_out.nodes['v'].data['s_1'][action_set]

#         # logits = torch.mul(logits_0, logits_1)

#         agent_out = self.agent.forward_sb(g)
#         logits = agent_out.nodes['v'].data['s'][action_set]
#         score = logits
#         if (self.model.getCurrentNode().getDepth() <= 2 and self.model.getCurrentNode().getDepth() >-1):
#             # if (self.model.getCurrentNode().getDepth()<1):
#             #     k = 10
#             # elif (self.model.getCurrentNode().getDepth()<=2):
#             #     k = 5
#             # else:
#             #     k = 2
#             #187286 不用top_k方法的，即正常的 ，187287 用top-5且self.model.getCurrentNode().getDepth() <= 2 and self.model.getCurrentNode().getDepth() >-1
#             _, topk_cand_index = torch.topk(score, k, dim=0)#先对应出top_k score的index，再根据index完善c,A,b
#             topk_index = [cands[int(cand_index)].getCol().getLPPos() for cand_index in topk_cand_index]#在所有变量中的index
#             aaaa = time.time()
#             # cur_z = self.model.getLPObjVal()
#             # aa = time.time()
#             # SB_list = []
#             topk_index_cand = [cands[int(cand_index)] for cand_index in topk_cand_index]
#             lp_data = self.model.getCandsDowngainUpgainData(topk_index_cand)
#             lp_scores_0 = [item[0] for item in lp_data]
#             lp_scores_1 = [item[1] for item in lp_data]
#             sb_list = [max(lp_scores_0[i], 1e-6)*max(lp_scores_1[i], 1e-6) for i in range(len(lp_scores_0))]

#             # sb_list = []
#             # for cand_var in topk_index_cand:
#             #     self.model.startProbing()
#             #     assert not self.model.isObjChangedProbing()
#             #     lb = cand_var.getLbGlobal()
#             #     ub = cand_var.getUbGlobal()
#             #     self.model.fixVarProbing(cand_var, lb)
#             #     self.model.constructLP()
#             #     self.model.solveProbingLP()
#             #     lb_score = self.model.getLPObjVal()
#             #     self.model.endProbing()
#             #     self.model.startProbing()
#             #     self.model.fixVarProbing(cand_var, ub)
#             #     self.model.constructLP()
#             #     self.model.solveProbingLP()
#             #     ub_score = self.model.getLPObjVal()
#             #     self.model.endProbing()
#             #     sb_score = max(lb_score-cur_z, 1e-6)*max(ub_score-cur_z, 1e-6)
#             # #     #sb = max(lp_down-cur_z, 1e-6)*max(lp_up-cur_z, 1e-6)
#             #     sb_list.append(sb_score)

#             cand_indexx = np.argmax(sb_list)
#             # print('cand_index',cand_indexx,topk_index_cand[cand_indexx])
#             print('SCIP time', time.time() - aaaa)
#             best_cand = topk_index_cand[cand_indexx]#topk_index_cand[calculate_SB_score(topk_state, topk_index, cur_z, k)]
#         else:
#             best_cand = cands[int(score.argmax())]

#         self.model.branchVar(best_cand)#self.model.branchVar(cands[best_cand])

#         # best_cand = int(logits.argmax())

#         # self.model.branchVar(cands[best_cand])
#         result = SCIP_RESULT.BRANCHED

#         if result == SCIP_RESULT.REDUCEDDOM:
#             self.dom_reduction += 1
#         elif result == SCIP_RESULT.CUTOFF:
#             self.cut_off += 2
#         return {'result': result}

class PolicyBranching_21(Branchrule):
    # Get accurate data for imitation learning and supervised learning
    # 为模仿学习和监督学习获得准确的数据

    def __init__(self, scip, agent, device, args):
        super().__init__()
        self.model = scip
        self.device = device
        self.agent = agent
        self.args = args
        self.count = 0
        self.dom_reduction = 0
        self.cut_off = 0

    def branchinit(self):
        self.root_buffer = {}

    def branchexeclp(self, allowaddcons):

        result = SCIP_RESULT.DIDNOTRUN
        self.count += 1

        state = utils.extract_state(self.model)
        g = utils.graph_transform_new_21(state).to(self.device)

        cands, *_ = self.model.getPseudoBranchCands()
        action_set = [c.getCol().getLPPos() for c in cands]
    
        # agent_out = self.agent.forward_lp(g)

        # logits_0 = agent_out.nodes['v'].data['s_0'][action_set]
        # logits_1 = agent_out.nodes['v'].data['s_1'][action_set]

        # logits = torch.mul(logits_0, logits_1)

        agent_out = self.agent.forward_sb(g)
        logits = agent_out.nodes['v'].data['s'][action_set]

        best_cand = int(logits.argmax())

        self.model.branchVar(cands[best_cand])
        result = SCIP_RESULT.BRANCHED

        if result == SCIP_RESULT.REDUCEDDOM:
            self.dom_reduction += 1
        elif result == SCIP_RESULT.CUTOFF:
            self.cut_off += 2
        return {'result': result}

class PolicyBranching_21_topk(Branchrule):
    # Get accurate data for imitation learning and supervised learning
    # 为模仿学习和监督学习获得准确的数据

    def __init__(self, scip, agent, device, args):
        super().__init__()
        self.model = scip
        self.device = device
        self.agent = agent
        self.args = args
        self.count = 0
        self.dom_reduction = 0
        self.cut_off = 0

    def branchinit(self):
        self.root_buffer = {}

    def branchexeclp(self, allowaddcons, k = 5):#getDowngainUpgainData试试这个

        result = SCIP_RESULT.DIDNOTRUN
        self.count += 1

        state = utils.extract_state(self.model)
        g = utils.graph_transform_new_21(state).to(self.device)

        cands, *_ = self.model.getPseudoBranchCands()
        action_set = [c.getCol().getLPPos() for c in cands]
    
        # agent_out = self.agent.forward_lp(g)

        # logits_0 = agent_out.nodes['v'].data['s_0'][action_set]
        # logits_1 = agent_out.nodes['v'].data['s_1'][action_set]

        # logits = torch.mul(logits_0, logits_1)

        agent_out = self.agent.forward_sb(g)
        logits = agent_out.nodes['v'].data['s'][action_set]
        score = logits
        if (self.model.getCurrentNode().getDepth() <= 2 and self.model.getCurrentNode().getDepth() >-1):
            # if (self.model.getCurrentNode().getDepth()<1):
            #     k = 10
            # elif (self.model.getCurrentNode().getDepth()<=2):
            #     k = 5
            # else:
            #     k = 2
            #187286 不用top_k方法的，即正常的 ，187287 用top-5且self.model.getCurrentNode().getDepth() <= 2 and self.model.getCurrentNode().getDepth() >-1
            _, topk_cand_index = torch.topk(score, k, dim=0)#先对应出top_k score的index，再根据index完善c,A,b
            topk_index = [cands[int(cand_index)].getCol().getLPPos() for cand_index in topk_cand_index]#在所有变量中的index
            aaaa = time.time()
            cur_z = self.model.getLPObjVal()
            # aa = time.time()
            # SB_list = []
            topk_index_cand = [cands[int(cand_index)] for cand_index in topk_cand_index]
            sb_list = []
            for cand_var in topk_index_cand:
                self.model.startProbing()
                assert not self.model.isObjChangedProbing()
                lb = cand_var.getLbGlobal()
                ub = cand_var.getUbGlobal()
                self.model.fixVarProbing(cand_var, lb)
                self.model.constructLP()
                self.model.solveProbingLP()
                lb_score = self.model.getLPObjVal()
                self.model.endProbing()
                self.model.startProbing()
                self.model.fixVarProbing(cand_var, ub)
                self.model.constructLP()
                self.model.solveProbingLP()
                ub_score = self.model.getLPObjVal()
                self.model.endProbing()
                sb_score = max(lb_score-cur_z, 1e-6)*max(ub_score-cur_z, 1e-6)
            #     #sb = max(lp_down-cur_z, 1e-6)*max(lp_up-cur_z, 1e-6)
                sb_list.append(sb_score)

            cand_indexx = np.argmax(sb_list)
            # print('cand_index',cand_indexx,topk_index_cand[cand_indexx])
            print('SCIP time', time.time() - aaaa)
            best_cand = topk_index_cand[cand_indexx]#topk_index_cand[calculate_SB_score(topk_state, topk_index, cur_z, k)]
        else:
            best_cand = cands[int(score.argmax())]

        self.model.branchVar(best_cand)#self.model.branchVar(cands[best_cand])

        # best_cand = int(logits.argmax())

        # self.model.branchVar(cands[best_cand])
        result = SCIP_RESULT.BRANCHED

        if result == SCIP_RESULT.REDUCEDDOM:
            self.dom_reduction += 1
        elif result == SCIP_RESULT.CUTOFF:
            self.cut_off += 2
        return {'result': result}

class PolicyBranching_22(Branchrule):
    # Get accurate data for imitation learning and supervised learning
    # 为模仿学习和监督学习获得准确的数据

    def __init__(self, scip, agent, device, args):
        super().__init__()
        self.model = scip
        self.device = device
        self.agent = agent
        self.args = args
        self.count = 0
        self.dom_reduction = 0
        self.cut_off = 0


    def branchinit(self):
        self.root_buffer = {}

    def branchexeclp(self, allowaddcons):

        result = SCIP_RESULT.DIDNOTRUN
        self.count += 1

        if self.count == 1:
            print("!!!!!!!!!22!!!!!!!!")

        state = utils.extract_state_new_22(self.model)
        g = utils.graph_transform_new_22(state).to(self.device)

        cands, *_ = self.model.getPseudoBranchCands()
        action_set = [c.getCol().getLPPos() for c in cands]
        agent_out = self.agent.forward(g)

        logits_0 = agent_out.nodes['v'].data['s_0'][action_set]
        logits_1 = agent_out.nodes['v'].data['s_1'][action_set]

        # logits = torch.mul(logits_0, logits_1)
        # logits_0, logits_1 = self.agent.eval_forward(g)

        logits = torch.mul(torch.clamp(logits_0, min=1.0/6.0),torch.clamp(logits_1, min=1.0/6.0))
        # print('logits',logits.shape)
        # print('action_set', len(action_set), action_set)
        score = logits#[action_set]
        # print('score',score)
        best_cand = int(score.argmax())

        self.model.branchVar(cands[best_cand])
        result = SCIP_RESULT.BRANCHED

        if result == SCIP_RESULT.REDUCEDDOM:
            self.dom_reduction += 1
        elif result == SCIP_RESULT.CUTOFF:
            self.cut_off += 2
        return {'result': result}

class PolicyBranching_22_topk(Branchrule):
    # Get accurate data for imitation learning and supervised learning
    # 为模仿学习和监督学习获得准确的数据

    def __init__(self, scip, agent, device, args):
        super().__init__()
        self.model = scip
        self.device = device
        self.agent = agent
        self.args = args
        self.count = 0
        self.dom_reduction = 0
        self.cut_off = 0


    def branchinit(self):
        self.root_buffer = {}

    def branchexeclp(self, allowaddcons, k = 5):

        result = SCIP_RESULT.DIDNOTRUN
        self.count += 1

        if self.count == 1:
            print("!!!!!!!!!22!!!!!!!!")

        state = utils.extract_state_new_22(self.model)
        g = utils.graph_transform_new_22(state).to(self.device)

        cands, *_ = self.model.getPseudoBranchCands()
        action_set = [c.getCol().getLPPos() for c in cands]
        
        # logits_0, logits_1 = self.agent.eval_forward(g)
        agent_out = self.agent.forward(g)
        
        logits_0 = agent_out.nodes['v'].data['s_0'][action_set]
        logits_1 = agent_out.nodes['v'].data['s_1'][action_set]
        # print('logits_0', logits_0.shape, 'logits_1', logits_1.shape)
        logits = torch.mul(torch.clamp(logits_0, min=1.0/6.0),torch.clamp(logits_1, min=1.0/6.0))
        # print('logits', logits.shape, action_set)
        # score = logits[action_set]
        # logits = agent_out.nodes['v'].data['s'][action_set]
        score = logits
        if (self.model.getCurrentNode().getDepth() <= 2 and self.model.getCurrentNode().getDepth() >-1):
            # if (self.model.getCurrentNode().getDepth()<1):
            #     k = 10
            # elif (self.model.getCurrentNode().getDepth()<=2):
            #     k = 5
            # else:
            #     k = 2
            #187286 不用top_k方法的，即正常的 ，187287 用top-5且self.model.getCurrentNode().getDepth() <= 2 and self.model.getCurrentNode().getDepth() >-1
            _, topk_cand_index = torch.topk(score, k, dim=0)#先对应出top_k score的index，再根据index完善c,A,b
            topk_index = [cands[int(cand_index)].getCol().getLPPos() for cand_index in topk_cand_index]#在所有变量中的index
            aaaa = time.time()
            cur_z = self.model.getLPObjVal()
            # aa = time.time()
            # SB_list = []
            topk_index_cand = [cands[int(cand_index)] for cand_index in topk_cand_index]
            sb_list = []
            for cand_var in topk_index_cand:
                self.model.startProbing()
                assert not self.model.isObjChangedProbing()
                lb = cand_var.getLbGlobal()
                ub = cand_var.getUbGlobal()
                self.model.fixVarProbing(cand_var, lb)
                self.model.constructLP()
                self.model.solveProbingLP()
                lb_score = self.model.getLPObjVal()
                self.model.endProbing()
                self.model.startProbing()
                self.model.fixVarProbing(cand_var, ub)
                self.model.constructLP()
                self.model.solveProbingLP()
                ub_score = self.model.getLPObjVal()
                self.model.endProbing()
                sb_score = max(lb_score-cur_z, 1e-6)*max(ub_score-cur_z, 1e-6)
            #     #sb = max(lp_down-cur_z, 1e-6)*max(lp_up-cur_z, 1e-6)
                sb_list.append(sb_score)

            cand_indexx = np.argmax(sb_list)
            # print('cand_index',cand_indexx,topk_index_cand[cand_indexx])
            print('SCIP time', time.time() - aaaa)
            best_cand = topk_index_cand[cand_indexx]#topk_index_cand[calculate_SB_score(topk_state, topk_index, cur_z, k)]
        else:
            best_cand = cands[int(score.argmax())]

        self.model.branchVar(best_cand)#self.model.branchVar(cands[best_cand])
        # best_cand = int(score.argmax())

        # self.model.branchVar(cands[best_cand])
        result = SCIP_RESULT.BRANCHED

        if result == SCIP_RESULT.REDUCEDDOM:
            self.dom_reduction += 1
        elif result == SCIP_RESULT.CUTOFF:
            self.cut_off += 2
        return {'result': result}

class PolicyBranching_23(Branchrule):
    # Get accurate data for imitation learning and supervised learning
    # 为模仿学习和监督学习获得准确的数据

    def __init__(self, scip, agent, device, args):
        super().__init__()
        self.model = scip
        self.device = device
        self.agent = agent
        self.args = args
        self.count = 0
        self.dom_reduction = 0
        self.cut_off = 0

        self.pre_data = None
        self.ave_time = None
        # self.ave_flag = None

    def branchinit(self):
        self.root_buffer = {}

    def branchexeclp(self, allowaddcons):

        result = SCIP_RESULT.DIDNOTRUN
        self.count += 1

        cands, *_ = self.model.getPseudoBranchCands()
        action_set = [c.getCol().getLPPos() for c in cands]

        if self.count <= 4:
            state = utils.extract_state(self.model)
            g = utils.graph_transform(state).to(self.device)

            agent_out = self.agent.eval_forward(g)
            score = agent_out[action_set]
            if self.pre_data == None:
                self.pre_data = agent_out
                self.ave_time = 1
            else:
                self.pre_data = torch.add(self.pre_data, agent_out)
                self.ave_time = self.ave_time + 1

        else:
            score = self.pre_data[action_set]

        best_cand = int(score.argmax())

        self.model.branchVar(cands[best_cand])
        result = SCIP_RESULT.BRANCHED





        if result == SCIP_RESULT.REDUCEDDOM:
            self.dom_reduction += 1
        elif result == SCIP_RESULT.CUTOFF:
            self.cut_off += 2
        return {'result': result}

class PolicyBranching_25(Branchrule):
    # Get accurate data for imitation learning and supervised learning
    # 为模仿学习和监督学习获得准确的数据

    def __init__(self, scip, agent, device, args):
        super().__init__()
        self.model = scip
        self.device = device
        self.agent = agent
        self.args = args
        self.count = 0
        self.dom_reduction = 0
        self.cut_off = 0

        self.root_state = None
        self.root_embeds = None
        self.khalil_root_buffer = {}

    def branchinit(self):
        self.root_buffer = {}

    def branchexeclp(self, allowaddcons):

        result = SCIP_RESULT.DIDNOTRUN
        self.count += 1

        if self.count == 1:
            utils.extract_khalil_variable_features_new_25(self.model, [], self.khalil_root_buffer)
            self.root_state = utils.extract_state_new_25(self.model)
            g = utils.graph_transform(self.root_state).to(self.device)
            self.root_embeds = self.agent.get_var_embeds(g)

        cands, *_ = self.model.getPseudoBranchCands()
        action_set = [c.getCol().getLPPos() for c in cands]
        khalil_state = utils.extract_khalil_variable_features_new_25(self.model, cands, self.khalil_root_buffer)
        khalil_state = torch.FloatTensor(khalil_state).to(device)
        data = self.agent.eval_forword(1, self.root_embeds, khalil_state)


        score = data[0][action_set]

        best_cand = int(score.argmax())

        self.model.branchVar(cands[best_cand])
        result = SCIP_RESULT.BRANCHED



        if result == SCIP_RESULT.REDUCEDDOM:
            self.dom_reduction += 1
        elif result == SCIP_RESULT.CUTOFF:
            self.cut_off += 2
        return {'result': result}

class PolicyBranching_26(Branchrule):
    # Get accurate data for imitation learning and supervised learning
    # 为模仿学习和监督学习获得准确的数据

    def __init__(self, scip, agent, device, args):
        super().__init__()
        self.top_sb_k = 10
        self.ave_k = 3

        self.model = scip
        self.device = device
        self.agent = agent
        self.args = args
        self.count = 0
        self.dom_reduction = 0
        self.cut_off = 0

        self.father_history_sb_score = {}
        self.khalil_root_buffer = {}
        self.father_ave_flag = {}

    def get_now_sb_score(self, top_pos_list, top_cands, parent_id):
        
        sb_score = []
        if parent_id in self.father_history_sb_score.keys():
            parent_sb_score = self.father_history_sb_score[parent_id]
            parent_ave_flag = self.father_ave_flag[parent_id]
        else:
            parent_sb_score = {}
            parent_ave_flag = {}

        update_cands = []
        for i in range(len(top_pos_list)):
            pos = top_pos_list[i]
            cand = top_cands[i]

            if pos not in parent_ave_flag.keys():
                update_cands.append(cand)

        lp_scores = np.array(self.model.getCandsVsbScores_1(update_cands))

        # print("top_pos_list:", len(top_pos_list))
        # print("update_cands:", len(update_cands))

        history_sb_score = {}
        update_cands_index = 0
        for i in range(len(top_pos_list)):
            pos = top_pos_list[i] 
            if pos in parent_ave_flag.keys():
                sb_score.append(parent_ave_flag[pos])
            else:
                history_sb_score[pos] = [lp_scores[update_cands_index]]
                sb_score.append(lp_scores[update_cands_index])
                update_cands_index = update_cands_index + 1
        

        sb_score = np.array(sb_score)
        
        return sb_score, history_sb_score

    def add_now_sb_score(self, history_sb_score, current_node_id, parent_id):

        
        if parent_id in self.father_ave_flag.keys():
            now_ave_flag = copy.deepcopy(self.father_ave_flag[parent_id])
        else:
            now_ave_flag = {}

        if parent_id in self.father_history_sb_score.keys():
            parent_sb_score = copy.deepcopy(self.father_history_sb_score[parent_id])
        else:
            parent_sb_score = {}

        for key, value in history_sb_score.items():
            if key not in now_ave_flag.keys():
                if key not in parent_sb_score.keys():
                    history_sb_score[key] = value
                else:
                    history_sb_score[key].extend(value)
                    
                    if len(history_sb_score[key]) > self.ave_k:
                        now_ave_flag[key] = ave(history_sb_score[key])
                        del history_sb_score[key]

        self.father_history_sb_score[current_node_id] = history_sb_score
        self.father_ave_flag[current_node_id] = now_ave_flag

        return history_sb_score
        
        

    def branchinit(self):
        self.root_buffer = {}
        self.branch_time = 0
        self.get_sb_time = 0

        # self.model.setBoolParam('branching/vanillafullstrong/integralcands', True)
        # # 当前LP解中的积分变量是否应被视为分支候选者？
        # # should integral variables in the current LP solution be considered as branching candidates?
        
        # self.model.setBoolParam('branching/vanillafullstrong/scoreall', True)
        # # 应该为所有候选者计算强分支分数吗？或者当一个变量的分数无穷大时，我们可以提前停止吗？
        # # should strong branching scores be computed for all candidates, or can we early stop when a variable has infinite score?
        
        # self.model.setBoolParam('branching/vanillafullstrong/collectscores', True)
        # # 是否应该收集强分支分数？
        # # should strong branching scores be collected?

        # self.model.setBoolParam('branching/vanillafullstrong/donotbranch', True)
        # # 是否应该只对候选人进行评分，而不进行分支？
        # # should candidates only be scored, but no branching be performed?
        
        # self.model.setBoolParam('branching/vanillafullstrong/idempotent', True)
        # # 是否应该防止强大的分支副作用（例如，域更改、stat更新等）？
        # # should strong branching side-effects be prevented (e.g., domain changes, stat updates etc.)?

    def branchexeclp(self, allowaddcons):
        return self.branchexeclp_gcnn_sb_ave(allowaddcons)
        # return self.branchexeclp_vfsb_branch(allowaddcons)
        # return self.branchexeclp_my_sbScores(allowaddcons)

    def branchexeclp_fsb(self, allowaddcons):
        
        result = SCIP_RESULT.DIDNOTRUN
        self.count += 1
        time_1 = time.time()

        result = self.model.executeBranchRule('fullstrong', allowaddcons)
        if result == SCIP_RESULT.REDUCEDDOM:
            self.dom_reduction += 1
        elif result == SCIP_RESULT.CUTOFF:
            self.cut_off += 2


        self.branch_time = self.branch_time + time.time() - time_1
        print("one_time:", time.time() - time_1)
        print("self.count full sb:",self.count, "\tnodes:", self.model.getNNodes(), "\ttime:", self.branch_time)

        return {'result': result}

    def branchexeclp_vfsb(self, allowaddcons):
        
        self.count += 1
        time_1 = time.time()

        result = SCIP_RESULT.DIDNOTRUN
        result = self.model.executeBranchRule('vanillafullstrong', allowaddcons)
        if result == SCIP_RESULT.REDUCEDDOM:
            self.dom_reduction += 1
        elif result == SCIP_RESULT.CUTOFF:
            self.cut_off += 2

        self.branch_time = self.branch_time + time.time() - time_1
        print("one_time:", time.time() - time_1)
        print("self.count full sb:",self.count, "\tnodes:", self.model.getNNodes(), "\ttime:", self.branch_time)


        return {'result': result}

    # 救赎之道，就在其中
    def branchexeclp_vfsb_branch(self, allowaddcons):

        self.count += 1
        time_1 = time.time()

        if self.model.getNNodes() == 1:
            # initialize root buffer for Khalil features extraction
            utils.extract_khalil_variable_features(self.model, [], self.khalil_root_buffer)
        # once in a while, also run the expert policy and record the (state, action) pair
        result = SCIP_RESULT.DIDNOTRUN

        result = self.model.executeBranchRule('vanillafullstrong', allowaddcons)

        assert result == SCIP_RESULT.DIDNOTRUN

        # apply 'vanillafullstrong' branching decision if needed
        assert result == SCIP_RESULT.DIDNOTRUN
        cands, scores, ncands, bestcand, _ = self.model.getVanillafullstrongData()
        action_set = [c.getCol().getLPPos() for c in cands]
        # print("cands.len:",len(cands))
        # print("action_set:",action_set[:10])
        # print("lp_scores:",scores[:10])

        print("self.count sb branch:",self.count, "\tcands:", len(cands), "\tbestcand:", bestcand)

        self.model.branchVar(cands[bestcand])
        result = SCIP_RESULT.BRANCHED


        return {"result": result}

    def branchexeclp_gcnn(self, allowaddcons):

        result = SCIP_RESULT.DIDNOTRUN
        self.count += 1
        time_1 = time.time()

        state = utils.extract_state(self.model)
        g = utils.graph_transform(state).to(self.device)

        cands, *_ = self.model.getPseudoBranchCands()

        action_set = [c.getCol().getLPPos() for c in cands]
        
        agent_out = self.agent.eval_forward(g)
        score = agent_out[action_set]
        best_cand = int(score.argmax())

        self.model.branchVar(cands[best_cand])
        result = SCIP_RESULT.BRANCHED

        self.branch_time = self.branch_time + time.time() - time_1
        print("one_time:", time.time() - time_1)
        print("self.count gcnn:",self.count, "\tnodes:", self.model.getNNodes(), "\ttime:", self.branch_time)


        if result == SCIP_RESULT.REDUCEDDOM:
            self.dom_reduction += 1
        elif result == SCIP_RESULT.CUTOFF:
            self.cut_off += 2
        return {'result': result}

    def branchexeclp_my_sbScores(self, allowaddcons):

        result = SCIP_RESULT.DIDNOTRUN
        self.count += 1
        time_1 = time.time()


        cands, *_ = self.model.getPseudoBranchCands()
        action_set = [c.getCol().getLPPos() for c in cands]
        lp_scores = np.array(self.model.getCandsVsbScores(cands))

        best_cand = int(lp_scores.argmax())

        self.model.branchVar(cands[best_cand])

        result = SCIP_RESULT.BRANCHED

        if result == SCIP_RESULT.REDUCEDDOM:
            self.dom_reduction += 1
        elif result == SCIP_RESULT.CUTOFF:
            self.cut_off += 2

        self.branch_time = self.branch_time + time.time() - time_1
        
        print("\n")
        print("self.count van full sb:",self.count, "\tnodes:", self.model.getNNodes(), "\ttime:", self.branch_time)
        print("one_time:", time.time() - time_1)
        print("cands:",len(cands), "\tbest_cand:", action_set[best_cand])
        print("action_set:",action_set[:10])
        print("lp_scores:",lp_scores[:10])

        return {'result': result}

    def branchexeclp_my_lpScores(self, allowaddcons):

        result = SCIP_RESULT.DIDNOTRUN
        self.count += 1

        cands, *_ = self.model.getPseudoBranchCands()
        action_set = [c.getCol().getLPPos() for c in cands]
        indices_index, lp_data = self.model.getDowngainUpgainData()

    

        lp_data = [lp_data[i][0]*lp_data[i][1] for i in range(len(lp_data))]

        # scores = lp_data[action_set]
        scores = [lp_data[i] for i in action_set]

        # best_cand = int(scores.argmax())
        max_score = max(scores)
        best_cand = scores.index(max_score)

        self.model.branchVar(cands[best_cand])

        result = SCIP_RESULT.BRANCHED

        print("cands.len:",len(cands), "\tnodes:", self.model.getNNodes())
        print("action_set:",action_set[:10])
        print("scores:",scores[:10])
        print("self.count my_lpScores:",self.count, "\tcands.len:", len(cands),  "\tbest_cand:", best_cand)


        if result == SCIP_RESULT.REDUCEDDOM:
            self.dom_reduction += 1
        elif result == SCIP_RESULT.CUTOFF:
            self.cut_off += 2

        return {'result': result}

    def branchexeclp_gcnn_sb(self, allowaddcons):
        def get_top_indexes(lst, top_nums):
            
            top_nums = min(top_nums, len(lst))

            heap = [(-value, index) for index, value in enumerate(lst)]  # 将列表转化为元组形式并取相反数，以便将最大值放在堆顶
            heapq.heapify(heap)  # 将元组列表转化为堆
            top_nums = heapq.nsmallest(top_nums, heap)  # 获取前三个最小值（因为之前已经取相反数）
            return [index for value, index in top_nums]  # 返回前三个最小值对应的索引

        result = SCIP_RESULT.DIDNOTRUN
        self.count += 1
        time_1 = time.time()

        # 使用机器学习方法
        state = utils.extract_state(self.model)
        g = utils.graph_transform(state).to(self.device)

        cands, *_ = self.model.getPseudoBranchCands()

        action_set = [c.getCol().getLPPos() for c in cands]
    

        agent_out = self.agent.eval_forward(g)
        score = agent_out[action_set]

        # best_cand = int(score.argmax())
        # self.model.branchVar(cands[best_cand])


        time_get_sb_time_start = time.time()
        # 现在要从中找出最好的五个
        topk_values, top_index_list = torch.topk(score.squeeze(), self.top_sb_k)
        # top_index_list = get_top_indexes(score, self.top_sb_k)
        top_cands = [cands[index] for index in top_index_list]

        # 使用SB方法得到最好的
        
        lp_scores  = np.array(self.model.getCandsVsbScores_1(top_cands))
        self.get_sb_time += (time.time() - time_get_sb_time_start)

        best_cand = int(lp_scores.argmax())

        self.model.branchVar(top_cands[best_cand])

        result = SCIP_RESULT.BRANCHED

        self.branch_time = self.branch_time + time.time() - time_1
        print("one_time:", time.time() - time_1)
        print("self.count top + sb:",self.count, "\tnodes:", self.model.getNNodes(), "\ttime:", self.branch_time, "get_sb_time:", self.get_sb_time)



        if result == SCIP_RESULT.REDUCEDDOM:
            self.dom_reduction += 1
        elif result == SCIP_RESULT.CUTOFF:
            self.cut_off += 2
        return {'result': result}


    def branchexeclp_gcnn_sb_ave(self, allowaddcons):

        result = SCIP_RESULT.DIDNOTRUN
        self.count += 1
        time_1 = time.time()

        # 使用机器学习方法
        state = utils.extract_state(self.model)
        g = utils.graph_transform(state).to(self.device)

        cands, *_ = self.model.getPseudoBranchCands()

        action_set = [c.getCol().getLPPos() for c in cands]
    
        agent_out = self.agent.eval_forward(g)
        score = agent_out[action_set]

        # 从GCNN输出当中找出top k
        
        time_get_sb_time_start = time.time()
        current_node_id = self.model.getCurrentNode().getNumber()
        parent_id = 0 if self.model.getNNodes() == 1 else self.model.getCurrentNode().getParent().getNumber()


        topk_values, top_index_list = torch.topk(score.squeeze(), self.top_sb_k)
        # top_index_list = get_top_indexes(score, self.top_sb_k)

        top_pos_list = [ action_set[index] for index in  top_index_list]

        top_cands = [cands[index] for index in top_index_list]

        

        # 得到sb_score
        sb_score, history_sb_score = self.get_now_sb_score(top_pos_list, top_cands, parent_id)

        # 更新历史score
        self.add_now_sb_score(history_sb_score, current_node_id, parent_id)

        self.get_sb_time += time.time() - time_get_sb_time_start

        # 分支
        best_cand = int(sb_score.argmax())

        self.model.branchVar(top_cands[best_cand])

        result = SCIP_RESULT.BRANCHED

        self.branch_time = self.branch_time + time.time() - time_1

        print("one_time:", time.time() - time_1)
        print("self.count top + sb:",self.count, "\tnodes:", self.model.getNNodes(), "\ttime:", self.branch_time, "get_sb_time:", self.get_sb_time)




        if result == SCIP_RESULT.REDUCEDDOM:
            self.dom_reduction += 1
        elif result == SCIP_RESULT.CUTOFF:
            self.cut_off += 2
        return {'result': result}



    def branchexeclp_gcnn_sb_pool(self, allowaddcons):
        def get_top_indexes(lst, top_nums):
            
            top_nums = min(top_nums, len(lst))

            heap = [(-value, index) for index, value in enumerate(lst)]  # 将列表转化为元组形式并取相反数，以便将最大值放在堆顶
            heapq.heapify(heap)  # 将元组列表转化为堆
            top_nums = heapq.nsmallest(top_nums, heap)  # 获取前三个最小值（因为之前已经取相反数）
            return [index for value, index in top_nums]  # 返回前三个最小值对应的索引

        result = SCIP_RESULT.DIDNOTRUN
        self.count += 1
        time_1 = time.time()

        # 使用机器学习方法
        state = utils.extract_state(self.model)
        g = utils.graph_transform(state).to(self.device)

        cands, *_ = self.model.getPseudoBranchCands()

        action_set = [c.getCol().getLPPos() for c in cands]
        
        agent_out = self.agent.eval_forward(g)
        score = agent_out[action_set]

        # 现在要从中找出最好的五个
        top_index_list = get_top_indexes(score, self.top_sb_k)
        top_cands = [cands[pos] for pos in top_index_list]

        # 使用SB方法得到最好的

        time_get_sb_time_start = time.time()
        time_lps_start = time.time()
        lp_scores  = np.array(self.model.getCandsVsbScores_1(top_cands))
        print("time lp_scores:", time.time() - time_lps_start)

        time_lpsp_start = time.time()
        lp_scores_pool  = np.array(self.model.getCandsVsbScores_pool(top_cands, 3))
        print("time lp_pool_scores:", time.time() - time_lpsp_start)

        for i in range(len(lp_scores)):
            print(f"i:{i} \t lps:{lp_scores[i]} \t lpsp:{lp_scores_pool[i]}")

        assert 1==2

        lp_scores  = np.array(self.model.getCandsVsbScores_1(top_cands))
        self.get_sb_time += (time.time() - time_get_sb_time_start)

        best_cand = int(lp_scores.argmax())

        self.model.branchVar(top_cands[best_cand])

        result = SCIP_RESULT.BRANCHED

        self.branch_time = self.branch_time + time.time() - time_1
        print("one_time:", time.time() - time_1)
        print("self.count top + sb:",self.count, "\tnodes:", self.model.getNNodes(), "\ttime:", self.branch_time, "get_sb_time:", self.get_sb_time)



        if result == SCIP_RESULT.REDUCEDDOM:
            self.dom_reduction += 1
        elif result == SCIP_RESULT.CUTOFF:
            self.cut_off += 2
        return {'result': result}


    def branchexeclp_old(self, allowaddcons):

        def get_top_indexes(lst, top_nums):
            
            top_nums = min(top_nums, len(lst))

            heap = [(-value, index) for index, value in enumerate(lst)]  # 将列表转化为元组形式并取相反数，以便将最大值放在堆顶
            heapq.heapify(heap)  # 将元组列表转化为堆
            top_nums = heapq.nsmallest(top_nums, heap)  # 获取前三个最小值（因为之前已经取相反数）
            return [index for value, index in top_nums]  # 返回前三个最小值对应的索引

        result = SCIP_RESULT.DIDNOTRUN
        self.count += 1


        cands, *_ = self.model.getPseudoBranchCands()
        action_set = [c.getCol().getLPPos() for c in cands]
        lp_scores = np.array(self.model.getCandsVsbScores(top_cands, 1))
        best_cand = lp_scores.argmax()

        state = utils.extract_state(self.model)
        g = utils.graph_transform(state).to(self.device)

        time_1 = time.time()
        agent_out = self.agent.eval_forward(g)
        time_2 = time.time()
        self.time_gcn_time.append(time_2-time_1)

        score = agent_out[action_set]
        
        top_index_list = get_top_indexes(score, 10)
        top_cands = [cands[pos] for pos in top_index_list]

        time_1 = time.time()
        # lp_scores = np.array(self.model.getCandsVsbScores(top_cands, 1))

        pool = ThreadPool(30)
        results = pool.map(self.model.process, top_cands)
        pool.close()
        pool.join()
        print("results:",results)
        print("lp_scores.len:",len(results))
        time_2 = time.time()
        print("lp_scores:",time_2 - time_1)
        
        assert 1==2


        self.time_sb_time.append(time_2-time_1)


        best_cand = lp_scores.argmax()

        # print(f"cands:{len(cands)},max(lp_scores):{max(lp_scores)},best_cand:{best_cand},best:{lp_scores[best_cand]}")

        self.model.branchVar(top_cands[best_cand])

        result = SCIP_RESULT.BRANCHED

        if result == SCIP_RESULT.REDUCEDDOM:
            self.dom_reduction += 1
        elif result == SCIP_RESULT.CUTOFF:
            self.cut_off += 2
        return {'result': result}

    def branchexeclp_time(self, allowaddcons):

        def get_top_indexes(lst, top_nums):
            heap = [(-value, index) for index, value in enumerate(lst)]  # 将列表转化为元组形式并取相反数，以便将最大值放在堆顶
            heapq.heapify(heap)  # 将元组列表转化为堆
            top_nums = heapq.nsmallest(top_nums, heap)  # 获取前三个最小值（因为之前已经取相反数）
            return [index for value, index in top_nums]  # 返回前三个最小值对应的索引

        result = SCIP_RESULT.DIDNOTRUN
        self.count += 1

        print("self.count:",self.count)

        if self.count == 100:
            print("self.time_gcn_time:",sum(self.time_gcn_time)/len(self.time_gcn_time))
            print("self.time_all_cands_time:",sum(self.time_all_cands_time))
            print("self.time_all_count:",self.time_all_count)
            print("self.time_all_cands_time.ave:",sum(self.time_all_cands_time)/self.time_all_count)
            print("self.time_3_cands_time:",sum(self.time_3_cands_time))
            print("self.time_3_count:",self.time_3_count)
            print("self.time_3_cands_time:",sum(self.time_3_cands_time)/self.time_3_count)
            
            assert 1==2

        cands, *_ = self.model.getPseudoBranchCands()

        time_1 = time.time()
        lp_scores = self.model.getCandsVsbScores(cands)
        self.time_all_cands_time.append(time.time()-time_1)
        self.time_all_count = self.time_all_count + len(cands)

        action_set = [c.getCol().getLPPos() for c in cands]


        time_1 = time.time()
        state = utils.extract_state(self.model)
        g = utils.graph_transform(state).to(self.device)
        agent_out = self.agent.eval_forward(g)
        if self.count > 3:
            self.time_gcn_time.append(time.time() - time_1)

        score = agent_out[action_set]
        
        top_index_list = get_top_indexes(score, 3)
        top_cands = [cands[pos] for pos in top_index_list]

        time_1 = time.time()
        lp_scores = self.model.getCandsVsbScores(top_cands)
        self.time_3_cands_time.append(time.time()-time_1)
        self.time_3_count = self.time_3_count + len(top_cands)

        best_cand = int(lp_scores.argmax())

        self.model.branchVar(top_cands[best_cand])
        result = SCIP_RESULT.BRANCHED

        if result == SCIP_RESULT.REDUCEDDOM:
            self.dom_reduction += 1
        elif result == SCIP_RESULT.CUTOFF:
            self.cut_off += 2
        return {'result': result}

class PolicyBranching_27(Branchrule):
    # Get accurate data for imitation learning and supervised learning
    # 为模仿学习和监督学习获得准确的数据

    def __init__(self, scip, agent, device, args):
        super().__init__()
        self.model = scip
        self.device = device
        self.agent = agent
        self.args = args
        self.count = 0
        self.dom_reduction = 0
        self.cut_off = 0

    def branchinit(self):
        self.root_buffer = {}
        self.model.setBoolParam('branching/vanillafullstrong/integralcands', True)
        self.model.setBoolParam('branching/vanillafullstrong/scoreall', True)
        self.model.setBoolParam('branching/vanillafullstrong/collectscores', True)
        self.model.setBoolParam('branching/vanillafullstrong/donotbranch', True)
        self.model.setBoolParam('branching/vanillafullstrong/idempotent', True)

    def branchexeclp(self, allowaddcons):

        self.count += 1
        result = self.model.executeBranchRule('vanillafullstrong', allowaddcons)
        
        assert result == SCIP_RESULT.DIDNOTRUN

        cands, scores, npriocands, bestcand, scip_result = self.model.getVanillafullstrongData()

        # action_set = [c.getCol().getLPPos() for c in cands]
        # # print("action_set:",action_set)

        # best_index = np.array(scores).argmax()

        # print(f"i:{self.count},cands:{len(cands)}")
        # print(f"bestcand:{bestcand},best cands score:{scores[bestcand]}")
        # print(f"best_index:{best_index},best score:{scores[best_index]}")


        self.model.branchVar(cands[bestcand])
        # self.model.branchVar(bestcand)
        
        result = SCIP_RESULT.BRANCHED

        if result == SCIP_RESULT.REDUCEDDOM:
            self.dom_reduction += 1
        elif result == SCIP_RESULT.CUTOFF:
            self.cut_off += 2


        return {'result': result}

class PolicyBranching_28(Branchrule):
    # Get accurate data for imitation learning and supervised learning
    # 为模仿学习和监督学习获得准确的数据

    def __init__(self, scip, agent, device, args):
        super().__init__()
        self.model = scip
        self.device = device
        self.agent = agent
        self.args = args
        self.count = 0
        self.dom_reduction = 0
        self.cut_off = 0


    def branchinit(self):
        self.root_buffer = {}

    def branchexeclp(self, allowaddcons):

        result = SCIP_RESULT.DIDNOTRUN
        self.count += 1

        state = utils.extract_state_new_28(self.model)
        g = utils.graph_transform_new_28(state).to(self.device)

        cands, *_ = self.model.getPseudoBranchCands()
        action_set = [c.getCol().getLPPos() for c in cands]
        
        logits_0, logits_1 = self.agent.eval_forward(g)

        logits = torch.mul(torch.clamp(logits_0, min=1.0/6.0),torch.clamp(logits_1, min=1.0/6.0))

        score = logits[action_set]
        best_cand = int(score.argmax())

        self.model.branchVar(cands[best_cand])
        result = SCIP_RESULT.BRANCHED

        if result == SCIP_RESULT.REDUCEDDOM:
            self.dom_reduction += 1
        elif result == SCIP_RESULT.CUTOFF:
            self.cut_off += 2
        return {'result': result}

class PolicyBranching_29(Branchrule):
    # Get accurate data for imitation learning and supervised learning
    # 为模仿学习和监督学习获得准确的数据

    def __init__(self, scip, agent, device, args):
        super().__init__()
        self.model = scip
        self.device = device
        self.agent = agent
        self.args = args
        self.count = 0
        self.dom_reduction = 0
        self.cut_off = 0


    def branchinit(self):
        self.root_buffer = {}

    def branchexeclp(self, allowaddcons):

        result = SCIP_RESULT.DIDNOTRUN
        self.count += 1

        state = utils.extract_state_new_29(self.model)

        g = utils.graph_transform_new_29(state).to(self.device)

        cands, *_ = self.model.getPseudoBranchCands()
        action_set = [c.getCol().getLPPos() for c in cands]
        
        logits = self.agent.eval_forward_sb(g)

        score = logits[action_set]
        
        best_cand = int(score.argmax())

        self.model.branchVar(cands[best_cand])
        result = SCIP_RESULT.BRANCHED

        if result == SCIP_RESULT.REDUCEDDOM:
            self.dom_reduction += 1
        elif result == SCIP_RESULT.CUTOFF:
            self.cut_off += 2
        return {'result': result}

class PolicyBranching_30(Branchrule):
    # Get accurate data for imitation learning and supervised learning
    # 为模仿学习和监督学习获得准确的数据

    def __init__(self, scip, agent, device, args):
        super().__init__()
        self.model = scip
        self.device = device
        self.agent = agent
        self.args = args
        self.count = 0
        self.dom_reduction = 0
        self.cut_off = 0


    def branchinit(self):
        self.root_buffer = {}

    def branchexeclp(self, allowaddcons):

        result = SCIP_RESULT.DIDNOTRUN
        self.count += 1

        state = utils.extract_state_new_30(self.model)

        g = utils.graph_transform_new_30(state).to(self.device)

        cands, *_ = self.model.getPseudoBranchCands()
        action_set = [c.getCol().getLPPos() for c in cands]

        if self.args.model_id == 'sb':
            logits = self.agent.eval_forward_sb(g)
        
        elif self.args.model_id == 'lp':
            logits_0, logits_1 = self.agent.eval_forward_lp(g)
            logits = torch.mul(torch.clamp(logits_0, min=1.0/6.0), torch.clamp(logits_1, min=1.0/6.0))

        score = logits[action_set]
        
        best_cand = int(score.argmax())

        self.model.branchVar(cands[best_cand])
        result = SCIP_RESULT.BRANCHED

        if result == SCIP_RESULT.REDUCEDDOM:
            self.dom_reduction += 1
        elif result == SCIP_RESULT.CUTOFF:
            self.cut_off += 2
        return {'result': result}

class PolicyBranching_31(Branchrule):
    # Get accurate data for imitation learning and supervised learning
    # 为模仿学习和监督学习获得准确的数据

    def __init__(self, scip, agent, device, args):
        super().__init__()
        self.model = scip
        self.device = device
        self.agent = agent
        self.args = args
        self.count = 0
        self.dom_reduction = 0
        self.cut_off = 0


    def branchinit(self):
        self.root_buffer = {}

    def branchexeclp(self, allowaddcons):

        result = SCIP_RESULT.DIDNOTRUN
        self.count += 1

        state = utils.extract_state_new_31(self.model)
        cands, *_ = self.model.getPseudoBranchCands()
        action_set = [c.getCol().getLPPos() for c in cands]
        
        g = utils.graph_transform_new_31(state).to(self.device)

        

        if self.args.model_id == 'sb':
            logits = self.agent.eval_forward_sb(g)
        
        elif self.args.model_id == 'lp':
            logits_0, logits_1 = self.agent.eval_forward_lp(g)
            logits = torch.mul(logits_0,logits_1)
        
        elif self.args.model_id == 'all':
            logits_sb, logits_0, logits_1 = self.agent.eval_forward_all(g)
            logits_lp = torch.mul(logits_0,logits_1)
            
            logits = 0.5 * logits_sb + 0.5 * logits_lp
                    

        score = logits[action_set]
        
        best_cand = int(score.argmax())

        self.model.branchVar(cands[best_cand])
        result = SCIP_RESULT.BRANCHED

        if result == SCIP_RESULT.REDUCEDDOM:
            self.dom_reduction += 1
        elif result == SCIP_RESULT.CUTOFF:
            self.cut_off += 2
        return {'result': result}

class PolicyBranching_32(Branchrule):
    # Get accurate data for imitation learning and supervised learning
    # 为模仿学习和监督学习获得准确的数据

    def __init__(self, scip, agent, device, args):
        super().__init__()
        self.model = scip
        self.device = device
        self.agent = agent
        self.args = args
        self.count = 0
        self.dom_reduction = 0
        self.cut_off = 0


    def branchinit(self):
        self.root_buffer = {}

    def branchexeclp(self, allowaddcons):

        result = SCIP_RESULT.DIDNOTRUN
        self.count += 1

        state = utils.extract_state_new_32(self.model)
        cands, *_ = self.model.getPseudoBranchCands()
        action_set = [c.getCol().getLPPos() for c in cands]
        
        g = utils.graph_transform_new_32(state).to(self.device)

        

        if self.args.model_id == 'sb':
            logits = self.agent.eval_forward_sb(g)
        
        elif self.args.model_id == 'lp':
            logits_0, logits_1 = self.agent.eval_forward_lp(g)
            logits = torch.mul(logits_0,logits_1)
        
        elif self.args.model_id == 'all':
            logits_sb, logits_0, logits_1 = self.agent.eval_forward_all(g)
            logits_lp = torch.mul(logits_0,logits_1)
            
            logits = 0.5 * logits_sb + 0.5 * logits_lp
                    

        score = logits[action_set]
        
        best_cand = int(score.argmax())

        self.model.branchVar(cands[best_cand])
        result = SCIP_RESULT.BRANCHED

        if result == SCIP_RESULT.REDUCEDDOM:
            self.dom_reduction += 1
        elif result == SCIP_RESULT.CUTOFF:
            self.cut_off += 2
        return {'result': result}

class PolicyBranching_33(Branchrule):
    # Get accurate data for imitation learning and supervised learning
    # 为模仿学习和监督学习获得准确的数据

    def __init__(self, scip, agents, device, args):
        super().__init__()
        self.model = scip
        self.device = device
        self.agents = agents

        self.args = args
        self.count = 0
        self.dom_reduction = 0
        self.cut_off = 0
        self.model_select_flag = args.model_select_flag

    def model_select(self, g, now_depth):
        if self.model_select_flag == -1:
            
            if now_depth <= 1:
                return self.agents[1].forward(g)
            elif now_depth <= 4:
                return self.agents[1].forward(g)
            elif now_depth <= 7:
                return self.agents[1].forward(g)
            else:
                return self.agents[2].forward(g)
        
        elif self.model_select_flag == 1:
            return self.agents[0].forward(g)
        
        elif self.model_select_flag == 2:
            return self.agents[1].forward(g)
        
        elif self.model_select_flag == 3:
            return self.agents[2].forward(g)
        
        elif self.model_select_flag == 4:
            return self.agents[3].forward(g)

        else:
            raise Exception("model_select_flag is not set")

    def branchinit(self):
        self.root_buffer = {}
        self.count = 0


    def branchexeclp(self, allowaddcons):
        result = SCIP_RESULT.DIDNOTRUN
        state = utils.extract_state(self.model)

        g = utils.graph_transform(state).to(self.device)

        # agent_out = self.agent(g)
        # agent_out = self.model_select(g, self.model.getCurrentNode().getDepth())
        agent_out = self.model_select(g, self.model.getNNodes())

        agent_out = agent_out.nodes['v'].data['s']

        cands, *_ = self.model.getPseudoBranchCands()

        action_set = [c.getCol().getLPPos() for c in cands]
        
        
        score = agent_out[action_set]
        best_cand = int(score.argmax())

        self.model.branchVar(cands[best_cand])
        result = SCIP_RESULT.BRANCHED

        if result == SCIP_RESULT.REDUCEDDOM:
            self.dom_reduction += 1
        elif result == SCIP_RESULT.CUTOFF:
            self.cut_off += 2
        return {'result': result}
    
class PolicyBranching_33_rollout(Branchrule):
    # Get accurate data for imitation learning and supervised learning
    # 为模仿学习和监督学习获得准确的数据

    def __init__(self, scip, agents, device, args):
        super().__init__()
        self.model = scip
        self.device = device
        self.agents = agents

        self.args = args
        self.count = 0
        self.dom_reduction = 0
        self.cut_off = 0
        self.model_select_flag = args.model_select_flag

    def model_select(self, g, now_depth):
        if self.model_select_flag == -1:
            
            if now_depth <= 1:
                return self.agents[1].forward(g)
            elif now_depth <= 4:
                return self.agents[1].forward(g)
            elif now_depth <= 7:
                return self.agents[1].forward(g)
            else:
                return self.agents[2].forward(g)
        
        elif self.model_select_flag == 1:
            return self.agents[0].forward(g)
        
        elif self.model_select_flag == 2:
            return self.agents[1].forward(g)
        
        elif self.model_select_flag == 3:
            return self.agents[2].forward(g)
        
        elif self.model_select_flag == 4:
            return self.agents[3].forward(g)

        else:
            raise Exception("model_select_flag is not set")

    def branchinit(self):
        self.root_buffer = {}
        self.count = 0


    def branchexeclp(self, allowaddcons):
        result = SCIP_RESULT.DIDNOTRUN
        state = utils.extract_state(self.model)
        bound_feats = state[3]
        lbs = bound_feats['lbs']
        ubs = bound_feats['ubs']
        g = utils.graph_transform(state).to(self.device)

        # agent_out = self.agent(g)
        # agent_out = self.model_select(g, self.model.getCurrentNode().getDepth())
        agent_out = self.model_select(g, self.model.getNNodes())

        agent_out = agent_out.nodes['v'].data['s']

        cands, *_ = self.model.getPseudoBranchCands()

        action_set = [c.getCol().getLPPos() for c in cands]
        
        
        score = agent_out[action_set]
        values, indices = torch.topk(score, k=5)
        indices_int = indices.to(dtype=torch.int)
        
        var_list = model.getVars()
        for var in var_list:
            var.getLbGlobal()
        #获取所有还没固定的点的index
        #获取它们的上下界和参数，然后组成一个线性规划问题，还是并行的，按照算强分支的方法计算，然后得出结果
        for i in indices_int:
            cand = cands[i]
            
        best_cand = int(score.argmax())

        self.model.branchVar(cands[best_cand])
        result = SCIP_RESULT.BRANCHED

        if result == SCIP_RESULT.REDUCEDDOM:
            self.dom_reduction += 1
        elif result == SCIP_RESULT.CUTOFF:
            self.cut_off += 2
        return {'result': result}
 
class PolicyBranching_34(Branchrule):
    # Get accurate data for imitation learning and supervised learning
    # 为模仿学习和监督学习获得准确的数据

    def __init__(self, scip, agent, device, args):
        super().__init__()
        self.model = scip
        self.device = device
        self.agent = agent
        self.args = args
        self.count = 0
        self.dom_reduction = 0
        self.cut_off = 0


    def branchinit(self):
        self.root_buffer = {}
        self.count = 0
        self.time_base = 0
        self.time_37 = 0

    def branchexeclp(self, allowaddcons):

        result = SCIP_RESULT.DIDNOTRUN

        var_dim = 25

        my_cands, my_cands_pos, cands_state_mat = self.model.getCandsState(var_dim, self.count)
        self.count = self.count + 1

        cands_state_mat = torch.from_numpy(cands_state_mat.astype('float32')).to(self.device)

        agent_out = self.agent(cands_state_mat)
        agent_out = agent_out.squeeze()

        best_cand = int(agent_out.argmax())

        print(f"self.count:{self.count} \t best_cand:{my_cands_pos[best_cand]}")

        self.model.branchVar(my_cands[best_cand])
        result = SCIP_RESULT.BRANCHED

        if result == SCIP_RESULT.REDUCEDDOM:
            self.dom_reduction += 1
        elif result == SCIP_RESULT.CUTOFF:
            self.cut_off += 2
        return {'result': result}

class PolicyBranching_35(Branchrule):
    # Get accurate data for imitation learning and supervised learning
    # 为模仿学习和监督学习获得准确的数据

    def __init__(self, scip, agent, device, args):
        super().__init__()
        self.model = scip
        self.device = device
        self.agent = agent
        self.args = args
        self.count = 0
        self.dom_reduction = 0
        self.cut_off = 0


    def branchinit(self):
        self.root_buffer = {}
        self.count = 0

    
    def branchexeclp(self, allowaddcons,k=5):

        result = SCIP_RESULT.DIDNOTRUN
        state = utils.extract_state_new_35(self.model)
        # aa = calculate_SB_score(state[3],[])
        # state = state[:3]
        g = utils.graph_transform_new_35(state).to(self.device)

        agent_out = self.agent(g)

        agent_out = agent_out.nodes['v'].data['s']

        cands, *_ = self.model.getPseudoBranchCands()
        # print('cands',len(cands),cands)

        action_set = [c.getCol().getLPPos() for c in cands]
        
        score = agent_out[action_set]

        best_cand = int(score.argmax())
        
        # print('action_set', action_set, len(action_set), len(score),len(cands))
        # result = linprog(c, A_ub=A, b_ub=b, bounds=[x0_bounds, x1_bounds], method='interior-point')

        self.model.branchVar(cands[best_cand])
        result = SCIP_RESULT.BRANCHED
        # print('result',result)

        if result == SCIP_RESULT.REDUCEDDOM:
            self.dom_reduction += 1
        elif result == SCIP_RESULT.CUTOFF:
            self.cut_off += 2
        return {'result': result}

class PolicyBranching_35_topk(Branchrule):
    # Get accurate data for imitation learning and supervised learning
    # 为模仿学习和监督学习获得准确的数据

    def __init__(self, scip, agent, device, args):
        super().__init__()
        self.model = scip
        self.device = device
        self.agent = agent
        self.args = args
        self.count = 0
        self.dom_reduction = 0
        self.cut_off = 0


    def branchinit(self):
        self.root_buffer = {}
        self.count = 0

    
    def branchexeclp(self, allowaddcons, k=5):
#187225 k==1 187226 k=5
        result = SCIP_RESULT.DIDNOTRUN
        state = utils.extract_state_new_35(self.model)
        # state = utils.extract_state_new_35_topk(self.model)
        # topk_state = state[3]
        # state = state[:3]#复原为原始的state
        g = utils.graph_transform_new_35(state).to(self.device)

        agent_out = self.agent(g)

        agent_out = agent_out.nodes['v'].data['s']

        cands, *_ = self.model.getPseudoBranchCands()
        # print('cands',len(cands),cands)

        action_set = [c.getCol().getLPPos() for c in cands]
        
        score = agent_out[action_set]
       
        if (self.model.getCurrentNode().getDepth() <= 2 and self.model.getCurrentNode().getDepth() >-1):
            #187286 不用top_k方法的，即正常的 ，187287 用top-5且self.model.getCurrentNode().getDepth() <= 2 and self.model.getCurrentNode().getDepth() >-1
            _, topk_cand_index = torch.topk(score, k, dim=0)#先对应出top_k score的index，再根据index完善c,A,b
            topk_index = [cands[int(cand_index)].getCol().getLPPos() for cand_index in topk_cand_index]#在所有变量中的index
            aaaa = time.time()
            cur_z = self.model.getLPObjVal()
            # aa = time.time()
            # SB_list = []
            topk_index_cand = [cands[int(cand_index)] for cand_index in topk_cand_index]
            sb_list = []
            for cand_var in topk_index_cand:
                self.model.startProbing()
                assert not self.model.isObjChangedProbing()
                lb = cand_var.getLbGlobal()
                ub = cand_var.getUbGlobal()
                self.model.fixVarProbing(cand_var, lb)
                self.model.constructLP()
                self.model.solveProbingLP()
                lb_score = self.model.getLPObjVal()
                self.model.endProbing()
                self.model.startProbing()
                self.model.fixVarProbing(cand_var, ub)
                self.model.constructLP()
                self.model.solveProbingLP()
                ub_score = self.model.getLPObjVal()
                self.model.endProbing()
                sb_score = max(lb_score-cur_z, 1e-6)*max(ub_score-cur_z, 1e-6)
            #     #sb = max(lp_down-cur_z, 1e-6)*max(lp_up-cur_z, 1e-6)
                sb_list.append(sb_score)

            cand_indexx = np.argmax(sb_list)
            # print('cand_index',cand_indexx,topk_index_cand[cand_indexx])
            print('SCIP time', time.time() - aaaa)
            best_cand = topk_index_cand[cand_indexx]#topk_index_cand[calculate_SB_score(topk_state, topk_index, cur_z, k)]
        else:
            best_cand = cands[int(score.argmax())]

        self.model.branchVar(best_cand)#cands[best_cand]
        result = SCIP_RESULT.BRANCHED
        # print('result',result)

        if result == SCIP_RESULT.REDUCEDDOM:
            self.dom_reduction += 1
        elif result == SCIP_RESULT.CUTOFF:
            self.cut_off += 2
        return {'result': result}

class PolicyBranching_37(Branchrule):
    # Get accurate data for imitation learning and supervised learning
    # 为模仿学习和监督学习获得准确的数据

    def __init__(self, scip, agent, device, args):
        super().__init__()
        self.model = scip
        self.device = device
        self.agent_root = agent[0]
        self.agent_leaf = agent[1]
        self.args = args
        self.count = 0
        self.dom_reduction = 0
        self.cut_off = 0
        
        self.count = 0
        self.time_base = 0
        self.time_get_cands = 0
        self.time_37 = 0
        self.time_37_buffer = 0
        self.time_get_cands_buffer = 0
        self.time_base_buffer = 0
        self.root_state_base = None
        self.root_embeds = None

    def branchinit(self):
        self.root_buffer = {}
        self.root_37_buffer = {}
        
    
    def branchexeclp(self, allowaddcons):
        if self.model.getNNodes() == 1:
            pass
            # self.root_state_base, _, _ = utils.extract_state_new_37_base(self.model)
            # g = utils.graph_transform_new_37_all(self.root_state_base).to(self.device)
            # self.root_embeds = self.agent_root.get_var_embeds(g)

        state_cands_all, cands, cands_pos = utils.extract_state_new_37_all(self.model)
        
        g = utils.graph_transform_new_37_all(state_cands_all).to(self.device)

        # now_root_embeds = self.root_embeds[cands_pos]

        # agent_out = self.agent_leaf(now_root_embeds, g)

        agent_out =self.agent_leaf.forword_all(g)

        agent_out = agent_out.nodes['v'].data['s']

        best_cand = int(agent_out.argmax())

        self.model.branchVar(cands[best_cand])
        result = SCIP_RESULT.BRANCHED

        if result == SCIP_RESULT.REDUCEDDOM:
            self.dom_reduction += 1
        elif result == SCIP_RESULT.CUTOFF:
            self.cut_off += 2
        return {'result': result}


    def branchexeclp_time(self, allowaddcons):

        # time_1 = time.time()
        # state = utils.extract_state_baseline_time(self.model)
        # self.time_base += (time.time()-time_1)
        # print(f"############ now time_base:{time.time()-time_1} \n")

        time_1 = time.time()
        cands, *_ = self.model.getPseudoBranchCands()
        action_set = [c.getCol().getLPPos() for c in cands]
        self.time_get_cands += (time.time()-time_1)
        print(f"############ now time_get_cands:{time.time()-time_1} \n")


        # time_1 = time.time()
        # constraint_features, edge_features, variable_features, cands, cands_pos = utils.extract_state_new_37(self.model)
        # self.time_37 += (time.time()-time_1)
        # print(f"############ now time_37:{time.time()-time_1} \n")

        time_1 = time.time()
        state = utils.extract_state_baseline_time(self.model, self.root_buffer)
        self.time_base_buffer += (time.time()-time_1)
        print(f"############ now time_base_buffer:{time.time()-time_1} \n")

        time_1 = time.time()
        constraint_features, edge_features, variable_features, cands, cands_pos = utils.extract_state_new_37(self.model, self.root_37_buffer)
        self.time_37_buffer += (time.time()-time_1)
        print(f"############ now time_37_buffer:{time.time()-time_1} \n")


        print(f"count:{self.count} \t get_cands:{str(self.time_get_cands)[:5]} base:{str(self.time_base)[:5]} \t base_buffer:{str(self.time_base_buffer)[:5]} \t 37:{str(self.time_37-self.time_get_cands)[:5]} \t buffer:{str(self.time_37_buffer-self.time_get_cands)[:5]} ")

        self.count = self.count + 1

        result = SCIP_RESULT.DIDNOTRUN
        
        constraint_features, edge_features, variable_features, cands, cands_pos = utils.extract_state_new_37(self.model)
        state = constraint_features, edge_features, variable_features

        g = utils.graph_transform_new_37(state).to(self.device)

        agent_out = self.agent(g)
        agent_out = agent_out.nodes['v'].data['s']
        
        score = agent_out[cands_pos]
        best_cand = int(score.argmax())

        self.model.branchVar(cands[best_cand])
        result = SCIP_RESULT.BRANCHED

        if result == SCIP_RESULT.REDUCEDDOM:
            self.dom_reduction += 1
        elif result == SCIP_RESULT.CUTOFF:
            self.cut_off += 2
        return {'result': result}

class PolicyBranching_37_1(Branchrule):
    # Get accurate data for imitation learning and supervised learning
    # 为模仿学习和监督学习获得准确的数据

    def __init__(self, scip, agent, device, args):
        super().__init__()
        self.model = scip
        self.device = device
        self.agent = agent
        self.args = args
        self.count = 0
        self.dom_reduction = 0
        self.cut_off = 0
        self.state_buffer = {}

        self.branch_time = 0

    def branchinit(self):
        self.root_buffer = {}
        self.count = 0

    
    def branchexeclp(self, allowaddcons):
        branch_time_start = time.time()

        result = SCIP_RESULT.DIDNOTRUN
        state, cands, cands_pos = utils.extract_state_new_37_1(self.model, self.state_buffer)

        g = utils.graph_transform_new_37_1(state).to(self.device)

        agent_out = self.agent(g)

        agent_out = agent_out.nodes['v'].data['s']

        # cands, *_ = self.model.getPseudoBranchCands()

        action_set = [c.getCol().getLPPos() for c in cands]
        
        score = agent_out[action_set]
        best_cand = int(score.argmax())

        self.model.branchVar(cands[best_cand])
        result = SCIP_RESULT.BRANCHED

        if result == SCIP_RESULT.REDUCEDDOM:
            self.dom_reduction += 1
        elif result == SCIP_RESULT.CUTOFF:
            self.cut_off += 2


        branch_time_end = time.time() - branch_time_start
        self.branch_time += branch_time_end
        self.count += 1

        if self.count in [10, 100, 1000, 10000]:
            print(f"PolicyBranching_37_1 self.count:{self.count}, branch time:{self.branch_time}")

        return {'result': result}

class PolicyBranching_38(Branchrule):
    # Get accurate data for imitation learning and supervised learning
    # 为模仿学习和监督学习获得准确的数据

    def __init__(self, scip, agent, device, args):
        super().__init__()
        self.model = scip
        self.device = device
        self.agent = agent
        self.args = args
        self.count = 0
        self.dom_reduction = 0
        self.cut_off = 0


    def branchinit(self):
        self.root_buffer = {}
        self.count = 0

    
    def branchexeclp(self, allowaddcons):

        result = SCIP_RESULT.DIDNOTRUN
        state = utils.extract_state_new_38(self.model)

        g = utils.graph_transform_new_38(state).to(self.device)

        agent_out = self.agent(g)

        agent_out = agent_out.nodes['v'].data['s']

        cands, *_ = self.model.getPseudoBranchCands()

        action_set = [c.getCol().getLPPos() for c in cands]
        
        
        score = agent_out[action_set]
        best_cand = int(score.argmax())

        self.model.branchVar(cands[best_cand])
        result = SCIP_RESULT.BRANCHED

        if result == SCIP_RESULT.REDUCEDDOM:
            self.dom_reduction += 1
        elif result == SCIP_RESULT.CUTOFF:
            self.cut_off += 2
        return {'result': result}
    
class PolicyBranching_38_topk(Branchrule):
    # Get accurate data for imitation learning and supervised learning
    # 为模仿学习和监督学习获得准确的数据

    def __init__(self, scip, agent, device, args):
        super().__init__()
        self.model = scip
        self.device = device
        self.agent = agent
        self.args = args
        self.count = 0
        self.dom_reduction = 0
        self.cut_off = 0


    def branchinit(self):
        self.root_buffer = {}
        self.count = 0

    
    def branchexeclp(self, allowaddcons, k=3):

        result = SCIP_RESULT.DIDNOTRUN
        state = utils.extract_state_new_38(self.model)

        g = utils.graph_transform_new_38(state).to(self.device)

        agent_out = self.agent(g)

        agent_out = agent_out.nodes['v'].data['s']

        cands, *_ = self.model.getPseudoBranchCands()

        action_set = [c.getCol().getLPPos() for c in cands]
        
        
        score = agent_out[action_set]
        if (self.model.getCurrentNode().getDepth() <= 5 and self.model.getCurrentNode().getDepth() >-1):
            print('current depth',self.model.getCurrentNode().getDepth(),len(action_set))
        if (self.model.getCurrentNode().getDepth() <= 2 and self.model.getCurrentNode().getDepth() >-1):
            #187286 不用top_k方法的，即正常的 ，187287 用top-5且self.model.getCurrentNode().getDepth() <= 2 and self.model.getCurrentNode().getDepth() >-1
            _, topk_cand_index = torch.topk(score, k, dim=0)#先对应出top_k score的index，再根据index完善c,A,b
            topk_index = [cands[int(cand_index)].getCol().getLPPos() for cand_index in topk_cand_index]#在所有变量中的index
            aaaa = time.time()
            cur_z = self.model.getLPObjVal()
            # aa = time.time()
            # SB_list = []
            topk_index_cand = [cands[int(cand_index)] for cand_index in topk_cand_index]
            
            sb_list = []
            for cand_var in topk_index_cand:
                self.model.startProbing()
                assert not self.model.isObjChangedProbing()
                lb = cand_var.getLbGlobal()
                ub = cand_var.getUbGlobal()
                self.model.fixVarProbing(cand_var, lb)
                self.model.constructLP()
                self.model.solveProbingLP()
                lb_score = self.model.getLPObjVal()
                self.model.endProbing()
                self.model.startProbing()
                self.model.fixVarProbing(cand_var, ub)
                self.model.constructLP()
                self.model.solveProbingLP()
                ub_score = self.model.getLPObjVal()
                self.model.endProbing()
                sb_score = max(lb_score-cur_z, 1e-6)*max(ub_score-cur_z, 1e-6)
            #     #sb = max(lp_down-cur_z, 1e-6)*max(lp_up-cur_z, 1e-6)
                sb_list.append(sb_score)

            cand_indexx = np.argmax(sb_list)
            # print('cand_index',cand_indexx,topk_index_cand[cand_indexx])
            # print('SCIP time', time.time() - aaaa)
            best_cand = topk_index_cand[cand_indexx]#topk_index_cand[calculate_SB_score(topk_state, topk_index, cur_z, k)]
        else:
            best_cand = cands[int(score.argmax())]

        self.model.branchVar(best_cand)#self.model.branchVar(cands[best_cand])
        result = SCIP_RESULT.BRANCHED

        if result == SCIP_RESULT.REDUCEDDOM:
            self.dom_reduction += 1
        elif result == SCIP_RESULT.CUTOFF:
            self.cut_off += 2
        return {'result': result}

class PolicyBranching_39(Branchrule):
    # Get accurate data for imitation learning and supervised learning
    # 为模仿学习和监督学习获得准确的数据

    def __init__(self, scip, agent, device, args):
        super().__init__()
        self.model = scip
        self.device = device
        self.agent = agent
        self.args = args
        self.count = 0
        self.dom_reduction = 0
        self.cut_off = 0


    def branchinit(self):
        self.root_buffer = {}
        self.count = 0

    
    def branchexeclp(self, allowaddcons):

        result = SCIP_RESULT.DIDNOTRUN
        state = utils.extract_state_new_39(self.model)

        g = utils.graph_transform_new_39(state).to(self.device)

        agent_out = self.agent(g)

        agent_out = agent_out.nodes['v'].data['s']

        cands, *_ = self.model.getPseudoBranchCands()

        action_set = [c.getCol().getLPPos() for c in cands]
        
        score = agent_out[action_set]
        best_cand = int(score.argmax())

        self.model.branchVar(cands[best_cand])
        result = SCIP_RESULT.BRANCHED

        if result == SCIP_RESULT.REDUCEDDOM:
            self.dom_reduction += 1
        elif result == SCIP_RESULT.CUTOFF:
            self.cut_off += 2
        return {'result': result}

class PolicyBranching_39_topk(Branchrule):
    # Get accurate data for imitation learning and supervised learning
    # 为模仿学习和监督学习获得准确的数据

    def __init__(self, scip, agent, device, args):
        super().__init__()
        self.model = scip
        self.device = device
        self.agent = agent
        self.args = args
        self.count = 0
        self.dom_reduction = 0
        self.cut_off = 0


    def branchinit(self):
        self.root_buffer = {}
        self.count = 0

    
    def branchexeclp(self, allowaddcons, k=5):

        result = SCIP_RESULT.DIDNOTRUN
        state = utils.extract_state_new_39(self.model)

        g = utils.graph_transform_new_39(state).to(self.device)

        agent_out = self.agent(g)

        agent_out = agent_out.nodes['v'].data['s']

        cands, *_ = self.model.getPseudoBranchCands()

        action_set = [c.getCol().getLPPos() for c in cands]
        
        score = agent_out[action_set]

        if (self.model.getCurrentNode().getDepth() <= 3 and self.model.getCurrentNode().getDepth() >-1):
            if (self.model.getCurrentNode().getDepth()<1):
                k = 10
            elif (self.model.getCurrentNode().getDepth()<=2):
                k = 5
            else:
                k = 2
            #187286 不用top_k方法的，即正常的 ，187287 用top-5且self.model.getCurrentNode().getDepth() <= 2 and self.model.getCurrentNode().getDepth() >-1
            _, topk_cand_index = torch.topk(score, k, dim=0)#先对应出top_k score的index，再根据index完善c,A,b
            topk_index = [cands[int(cand_index)].getCol().getLPPos() for cand_index in topk_cand_index]#在所有变量中的index
            aaaa = time.time()
            cur_z = self.model.getLPObjVal()
            # aa = time.time()
            # SB_list = []
            topk_index_cand = [cands[int(cand_index)] for cand_index in topk_cand_index]
            sb_list = []
            for cand_var in topk_index_cand:
                self.model.startProbing()
                assert not self.model.isObjChangedProbing()
                lb = cand_var.getLbGlobal()
                ub = cand_var.getUbGlobal()
                self.model.fixVarProbing(cand_var, lb)
                self.model.constructLP()
                self.model.solveProbingLP()
                lb_score = self.model.getLPObjVal()
                self.model.endProbing()
                self.model.startProbing()
                self.model.fixVarProbing(cand_var, ub)
                self.model.constructLP()
                self.model.solveProbingLP()
                ub_score = self.model.getLPObjVal()
                self.model.endProbing()
                sb_score = max(lb_score-cur_z, 1e-6)*max(ub_score-cur_z, 1e-6)
            #     #sb = max(lp_down-cur_z, 1e-6)*max(lp_up-cur_z, 1e-6)
                sb_list.append(sb_score)

            cand_indexx = np.argmax(sb_list)
            # print('cand_index',cand_indexx,topk_index_cand[cand_indexx])
            print('SCIP time', time.time() - aaaa)
            best_cand = topk_index_cand[cand_indexx]#topk_index_cand[calculate_SB_score(topk_state, topk_index, cur_z, k)]
        else:
            best_cand = cands[int(score.argmax())]

        self.model.branchVar(best_cand)#self.model.branchVar(cands[best_cand])
        result = SCIP_RESULT.BRANCHED

        if result == SCIP_RESULT.REDUCEDDOM:
            self.dom_reduction += 1
        elif result == SCIP_RESULT.CUTOFF:
            self.cut_off += 2
        return {'result': result}

class PolicyBranching_39_data(Branchrule):
    # Get accurate data for imitation learning and supervised learning
    # 为模仿学习和监督学习获得准确的数据

    def __init__(self, scip, agent, device, args):
        super().__init__()
        self.model = scip
        self.device = device
        self.agent = agent
        self.args = args
        self.count = 0
        self.dom_reduction = 0
        self.cut_off = 0
        self.app_data = []
        


    def branchinit(self):
        self.root_buffer = {}
        self.count = 0

    
    def branchexeclp(self, allowaddcons):
        
        one_data = {}

        ques_time = self.model.getTotalTime()
        ques_nodes_sum = self.model.getNNodes()
        ques_lp_sum = self.model.getNLPIterations()
        ques_gap = self.model.getGap()
        ques_Primalbound = self.model.getPrimalbound()
        ques_Dualbound = self.model.getDualbound()
        
        one_data["ques_time"] = ques_time
        one_data["ques_nodes_sum"] = ques_nodes_sum
        one_data["ques_lp_sum"] = ques_lp_sum
        one_data["ques_gap"] = ques_gap
        one_data["ques_Primalbound"] = ques_Primalbound
        one_data["ques_Dualbound"] = ques_Dualbound


        now_node = self.model.getCurrentNode()
        parent_node = now_node.getParent()

        node_id = now_node.getNumber()
        node_parent_id = -1 if parent_node is None else parent_node.getNumber()
        node_depth = now_node.getDepth()
        now_low_bound = now_node.getLowerbound()

        one_data["node_id"] = node_id
        one_data["node_parent_id"] = node_parent_id
        one_data["node_depth"] = node_depth
        one_data["now_low_bound"] = now_low_bound

        time_1 = time.time()


        result = SCIP_RESULT.DIDNOTRUN
        state = utils.extract_state_new_39(self.model)
        g = utils.graph_transform_new_39(state).to(self.device)

        agent_out = self.agent(g)

        agent_out = agent_out.nodes['v'].data['s']

        cands, *_ = self.model.getPseudoBranchCands()

        action_set = [c.getCol().getLPPos() for c in cands]
        
        
        score = agent_out[action_set]
        best_cand = int(score.argmax())

        downchild, eqchild, upchild = self.model.branchVar(cands[best_cand])

        node_son_0_id = downchild.getNumber()
        node_son_1_id = upchild.getNumber()

        one_data["node_son_0_id"] = node_son_0_id
        one_data["node_son_1_id"] = node_son_1_id

        result = SCIP_RESULT.BRANCHED

        if result == SCIP_RESULT.REDUCEDDOM:
            self.dom_reduction += 1
        elif result == SCIP_RESULT.CUTOFF:
            self.cut_off += 2


        node_branch_time = time.time() - time_1
        one_data["node_branch_time"] = node_branch_time

        variable_info_list = []

        values, indices = torch.topk(score.squeeze(), 10)

        for i in range(len(indices)):
            one_variable_info = {}
            one_variable_info["index"] = action_set[indices[i]]
            one_variable_info["lp_sol"] = cands[indices[i]].getLPSol()
            one_variable_info["score"] = values[i].item()

            variable_info_list.append(one_variable_info)

        one_data["variable_info_list"] = variable_info_list
        
        self.app_data.append(one_data)

        return {'result': result}

class PolicyBranching_40(Branchrule):
    # Get accurate data for imitation learning and supervised learning
    # 为模仿学习和监督学习获得准确的数据

    def __init__(self, scip, agent, device, args):
        super().__init__()
        self.model = scip
        self.device = device
        self.agent = agent
        self.args = args
        self.count = 0
        self.dom_reduction = 0
        self.cut_off = 0


    def branchinit(self):
        self.root_buffer = {}
        self.count = 0

    
    def branchexeclp(self, allowaddcons):

        result = SCIP_RESULT.DIDNOTRUN
        state = utils.extract_state_new_40(self.model)

        g = utils.graph_transform_new_40(state).to(self.device)

        agent_out = self.agent(g)

        agent_out = agent_out.nodes['v'].data['s']

        cands, *_ = self.model.getPseudoBranchCands()

        action_set = [c.getCol().getLPPos() for c in cands]
        
        
        score = agent_out[action_set]
        best_cand = int(score.argmax())

        self.model.branchVar(cands[best_cand])
        result = SCIP_RESULT.BRANCHED

        if result == SCIP_RESULT.REDUCEDDOM:
            self.dom_reduction += 1
        elif result == SCIP_RESULT.CUTOFF:
            self.cut_off += 2
        return {'result': result}

class PolicyBranching_40_topk(Branchrule):
    # Get accurate data for imitation learning and supervised learning
    # 为模仿学习和监督学习获得准确的数据

    def __init__(self, scip, agent, device, args):
        super().__init__()
        self.model = scip
        self.device = device
        self.agent = agent
        self.args = args
        self.count = 0
        self.dom_reduction = 0
        self.cut_off = 0


    def branchinit(self):
        self.root_buffer = {}
        self.count = 0

    
    def branchexeclp(self, allowaddcons, k = 5):

        result = SCIP_RESULT.DIDNOTRUN
        state = utils.extract_state_new_40(self.model)

        g = utils.graph_transform_new_40(state).to(self.device)

        agent_out = self.agent(g)

        agent_out = agent_out.nodes['v'].data['s']

        cands, *_ = self.model.getPseudoBranchCands()

        action_set = [c.getCol().getLPPos() for c in cands]
        
        
        score = agent_out[action_set]
        # best_cand = int(score.argmax())

        # self.model.branchVar(cands[best_cand])
        # result = SCIP_RESULT.BRANCHED
        if (self.model.getCurrentNode().getDepth() <= 2 and self.model.getCurrentNode().getDepth() >-1):
            # if (self.model.getCurrentNode().getDepth()<1):
            #     k = 10
            # elif (self.model.getCurrentNode().getDepth()<=2):
            #     k = 5
            # else:
            #     k = 2
            #187286 不用top_k方法的，即正常的 ，187287 用top-5且self.model.getCurrentNode().getDepth() <= 2 and self.model.getCurrentNode().getDepth() >-1
            _, topk_cand_index = torch.topk(score, k, dim=0)#先对应出top_k score的index，再根据index完善c,A,b
            topk_index = [cands[int(cand_index)].getCol().getLPPos() for cand_index in topk_cand_index]#在所有变量中的index
            aaaa = time.time()
            cur_z = self.model.getLPObjVal()
            # aa = time.time()
            # SB_list = []
            topk_index_cand = [cands[int(cand_index)] for cand_index in topk_cand_index]
            sb_list = []
            for cand_var in topk_index_cand:
                self.model.startProbing()
                assert not self.model.isObjChangedProbing()
                lb = cand_var.getLbGlobal()
                ub = cand_var.getUbGlobal()
                self.model.fixVarProbing(cand_var, lb)
                self.model.constructLP()
                self.model.solveProbingLP()
                lb_score = self.model.getLPObjVal()
                self.model.endProbing()
                self.model.startProbing()
                self.model.fixVarProbing(cand_var, ub)
                self.model.constructLP()
                self.model.solveProbingLP()
                ub_score = self.model.getLPObjVal()
                self.model.endProbing()
                sb_score = max(lb_score-cur_z, 1e-6)*max(ub_score-cur_z, 1e-6)
            #     #sb = max(lp_down-cur_z, 1e-6)*max(lp_up-cur_z, 1e-6)
                sb_list.append(sb_score)

            cand_indexx = np.argmax(sb_list)
            # print('cand_index',cand_indexx,topk_index_cand[cand_indexx])
            print('SCIP time', time.time() - aaaa)
            best_cand = topk_index_cand[cand_indexx]#topk_index_cand[calculate_SB_score(topk_state, topk_index, cur_z, k)]
        else:
            best_cand = cands[int(score.argmax())]

        self.model.branchVar(best_cand)#self.model.branchVar(cands[best_cand])
        result = SCIP_RESULT.BRANCHED

        if result == SCIP_RESULT.REDUCEDDOM:
            self.dom_reduction += 1
        elif result == SCIP_RESULT.CUTOFF:
            self.cut_off += 2
        return {'result': result}

def log_end_info(args, statuss, sol_time, nodes, fair_nodes, gaps, device):

    # gaps = [float(i) for i in gaps]
    over_count = 0
    for i in statuss:
        if i == 'optimal':
            over_count = over_count + 1

    csv_path = os.path.join(result_dir, f'0.{args.ins_config}.csv')
    
    print("csv_path:",csv_path)
    

    if os.path.exists(csv_path):
        print("find csv_path")
        df = pd.read_csv(csv_path)
    else:
        df = pd.DataFrame(columns=[
            'device',
            'type',
            'config',
            'num_instance',
            'method',
            'check_point',
            'over_rate',
            'avg_time',
            'avg_nodes',
            'avg_fair_nodes'])
        

    # print("len(df.index):",len(df.index))
    # print("method:",df['method'])
    if args.method in ['new_33', 'new_33_root']:
        df.loc[len(df.index)] = [
            str(device),
            args.ins_type,
            args.ins_config,
            args.num_instance,
            args.method,
            args.model_select_flag,
            str(over_count)+"/"+str(len(statuss)),
            num_mean(sol_time),
            num_mean(nodes),
            num_mean(fair_nodes)
            ]
    else:
        df.loc[len(df.index)] = [
            str(device),
            args.ins_type,
            args.ins_config,
            args.num_instance,
            args.method,
            args.check_point,
            str(over_count)+"/"+str(len(statuss)),
            num_mean(sol_time),
            num_mean(nodes),
            num_mean(fair_nodes)
            ]

    df.to_csv(csv_path,index=False)


def evaluate(agent, instances, result_dir, device, args):
    # vals = {}
    statuss = []
    sol_time = []
    nodes = []
    fair_nodes = []
    gaps = []

    # if args.method != 'presolve':
    #     with open(os.path.join(f'../results/{args.ins_type}/{args.ins_config}', 'vals.pkl'), 'rb') as handle:
    #         vals = pickle.load(handle)
    # print('instance',instances[0])
    for instance in instances:
        model = Model()
        model.hideOutput()
        model.readProblem(f"{instance}")
        set_scip(model, args.scip_seed, restart=False, separator=False, primal_heuristic=False)
        model.setIntParam('timing/clocktype', 2)  # 1: CPU user seconds, 2: wall clock time
        model.setRealParam('limits/time', args.time_limit)

        if args.method in ['relpscost', 'presolve', 'fullstrong', 'vanillafullstrong', 'vanilla', 'mostinf', 'random']:
            brancher = PolicyBranching(model, agent, device, args)

        elif args.method == 'baseline':
            brancher = PolicyBranching_baseline(model, agent, device, args)

        elif args.method in ['gcnn', 'pd', 'new_1', 'new_2', 'new_5', 'new_6', 'new_7', 'new_16', 'new_17', 'new_18']:
            brancher = PolicyBranching_gcn(model, agent, device, args)

        elif args.method == 'new_3':
            brancher = PolicyBranching_3(model, agent, device, args)

        elif args.method == 'new_3_0':
            brancher = PolicyBranching_3_0(model, agent, device, args)

        elif args.method == 'new_3_1':
            brancher = PolicyBranching_3_1(model, agent, device, args)

        elif args.method == 'new_3_2':
            brancher = PolicyBranching_3_2(model, agent, device, args)

        elif args.method == 'new_4':
            brancher = PolicyBranching_4(model, agent, device, args)
        
        elif args.method == 'new_8':
            brancher = PolicyBranching_8_2(model, agent, device, args)
        
        elif args.method == 'new_10':
            brancher = PolicyBranching_10_2(model, agent, device, args)

        elif args.method == 'new_12':
            brancher = PolicyBranching_12_2(model, agent, device, args)

        elif args.method == 'new_13':
            brancher = PolicyBranching_13_2(model, agent, device, args)

        elif args.method == 'new_14':
            brancher = PolicyBranching_14_2(model, agent, device, args)

        elif args.method == 'new_15':
            brancher = PolicyBranching_15_2(model, agent, device, args)
        
        elif args.method == 'fsb_gcnn':
            brancher = PolicyBranching_fsb_gcnn(model, agent, device, args)

        elif args.method == 'new_20':
            brancher = PolicyBranching_20(model, agent, device, args)
        
        elif args.method == 'new_21':
            brancher = PolicyBranching_21(model, agent, device, args)
        
        elif args.method == 'new_20_topk':
            brancher = PolicyBranching_20_topk(model, agent, device, args)
        elif args.method == 'new_21_topk':
            brancher = PolicyBranching_21_topk(model, agent, device, args)

        elif args.method == 'new_22':
            brancher = PolicyBranching_22(model, agent, device, args)
        elif args.method == 'new_22_topk':
            brancher = PolicyBranching_22_topk(model, agent, device, args)

        elif args.method == 'new_23':
            brancher = PolicyBranching_23(model, agent, device, args)
        
        elif args.method == 'new_25':
            brancher = PolicyBranching_25(model, agent, device, args)

        elif args.method == 'new_26':
            brancher = PolicyBranching_26(model, agent, device, args)
        
        elif args.method == 'new_27':
            brancher = PolicyBranching_27(model, agent, device, args)
        
        elif args.method == 'new_28':
            brancher = PolicyBranching_28(model, agent, device, args)

        elif args.method == 'new_29':
            brancher = PolicyBranching_29(model, agent, device, args)

        elif args.method == 'new_30':
            brancher = PolicyBranching_30(model, agent, device, args)

        elif args.method == 'new_31':
            brancher = PolicyBranching_31(model, agent, device, args)

        elif args.method == 'new_32':
            brancher = PolicyBranching_32(model, agent, device, args)

        elif args.method in ['new_33', 'new_33_root']:
            brancher = PolicyBranching_33(model, agent, device, args)
    
        elif args.method == 'new_34':
            brancher = PolicyBranching_34(model, agent, device, args)

        elif args.method == 'new_35':
            brancher = PolicyBranching_35(model, agent, device, args)
        elif args.method == 'new_35_topk':
            brancher = PolicyBranching_35_topk(model, agent, device, args)

        elif args.method == 'new_37':
            brancher = PolicyBranching_37(model, agent, device, args)

        elif args.method == 'new_37_1':
            brancher = PolicyBranching_37_1(model, agent, device, args)

        elif args.method == 'new_38':
            brancher = PolicyBranching_38(model, agent, device, args)
        elif args.method == 'new_38_topk':
            brancher = PolicyBranching_38_topk(model, agent, device, args)

        elif args.method == 'new_39':
            brancher = PolicyBranching_39(model, agent, device, args)
        elif args.method == 'new_39_topk':
            brancher = PolicyBranching_39_topk(model, agent, device, args)

        elif args.method == 'new_39_data':
            brancher = PolicyBranching_39_data(model, agent, device, args)

        elif args.method == 'new_40':
            brancher = PolicyBranching_40(model, agent, device, args)
        elif args.method == 'new_40_topk':
            brancher = PolicyBranching_40_topk(model, agent, device, args)
        elif args.method == 'new_41':
            brancher = PolicyBranching_33(model, agent, device, args)

        else:
            assert 1==2

        

        model.includeBranchrule(brancher, "Evaluate", "Policy branching on variable",
                                priority=99999, maxdepth=-1, maxbounddist=1)
        


        if args.method == 'vanilla':
            model.setBoolParam('branching/vanillafullstrong/integralcands', True)
            model.setBoolParam('branching/vanillafullstrong/scoreall', False)
            model.setBoolParam('branching/vanillafullstrong/collectscores', False)
            model.setBoolParam('branching/vanillafullstrong/donotbranch', False)
            model.setBoolParam('branching/vanillafullstrong/idempotent', True)

        # if args.method != 'presolve':
        #     objval = vals[instance]
        #     if model.getObjectiveSense() == 'minimize':
        #         objval += 1e-3
        #     else:
        #         objval -= 1e-3
        #     model.setObjlimit(objval)

        model.optimize()
        status = model.getStatus()
        # print('status',status)
        val = model.getObjVal()
        solving_time = model.getSolvingTime()
        num_nodes = model.getNNodes()
        fair_node = num_nodes + brancher.cut_off + brancher.dom_reduction
        gap = model.getGap()
        # print('gap',gap)

        def save_data(data, filename):
            with open(filename, 'wb') as f:
                pickle.dump(data, f)

        # print("app_data_len:", len(brancher.app_data))
        # save_data(brancher.app_data, 'data.pkl')
        # assert 1==2

        # vals[instance] = val
        statuss.append(status)
        sol_time.append(solving_time)
        nodes.append(num_nodes)
        fair_nodes.append(fair_node)
        gaps.append(gap)
        out_instance=instance[:]
        logger.info('{}: status={}, sol_time={:.2f}, nodes={}, fair_nodes={}'.format(out_instance, status, solving_time,
                                                                                     num_nodes, fair_node))
        model.freeProb()


    log_end_info(args, statuss, sol_time, nodes, fair_nodes, gaps, device)

    if args.method == 'presolve':
        with open(os.path.join(result_dir, 'vals.pkl'), 'wb') as handle:
            pickle.dump(vals, handle)
    else:
        statuss.append('avg')
        sol_time.append(num_mean(sol_time))
        nodes.append(num_mean(nodes))
        fair_nodes.append(num_mean(fair_nodes))
        gaps.append('gaps')

        d = {'status': statuss, 'time': sol_time, 'nodes': nodes, 'fair_nodes': fair_nodes, 'gap': gaps}
        df = pd.DataFrame(data=d)

        now_time = datetime.now()
        formatted_date = now_time.strftime("%m-%d_%H:%M")
        df.to_csv(os.path.join(result_dir, f'{formatted_date}_{args.method}_{args.start_num}-{args.num_instance}.csv'))

    logger.info('[{}]: avg_time={:.2f}, geo_time={:.2f}, avg_nodes={:.2f}, geo_nodes={:.2f}, '
                'avg_fair_nodes={:.2f}, geo_fair_nodes={:.2f}'
                .format(len(instances), num_mean(sol_time), geo_mean(sol_time), num_mean(nodes), geo_mean(nodes),
                        num_mean(fair_nodes), geo_mean(fair_nodes)))
    return num_mean(nodes), geo_mean(nodes)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('-m', '--method', help='Branch method.',
                        # choices=['fullstrong', 'vanillafullstrong', 'relpscost', 'gcn', 'vanilla', 'pd',
                        #          'presolve', 'mostinf', 'random'],
                        default='new_8')
    # parameters for setting the net
    # 网络设置参数
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--check_point', type=str, default=None)
    parser.add_argument('--device_id', '-d', type=int, default=0)
    # parameters for evaluate test instances
    # 评估测试实例的参数
    parser.add_argument('--scip_seed', type=int, default=0)
    parser.add_argument('--ins_type', type=str, help='instances directory', default='setcover_400r_1000c_0.05d_100mc_0se')
    parser.add_argument('--ins_config', type=str, help='instances directory', default='transfer_1000r')
    parser.add_argument('--time_limit', type=int, help='limit time for solving each instance', default=3600)
    parser.add_argument('--num_instance', type=int, default=1)
    parser.add_argument('--start_num', type=int, default=0)
    # parser.add_argument('--code_id', type=str, default='baseline')
    parser.add_argument('--model_id', type=str, default='sb')
    parser.add_argument('--h_model_1', type=str, default='84')
    parser.add_argument('--h_model_2', type=str, default='93')
    parser.add_argument('--h_model_3', type=str, default='43')
    parser.add_argument('--h_model_4', type=str, default='10')
    parser.add_argument('--model_select_flag', type=int, default=-1)

    # parameters for test accuracy
    # 测试精度参数
    args = parser.parse_args()
    device = set_device_seed(args)

    instance_dir = os.path.join('../data/instances/', f'{args.ins_type}/{args.ins_config}')
    
    result_dir = os.path.join('../results/', f'{args.ins_type}/{args.ins_config}')

    if not os.path.isdir(result_dir):
        os.makedirs(result_dir)

    
    logger.info(f'evaluate on {instance_dir}, method is {args.method}')
    
    if args.check_point is not None:
        logger.info(f'evaluate use {args.check_point}')

    
    if args.method == 'baseline':
        agent = model.GCNN_Net(v_dim=17).to(device)
    
    elif args.method == 'fsb_gcnn':
        agent = model.GCNN_Net(v_dim=17).to(device)

    elif args.method == 'new_1':
        agent = model.GCNN_Net_new_1(v_dim=23).to(device)
    
    elif args.method == 'new_2':
        agent = model.GCNN_Net_new_2(v_dim=17).to(device)

    elif args.method == 'new_3':
        agent = model.GCNN_Net_new_3(v_dim=17).to(device)

    elif args.method == 'new_3_0':
        agent = model.GCNN_Net_new_3_0(v_dim=17).to(device)

    elif args.method == 'new_3_1':
        agent = model.GCNN_Net_new_3_1(v_dim=17).to(device)

    elif args.method == 'new_3_2':
        agent = model.GCNN_Net_new_3_2(v_dim=25).to(device)
    
    elif args.method == 'new_4':
        agent = model.GCNN_Net_new_4(v_dim=17).to(device)
    
    elif args.method == 'new_5':
        agent = model.GCNN_Net_new_5(v_dim=17).to(device)
    
    elif args.method == 'new_6':
        agent = model.GCNN_Net_new_6(v_dim=17).to(device)

    elif args.method == 'new_7':
        agent = model.GCNN_Net_new_7(v_dim=17).to(device)

    elif args.method == 'new_8':
        agent = model.GCNN_Net_new_8(v_dim=17).to(device)

    elif args.method == 'new_10':
        agent = model.GCNN_Net_new_10(v_dim=17).to(device)

    elif args.method == 'new_12':
        agent = model.GCNN_Net_new_12(v_dim=17).to(device)

    elif args.method == 'new_13':
        agent = model.GCNN_Net_new_13(v_dim=17).to(device)

    elif args.method == 'new_14':
        agent = model.GCNN_Net_new_14(v_dim=17).to(device)

    elif args.method == 'new_15':
        agent = model.GCNN_Net_new_15(v_dim=17).to(device)

    elif args.method == 'new_16':
        agent = model.GCNN_Net(v_dim=17).to(device)

    elif args.method == 'new_17':
        agent = model.GCNN_Net(v_dim=17).to(device)

    elif args.method == 'new_18':
        agent = model.GCNN_Net(v_dim=17).to(device)

    elif args.method in ['new_20', 'new_20_topk','new_21', 'new_21_topk']:#== 'new_20':
        agent = model.GCNN_Net_new_20(v_dim=17).to(device)
    
    elif args.method in ['new_22', 'new_22_topk']:#== 'new_22':
        agent = model.GCNN_Net_new_22(v_dim=17).to(device)
    
    elif args.method == 'new_23':
        agent = model.GCNN_Net_new_23(v_dim=17).to(device)

    elif args.method == 'new_25':
        agent = model.GCNN_Net_new_25(v_dim=17).to(device)

    elif args.method == 'new_26':
        agent = model.GCNN_Net(v_dim=17).to(device)

    elif args.method == 'new_27':
        agent = model.GCNN_Net(v_dim=17).to(device)

    elif args.method == 'new_28':
        agent = model.GCNN_Net_new_28(v_dim=17).to(device)

    elif args.method == 'new_29':
        agent = model.GCNN_Net_new_29(v_dim=35).to(device)

    elif args.method == 'new_30':
        agent = model.GCNN_Net_new_30(v_dim=33).to(device)

    elif args.method == 'new_31':
        agent = model.GCNN_Net_new_31(v_dim=29).to(device)

    elif args.method == 'new_32':
        agent = model.GCNN_Net_new_32(v_dim=17).to(device)
    
    elif args.method in ["new_33", "new_33_root", 'new_41']:
        agent = [
            model.GCNN_Net_new_33_1(v_dim=17).to(device),
            model.GCNN_Net_new_33_2(v_dim=17).to(device),
            model.GCNN_Net_new_33_3(v_dim=17).to(device),
            model.GCNN_Net_new_33_4(v_dim=17).to(device)]

    elif args.method == 'new_34':
        agent = model.GCNN_Net_new_34().to(device)
    
    elif args.method in ['new_35', 'new_35_topk']:
        agent = model.GCNN_Net(v_dim=17).to(device)

    elif args.method == 'new_37_1':
        agent = model.GCNN_Net(v_dim=17).to(device)

    elif args.method in ['new_38', 'new_38_topk']:#== 'new_38':
        agent = model.GCNN_Net(v_dim=17).to(device)

    elif args.method in ['new_39', 'new_39_topk']:#== 'new_39':
        agent = model.GCNN_Net(v_dim=17).to(device)

    elif args.method == 'new_39_data':
        agent = model.GCNN_Net(v_dim=17).to(device)

    elif args.method in ['new_40', 'new_40_topk']:
        agent = model.GCNN_Net_new_40(v_dim=19).to(device)

    else:
        agent = None

    instances = [os.path.join(f'../data/instances/{args.ins_type}/{args.ins_config}',
                              f'instance_{i+1}.lp') for i in range(args.start_num, args.num_instance)]
    # print('instances',instances[0])
    if args.method in ["fsb_gcnn", "new_23", "new_26", "new_27"]:
        check_point_path = f'../check_points/{args.ins_type}/baseline/{args.check_point}'
    
    elif args.method in ["new_31", "new_32"]:
        check_point_path = f'../check_points/{args.ins_type}/{args.method}/{args.model_id}/{args.check_point}'

    elif args.method == "new_33":
        check_point_path = [
            f'../check_points/{args.ins_type}/new_33_1/bc_{args.h_model_1}.pt',
            f'../check_points/{args.ins_type}/new_33_2/bc_{args.h_model_2}.pt',
            f'../check_points/{args.ins_type}/new_33_3/bc_{args.h_model_3}.pt',
            f'../check_points/{args.ins_type}/new_33_4/bc_{args.h_model_4}.pt'
        ]

    elif args.method == "new_33_root":
        check_point_path = [
            f'../check_points/{args.ins_type}/new_33_1_root/bc_{args.h_model_1}.pt',
            f'../check_points/{args.ins_type}/new_33_2_root/bc_{args.h_model_2}.pt',
            f'../check_points/{args.ins_type}/new_33_3_root/bc_{args.h_model_3}.pt',
            f'../check_points/{args.ins_type}/new_33_4_root/bc_{args.h_model_4}.pt'
        ]

    elif args.method == "new_39_data":
        check_point_path = f'../check_points/{args.ins_type}/new_39/{args.check_point}'
    elif args.method == "new_39_topk":
        check_point_path = f'../check_points/{args.ins_type}/new_39/{args.check_point}'
    elif args.method == "new_40_topk":
        check_point_path = f'../check_points/{args.ins_type}/new_40/{args.check_point}'
    elif args.method == "new_38_topk":
        check_point_path = f'../check_points/{args.ins_type}/new_38/{args.check_point}'
    elif args.method == "new_35_topk":
        check_point_path = f'../check_points/{args.ins_type}/new_35/{args.check_point}'
    elif args.method == "new_20_topk":
        check_point_path = f'../check_points/{args.ins_type}/new_20/{args.check_point}'
    elif args.method == "new_21_topk":
        check_point_path = f'../check_points/{args.ins_type}/new_21/{args.check_point}'
    elif args.method == "new_22_topk":
        check_point_path = f'../check_points/{args.ins_type}/new_22/{args.check_point}'
    else:
        check_point_path = f'../check_points/{args.ins_type}/{args.method}/{args.check_point}'


    if agent is not None:
        
        if args.method in ["new_33", "new_33_root"]:
            for i in range(len(agent)):
                if os.path.exists(check_point_path[i]):
                    agent[i].load_state_dict(torch.load(check_point_path[i], map_location=device))
                else:
                    print("check_point_path[i]:", check_point_path[i])
            
            print("############# new 33 load state sucess! #####################")
        elif args.check_point is not None:
            agent.load_state_dict(torch.load(check_point_path, map_location=device))
            print("############## load state sucess! ##############")

    else:
        print("Not load model")



    evaluate(agent, instances, result_dir, device, args)
    print('finish solving an instance!')
