import random
import numpy as np
import torch
import argparse
import scipy.sparse as sp
import dgl
from datetime import datetime
import pickle, gzip
import sys
import itertools
from pyscipopt import SCIP_PARAMSETTING, Model, Branchrule, Nodesel, SCIP_RESULT
from torch.utils.data import Dataset
from treelib import Tree
from .utils_new_0 import *
from .utils_new_1 import *
from .utils_new_3 import *
from .utils_new_3_0 import *
from .utils_new_3_1 import *
from .utils_new_3_2 import *
from .utils_new_3_root import *
from .utils_new_4 import *
from .utils_new_20 import *
from .utils_new_21 import *
from .utils_new_22 import *
from .utils_new_28 import *
from .utils_new_29 import *
from .utils_new_30 import *
from .utils_new_31 import *
from .utils_new_32 import *
from .utils_new_33 import *
from .utils_new_33_1 import *
from .utils_new_33_2 import *
from .utils_new_33_3 import *
from .utils_new_33_4 import *
from .utils_new_34 import *
from .utils_new_35 import *
from .utils_new_36 import *
from .utils_new_37 import *
from .utils_new_37_1 import *
from .utils_new_38 import *
from .utils_new_39 import *
from .utils_new_40 import *
from .utils_new_41 import *

sys.path.append('../')
from src.logger import Logger
from src.model import GCNN_Net, PD_Net
from matplotlib import pyplot as plt

import time
logger = Logger.logger


flag = "gcnn"

# 设置numpy、torch的随机种子
def set_device_seed(args):
    torch.set_num_threads(1)
    # device = torch.device('cpu')
    if args.device_id < 0:
        device = torch.device('cpu')
    else:
        device = torch.device(f"cuda:{args.device_id}" if torch.cuda.is_available() else "cpu")
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    return device

# 检查种子是否为有效的随机种子。
def valid_seed(seed):
    """Check whether seed is a valid random seed or not."""
    seed = int(seed)
    if seed < 0 or seed > 2**32 - 1:
        raise argparse.ArgumentTypeError("seed must be any integer between 0 and 2**32 - 1 inclusive")
    return seed

# 设置scip库
def set_scip(model, seed,
             random=True,
             presolver=True,
             separator=True,
             separator_root=True,
             propagator=True,
             restart=True,
             primal_heuristic=True,
             conflict=True):

    seed = seed % 2147483648  # SCIP seed range
    # randommization
    if random:
        model.setBoolParam('randomization/permutevars', True)
        model.setIntParam('randomization/permutationseed', seed)
        model.setIntParam('randomization/randomseedshift', seed)

    if not separator:
        model.setIntParam('separating/maxrounds', 0)
    if not separator_root:
        model.setIntParam('separating/maxroundsroot', 0)

    if not restart:
        model.setIntParam('presolving/maxrestarts', 0)

    # if asked, disable presolving
    if not presolver:
        model.setIntParam('presolving/maxrounds', 0)

    # if asked, disable conflict analysis (more cuts)
    if not conflict:
        model.setBoolParam('conflict/enable', False)

    # if asked, disable primal heuristics
    if not primal_heuristic:
        model.setHeuristics(SCIP_PARAMSETTING.OFF)

    # if asked, disable propagation heuristics
    if not propagator:
        model.disablePropagation(False)

# 将scip模型当前的特征转换成GCNN模型二部图表示，其中变量和约束是两端的节点。
def extract_state(model, buffer=None):
    """
    计算求解器的二部图表示。在这种表示中，MILP的变量和约束是左右两边的节点，
    如果变量涉及约束，则一条边连接两个节点。节点和边都带有特征。
    Compute a bipartite graph representation of the solver. In this
    representation, the variables and constraints of the MILP are the
    left- and right-hand side nodes, and an edge links two nodes iff the
    variable is involved in the constraint. Both the nodes and edges carry
    features.
    
    Parameters
    ----------
    model : pyscipopt.scip.Model
        The current model.
        当前的scip模型
    buffer : dict
        A buffer to avoid re-extracting redundant information from the solver
        each time.
        缓冲区，以避免每次从解算器中重新提取冗余信息。
    Returns
    -------
    variable_features : dictionary of type {'names': list, 'values': np.ndarray}
        The features associated with the variable nodes in the bipartite graph.
        二部图中与变量节点相关的特征。
        这是一个稀疏矩阵在COO格式给出。
    edge_features : dictionary of type ('names': list, 'indices': np.ndarray, 'values': np.ndarray}
        The features associated with the edges in the bipartite graph.
        This is given as a sparse matrix in COO format.
        二部图中与边有关的特征。这是一个在COO格式的稀疏矩阵。
    constraint_features : dictionary of type {'names': list, 'values': np.ndarray}
        The features associated with the constraint nodes in the bipartite graph.
        二部图中与约束节点相关的特征。
    """
    if buffer is None or model.getNNodes() == 1:
        buffer = {}

    # 从缓冲区中更新状态
    # update state from buffer if any
    s = model.getState(buffer['scip_state'] if 'scip_state' in buffer else None)


    buffer['scip_state'] = s

    if 'state' in buffer:
        obj_norm = buffer['state']['obj_norm']
    else:
        obj_norm = np.linalg.norm(s['col']['coefs'])
        obj_norm = 1 if obj_norm <= 0 else obj_norm

    row_norms = s['row']['norms']
    row_norms[row_norms == 0] = 1

    # 列的功能 -> 变量特征
    # Column features
    n_cols = len(s['col']['types'])

    if 'state' in buffer:
        col_feats = buffer['state']['col_feats']
    else:
        col_feats = {}
        col_feats['type'] = np.zeros((n_cols, 4))  # BINARY INTEGER IMPLINT CONTINUOUS
        col_feats['type'][np.arange(n_cols), s['col']['types']] = 1
        col_feats['coef_normalized'] = s['col']['coefs'].reshape(-1, 1) / obj_norm

    col_feats['has_lb'] = ~np.isnan(s['col']['lbs']).reshape(-1, 1)
    col_feats['has_ub'] = ~np.isnan(s['col']['ubs']).reshape(-1, 1)
    
    col_feats['sol_is_at_lb'] = s['col']['sol_is_at_lb'].reshape(-1, 1)
    col_feats['sol_is_at_ub'] = s['col']['sol_is_at_ub'].reshape(-1, 1)
    col_feats['sol_frac'] = s['col']['solfracs'].reshape(-1, 1)
    col_feats['sol_frac'][s['col']['types'] == 3] = 0  # continuous have no fractionality
    col_feats['basis_status'] = np.zeros((n_cols, 4))  # LOWER BASIC UPPER ZERO
    col_feats['basis_status'][np.arange(n_cols), s['col']['basestats']] = 1
    col_feats['reduced_cost'] = s['col']['redcosts'].reshape(-1, 1) / obj_norm
    col_feats['age'] = s['col']['ages'].reshape(-1, 1) / (s['stats']['nlps'] + 5)
    col_feats['sol_val'] = s['col']['solvals'].reshape(-1, 1)
    col_feats['inc_val'] = s['col']['incvals'].reshape(-1, 1)
    col_feats['avg_inc_val'] = s['col']['avgincvals'].reshape(-1, 1)


    # print("col_feats.key():",col_feats.keys())
    # for key in col_feats.keys():
    #     print("col_feats.",key,".shape:",col_feats[key].shape)

    #keys:'type', 'coef_normalized', 'has_lb', 'has_ub', 'sol_is_at_lb', 'sol_is_at_ub', 'sol_frac', 'basis_status', 'reduced_cost', 'age', 'sol_val', 'inc_val', 'avg_inc_val']
    # 4  col_feats. type .shape: (1000, 4)
    # 5  col_feats. coef_normalized .shape: (1000, 1)
    # 6  col_feats. has_lb .shape: (1000, 1)
    # 7  col_feats. has_ub .shape: (1000, 1)
    # 8  col_feats. sol_is_at_lb .shape: (1000, 1)
    # 9  col_feats. sol_is_at_ub .shape: (1000, 1)
    # 10 col_feats. sol_frac .shape: (1000, 1)
    # 14 col_feats. basis_status .shape: (1000, 4)
    # 15 col_feats. reduced_cost .shape: (1000, 1)
    # 16 col_feats. age .shape: (1000, 1)
    # 17 col_feats. sol_val .shape: (1000, 1)
    # 18 col_feats. inc_val .shape: (1000, 1)
    # 19 col_feats. avg_inc_val .shape: (1000, 1)
    



    col_feat_names = [[k, ] if v.shape[1] == 1 else [f'{k}_{i}' for i in range(v.shape[1])] for k, v in col_feats.items()]
    col_feat_names = [n for names in col_feat_names for n in names]
    col_feat_vals = np.concatenate(list(col_feats.values()), axis=-1)
    # col_feat_vals.shape : 1000,19

    variable_features = {
        'names': col_feat_names,
        'values': col_feat_vals,}

    # 行的功能 -> 约束特征
    # Row features

    if 'state' in buffer:
        row_feats = buffer['state']['row_feats']
        has_lhs = buffer['state']['has_lhs']
        has_rhs = buffer['state']['has_rhs']
    else:
        row_feats = {}
        has_lhs = np.nonzero(~np.isnan(s['row']['lhss']))[0]
        has_rhs = np.nonzero(~np.isnan(s['row']['rhss']))[0]

        row_feats['obj_cosine_similarity'] = np.concatenate((
            -s['row']['objcossims'][has_lhs],
            +s['row']['objcossims'][has_rhs])).reshape(-1, 1)
        
        row_feats['bias'] = np.concatenate((
            -(s['row']['lhss'] / row_norms)[has_lhs],
            +(s['row']['rhss'] / row_norms)[has_rhs])).reshape(-1, 1)

    row_feats['is_tight'] = np.concatenate((
        s['row']['is_at_lhs'][has_lhs],
        s['row']['is_at_rhs'][has_rhs])).reshape(-1, 1)

    row_feats['age'] = np.concatenate((
        s['row']['ages'][has_lhs],
        s['row']['ages'][has_rhs])).reshape(-1, 1) / (s['stats']['nlps'] + 5)

    # todo: check whether this change helps
    tmp = s['row']['dualsols'] / (row_norms * obj_norm)
    # tmp = s['row']['dualsols'] * row_norms / obj_norm
    row_feats['dualsol_val_normalized'] = np.concatenate((
            -tmp[has_lhs],
            +tmp[has_rhs])).reshape(-1, 1)


    # print("row_feats.key():",row_feats.keys())
    # for key in row_feats.keys():
    #     print("row_feats.",key,".shape:",row_feats[key].shape)

    # key(): dict_keys(['obj_cosine_similarity', 'bias', 'is_tight', 'age', 'dualsol_val_normalized'])
    # row_feats. obj_cosine_similarity .shape: (520, 1)
    # row_feats. bias .shape: (520, 1)
    # row_feats. is_tight .shape: (520, 1)
    # row_feats. age .shape: (520, 1)
    # row_feats. dualsol_val_normalized .shape: (520, 1)


    row_feat_names = [[k, ] if v.shape[1] == 1 else [f'{k}_{i}' for i in range(v.shape[1])] for k, v in row_feats.items()]
    row_feat_names = [n for names in row_feat_names for n in names]
    row_feat_vals = np.concatenate(list(row_feats.values()), axis=-1)




    constraint_features = {
        'names': row_feat_names,
        'values': row_feat_vals,}

    # 边特征
    # Edge features
    if 'state' in buffer:
        edge_row_idxs = buffer['state']['edge_row_idxs']
        edge_col_idxs = buffer['state']['edge_col_idxs']
        edge_feats = buffer['state']['edge_feats']
    else:
        # 权重/范数
        coef_matrix = sp.csr_matrix(
            (s['nzrcoef']['vals'] / row_norms[s['nzrcoef']['rowidxs']],
            (s['nzrcoef']['rowidxs'], s['nzrcoef']['colidxs'])),
            shape=(len(s['row']['nnzrs']), len(s['col']['types'])))
        coef_matrix = sp.vstack((
            -coef_matrix[has_lhs, :],
            coef_matrix[has_rhs, :])).tocoo(copy=False)

        edge_row_idxs, edge_col_idxs = coef_matrix.row, coef_matrix.col
        edge_feats = {}

        edge_feats['coef_normalized'] = coef_matrix.data.reshape(-1, 1)

    edge_feat_names = [[k, ] if v.shape[1] == 1 else [f'{k}_{i}' for i in range(v.shape[1])] for k, v in edge_feats.items()]
    edge_feat_names = [n for names in edge_feat_names for n in names]
    edge_feat_indices = np.vstack([edge_row_idxs, edge_col_idxs])
    edge_feat_vals = np.concatenate(list(edge_feats.values()), axis=-1)

    # print("edge_feats.key():",edge_feats.keys())
    # for key in edge_feats.keys():
    #     print("edge_feats.",key,".shape:",edge_feats[key].shape)
    
    # edge_feats.key(): dict_keys(['coef_normalized'])
    # edge_feats. coef_normalized .shape: (34861, 1)
    
    # edge_feat_vals.shape:(34861, 1)

    edge_features = {
        'names': edge_feat_names,
        'indices': edge_feat_indices,
        'values': edge_feat_vals,}

    if 'state' not in buffer:
        buffer['state'] = {
            'obj_norm': obj_norm,
            'col_feats': col_feats,
            'row_feats': row_feats,
            'has_lhs': has_lhs,
            'has_rhs': has_rhs,
            'edge_row_idxs': edge_row_idxs,
            'edge_col_idxs': edge_col_idxs,
            'edge_feats': edge_feats,
        }

    return constraint_features, edge_features, variable_features

def extract_state_with_lb_ub(model, buffer=None):
    """
    提取包含点lower bound和up bound的点特征，比其它的state多一个
    Parameters
    ----------
    model : pyscipopt.scip.Model
        The current model.
        当前的scip模型
    buffer : dict
        A buffer to avoid re-extracting redundant information from the solver
        each time.
        缓冲区，以避免每次从解算器中重新提取冗余信息。
    Returns
    -------
    variable_features : dictionary of type {'names': list, 'values': np.ndarray}
        The features associated with the variable nodes in the bipartite graph.
        二部图中与变量节点相关的特征。
        这是一个稀疏矩阵在COO格式给出。
    edge_features : dictionary of type ('names': list, 'indices': np.ndarray, 'values': np.ndarray}
        The features associated with the edges in the bipartite graph.
        This is given as a sparse matrix in COO format.
        二部图中与边有关的特征。这是一个在COO格式的稀疏矩阵。
    constraint_features : dictionary of type {'names': list, 'values': np.ndarray}
        The features associated with the constraint nodes in the bipartite graph.
        二部图中与约束节点相关的特征。
    """
    if buffer is None or model.getNNodes() == 1:
        buffer = {}

    # 从缓冲区中更新状态
    # update state from buffer if any
    s = model.getState(buffer['scip_state'] if 'scip_state' in buffer else None)


    buffer['scip_state'] = s

    if 'state' in buffer:
        obj_norm = buffer['state']['obj_norm']
    else:
        obj_norm = np.linalg.norm(s['col']['coefs'])
        obj_norm = 1 if obj_norm <= 0 else obj_norm

    row_norms = s['row']['norms']
    row_norms[row_norms == 0] = 1

    # 列的功能 -> 变量特征
    # Column features
    n_cols = len(s['col']['types'])
    bound_feats = {}
    if 'state' in buffer:
        col_feats = buffer['state']['col_feats']
    else:
        col_feats = {}
        col_feats['type'] = np.zeros((n_cols, 4))  # BINARY INTEGER IMPLINT CONTINUOUS
        col_feats['type'][np.arange(n_cols), s['col']['types']] = 1
        col_feats['coef_normalized'] = s['col']['coefs'].reshape(-1, 1) / obj_norm

    col_feats['has_lb'] = ~np.isnan(s['col']['lbs']).reshape(-1, 1)
    col_feats['has_ub'] = ~np.isnan(s['col']['ubs']).reshape(-1, 1)
    bound_feats['lbs'] = s['col']['lbs'].reshape(-1, 1)
    bound_feats['ubs'] = s['col']['ubs'].reshape(-1, 1)
    
    col_feats['sol_is_at_lb'] = s['col']['sol_is_at_lb'].reshape(-1, 1)
    col_feats['sol_is_at_ub'] = s['col']['sol_is_at_ub'].reshape(-1, 1)
    col_feats['sol_frac'] = s['col']['solfracs'].reshape(-1, 1)
    col_feats['sol_frac'][s['col']['types'] == 3] = 0  # continuous have no fractionality
    col_feats['basis_status'] = np.zeros((n_cols, 4))  # LOWER BASIC UPPER ZERO
    col_feats['basis_status'][np.arange(n_cols), s['col']['basestats']] = 1
    col_feats['reduced_cost'] = s['col']['redcosts'].reshape(-1, 1) / obj_norm
    col_feats['age'] = s['col']['ages'].reshape(-1, 1) / (s['stats']['nlps'] + 5)
    col_feats['sol_val'] = s['col']['solvals'].reshape(-1, 1)
    col_feats['inc_val'] = s['col']['incvals'].reshape(-1, 1)
    col_feats['avg_inc_val'] = s['col']['avgincvals'].reshape(-1, 1)


    # print("col_feats.key():",col_feats.keys())
    # for key in col_feats.keys():
    #     print("col_feats.",key,".shape:",col_feats[key].shape)

    #keys:'type', 'coef_normalized', 'has_lb', 'has_ub', 'sol_is_at_lb', 'sol_is_at_ub', 'sol_frac', 'basis_status', 'reduced_cost', 'age', 'sol_val', 'inc_val', 'avg_inc_val']
    # 4  col_feats. type .shape: (1000, 4)
    # 5  col_feats. coef_normalized .shape: (1000, 1)
    # 6  col_feats. has_lb .shape: (1000, 1)
    # 7  col_feats. has_ub .shape: (1000, 1)
    # 8  col_feats. sol_is_at_lb .shape: (1000, 1)
    # 9  col_feats. sol_is_at_ub .shape: (1000, 1)
    # 10 col_feats. sol_frac .shape: (1000, 1)
    # 14 col_feats. basis_status .shape: (1000, 4)
    # 15 col_feats. reduced_cost .shape: (1000, 1)
    # 16 col_feats. age .shape: (1000, 1)
    # 17 col_feats. sol_val .shape: (1000, 1)
    # 18 col_feats. inc_val .shape: (1000, 1)
    # 19 col_feats. avg_inc_val .shape: (1000, 1)
    



    col_feat_names = [[k, ] if v.shape[1] == 1 else [f'{k}_{i}' for i in range(v.shape[1])] for k, v in col_feats.items()]
    col_feat_names = [n for names in col_feat_names for n in names]
    col_feat_vals = np.concatenate(list(col_feats.values()), axis=-1)
    # col_feat_vals.shape : 1000,19

    variable_features = {
        'names': col_feat_names,
        'values': col_feat_vals,}

    # 行的功能 -> 约束特征
    # Row features

    if 'state' in buffer:
        row_feats = buffer['state']['row_feats']
        has_lhs = buffer['state']['has_lhs']
        has_rhs = buffer['state']['has_rhs']
    else:
        row_feats = {}
        has_lhs = np.nonzero(~np.isnan(s['row']['lhss']))[0]
        has_rhs = np.nonzero(~np.isnan(s['row']['rhss']))[0]

        row_feats['obj_cosine_similarity'] = np.concatenate((
            -s['row']['objcossims'][has_lhs],
            +s['row']['objcossims'][has_rhs])).reshape(-1, 1)
        
        row_feats['bias'] = np.concatenate((
            -(s['row']['lhss'] / row_norms)[has_lhs],
            +(s['row']['rhss'] / row_norms)[has_rhs])).reshape(-1, 1)

    row_feats['is_tight'] = np.concatenate((
        s['row']['is_at_lhs'][has_lhs],
        s['row']['is_at_rhs'][has_rhs])).reshape(-1, 1)

    row_feats['age'] = np.concatenate((
        s['row']['ages'][has_lhs],
        s['row']['ages'][has_rhs])).reshape(-1, 1) / (s['stats']['nlps'] + 5)

    # todo: check whether this change helps
    tmp = s['row']['dualsols'] / (row_norms * obj_norm)

    row_feats['dualsol_val_normalized'] = np.concatenate((
            -tmp[has_lhs],
            +tmp[has_rhs])).reshape(-1, 1)


    row_feat_names = [[k, ] if v.shape[1] == 1 else [f'{k}_{i}' for i in range(v.shape[1])] for k, v in row_feats.items()]
    row_feat_names = [n for names in row_feat_names for n in names]
    row_feat_vals = np.concatenate(list(row_feats.values()), axis=-1)




    constraint_features = {
        'names': row_feat_names,
        'values': row_feat_vals,}

    # 边特征
    # Edge features
    if 'state' in buffer:
        edge_row_idxs = buffer['state']['edge_row_idxs']
        edge_col_idxs = buffer['state']['edge_col_idxs']
        edge_feats = buffer['state']['edge_feats']
    else:
        # 权重/范数
        coef_matrix = sp.csr_matrix(
            (s['nzrcoef']['vals'] / row_norms[s['nzrcoef']['rowidxs']],
            (s['nzrcoef']['rowidxs'], s['nzrcoef']['colidxs'])),
            shape=(len(s['row']['nnzrs']), len(s['col']['types'])))
        coef_matrix = sp.vstack((
            -coef_matrix[has_lhs, :],
            coef_matrix[has_rhs, :])).tocoo(copy=False)

        edge_row_idxs, edge_col_idxs = coef_matrix.row, coef_matrix.col
        edge_feats = {}

        edge_feats['coef_normalized'] = coef_matrix.data.reshape(-1, 1)

    edge_feat_names = [[k, ] if v.shape[1] == 1 else [f'{k}_{i}' for i in range(v.shape[1])] for k, v in edge_feats.items()]
    edge_feat_names = [n for names in edge_feat_names for n in names]
    edge_feat_indices = np.vstack([edge_row_idxs, edge_col_idxs])
    edge_feat_vals = np.concatenate(list(edge_feats.values()), axis=-1)

    edge_features = {
        'names': edge_feat_names,
        'indices': edge_feat_indices,
        'values': edge_feat_vals,}

    if 'state' not in buffer:
        buffer['state'] = {
            'obj_norm': obj_norm,
            'col_feats': col_feats,
            'row_feats': row_feats,
            'has_lhs': has_lhs,
            'has_rhs': has_rhs,
            'edge_row_idxs': edge_row_idxs,
            'edge_col_idxs': edge_col_idxs,
            'edge_feats': edge_feats,
        }

    return constraint_features, edge_features, variable_features, bound_feats


# 提取Rank模型的特征，GCNN模型当中好像用过
def extract_khalil_variable_features(model, candidates, root_buffer):
    """
    Extract features following Khalil et al. (2016) Learning to Branch in Mixed Integer Programming.
    根据Khalil等人(2016)提取特征混合整数规划中的分支学习。
    Parameters
    ----------
    model : pyscipopt.scip.Model
        The current model.
    candidates : list of pyscipopt.scip.Variable's
        A list of variables for which to compute the variable features.
    root_buffer : dict
        A buffer to avoid re-extracting redundant root node information (None to deactivate buffering).
    Returns
    -------
    variable_features : 2D np.ndarray
        The features associated with the candidate variables.
    """
    # update state from state_buffer if any
    # 如果有，从state_buffer更新状态
    scip_state = model.getKhalilState(root_buffer, candidates)

    variable_feature_names = sorted(scip_state)
    variable_features = np.stack([scip_state[feature_name] for feature_name in variable_feature_names], axis=1)

    return variable_features

# 将scip当中提出的二部图使用dgl库构建基本的异构图。
def graph_transform(state):
    """
    :param state: features extracted from scip
    参数状态:从scip中提取的特征
    :return: dgl heterograph graph
    返回:DGL异质图
    """

    idx = state[1]['indices']
    # make sure that feature shape matches the graph
    # 确保特征形状与图形匹配
    v2c_index = (idx[1, :], idx[0, :])
    c2v_index = (idx[0, :], idx[1, :])

    # 获得边特征、约束特征、变量特征。
    edge_feats = torch.tensor(state[1]['values'], dtype=torch.float).view(-1, 1)
    c_feats = torch.tensor(state[0]['values'], dtype=torch.float)
    v_feats = torch.tensor(state[2]['values'][:,:17], dtype=torch.float)
    # 获得约束和变量的索引
    num_nodes_dict = {'c': c_feats.shape[0], 'v': v_feats.shape[0]}
    
    # 构建异构图，一个异构图由一系列子图构成，一个子图对应一种关系。
    # 每个关系由一个字符串三元组定义（源节点类型，边类型，目标节点类型）。
    graph = dgl.heterograph({
        ('v', 'v2c', 'c'): v2c_index,
        ('c', 'c2v', 'v'): c2v_index,
    }, num_nodes_dict)

    # 给异构图赋值
    graph.edges['v2c'].data['h'] = edge_feats
    graph.edges['c2v'].data['h'] = edge_feats
    graph.nodes['v'].data['h'] = v_feats
    graph.nodes['c'].data['h'] = c_feats

    # 返回异构图
    return graph


# 加载分支数据
class BranchDataset(Dataset):
    def __init__(self, dirs):
        super().__init__()
        self.data = []
        for dir in dirs:
            sample = pickle.load(gzip.open(dir, 'rb'))['data']
            gcn_state, bestcand, action_set, scores = sample
            data = gcn_state
            label = bestcand
            candidates = action_set
            scores = scores
            self.data.append((data, candidates, label, scores))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]


# 整理dgl信息， 
def dgl_collate(batch):
    data, candidates, label, scores = map(list, zip(*batch))
    graphs = []

    for d in data:
        graphs.append(graph_transform(d))
    batched_graph = dgl.batch(graphs)
    label = torch.LongTensor(label)
    return batched_graph, candidates, label, scores


# 定义scip的环境
class scip_env():
    # init初始化
    def __init__(self, dirs, vals, batch_size, scip_seed, seed, timelimit):
        self.batch_size = batch_size
        self.scip_seed = scip_seed
        self.dirs = dirs
        self.vals = vals
        self.rng = np.random.RandomState(seed)
        self.timelimit = timelimit
        self.reset()

    # len函数
    def __len__(self):
        return self.batch_size

    # 遍历函数
    def __iter__(self):
        for i in range(self.batch_size):
            yield self.load(self.list[i])
    
    # 加载函数，
    def load(self, dir):
        model = Model()
        model.hideOutput()
        model.readProblem(dir)
        val = self.vals[dir]
        if model.getObjectiveSense() == 'minimize':
            model.setObjlimit(val + 1e-3)
        else:
            model.setObjlimit(val - 1e-3)
        set_scip(model, self.scip_seed, restart=False, separator=False, primal_heuristic=False)
        model.setRealParam('limits/time', self.timelimit)
        return model

    def reset(self):
        indices = self.rng.choice(len(self.dirs), self.batch_size, replace=False)
        self.list = [self.dirs[idx] for idx in indices]

# 定义scip分支规则?
class scip_brancher(Branchrule):
    # 初始化 分支策略 以及 计数count
    def __init__(self, policy, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.policy = policy
        self.count = 0
    # 初始化 branchinitsol
    def branchinitsol(self):
        self.state_buffer = {}
        
    # 记录branch次数吗
    def branchexeclp(self, allowaddcons):
        self.count += 1

        cands, *_ = self.model.getPseudoBranchCands()
        state = extract_state_new_3(self.model, self.state_buffer)
        g = graph_transform(state)
        action_set = [var.getCol().getLPPos() for var in cands]
        with torch.no_grad():
            bestcand = self.policy.select_action(g, action_set)

        if self.count == 1:
             self.state_buffer = {}

        self.model.branchVar(cands[bestcand])
        result = SCIP_RESULT.BRANCHED

        return {'result': result}

# scip内置的选择策略
class scip_nodesel(Nodesel):
    def __init__(self):
        super().__init__()
        self.tree = Tree()

    def nodeselect(self):
        listOfNodes = list(itertools.chain.from_iterable(self.model.getOpenNodes()))
        if len(listOfNodes) == 0:
            return {}
        optimal_node = listOfNodes[0]
        if self.model.getCurrentNode() is None:
            # if no node has been branched, the progress is at root node
            # 如果没有节点被分支，则进度在根节点
            assert len(listOfNodes) == 1
            root_node = listOfNodes[0]
            # the root node shows the number of variables in instance
            # 根节点显示实例中变量的数量
            self.tree.create_node('root', root_node.getNumber(), data={'num': len(self.model.getVars()),
                                                                       'vars': [], 'bounds': [], 'types': []})
        else:
            for node in listOfNodes[1:]:
                if self.nodecomp(node, optimal_node) <= 0:
                    optimal_node = node
            vars, bounds, types = optimal_node.getParentBranchings()
            vars = [var.getCol().getLPPos() for var in vars]
            parent_node = optimal_node.getParent()
            # the child node shows the bound change due to the branching
            self.tree.create_node('node', optimal_node.getNumber(), parent=parent_node.getNumber(),
                                  data={'vars': vars, 'bounds': bounds, 'types': types})
        return {'selnode': optimal_node}

    def nodecomp(self, node1, node2):
        '''
        比较当前分枝树的两个叶节点,返回值如下:
        值<0,如果节点1优于节点2
        值=0,如果两个节点同样好,
        值>0,如果节点1差于节点2
        compare two leaves of the current branching tree
        It should return the following values:
        value < 0, if node 1 comes before (is better than) node 2
        value = 0, if both nodes are equally good
        value > 0, if node 1 comes after (is worse than) node 2.
        '''
        depth1 = node1.getDepth()
        depth2 = node2.getDepth()
        if depth1 > depth2:
            return -1
        elif depth1 < depth2:
            return 1
        else:
            lb1 = node1.getLowerbound()
            lb2 = node2.getLowerbound()
            if lb1 < lb2:
                return -1
            elif lb1 > lb2:
                return 1
            else:
                return 0

# 代理类
class BranchAgent():
    def __init__(self, lr=1e-3, device=None, check_point=None):
        super().__init__()
        self.lr = lr
        if device is None:
            self.device = torch.device('cpu')
        else:
            self.device = device


        # ?啊?为啥是GCNN_Net()
        if flag=="gcnn":
            self.policy = GCNN_Net()
        else: 
            self.policy = PD_Net(T=2)

        if check_point:
            logger.info('load check points: {}'.format(check_point))
            self.policy.load_state_dict(torch.load(check_point, map_location=self.device))

    # 选择分支策略
    def select_action(self, state, action_set):
        self.policy.eval()
        with torch.no_grad():
            if flag=="gcnn":
                temp_logits_1 = self.policy(state)
                temp_logits_2 = temp_logits_1.nodes['v'].data['s']
                logits = temp_logits_2[action_set]
            else:
                logits = self.policy(state)[action_set]

        return logits.argmax(dim=0).item()

    # 获取模型的参数
    def policy_params(self):
        """
        ES应该训练的参数(全部)
        The params that should be trained by ES (all of them)
        """
        return [(k, v) for k, v in zip(self.policy.state_dict().keys(),
                                       self.policy.state_dict().values())]

    # 求解问题，此时应该已经加载了问题
    def solve(self, model):
        # 首先是加载分支选择策略
        branch_policy = scip_brancher(self)
        model.includeBranchrule(branch_policy, "Evaluate", "Policy branching on variable",
                                priority=99999, maxdepth=-1, maxbounddist=1)
        # 然后加载节点选择策略
        nodesel_policy = scip_nodesel()
        model.includeNodesel(nodesel_policy, "Evaluate", "Policy node selection on nodes",
                             1073741823, 536870911)
        # 求解问题
        model.optimize()
        # 返回结果
        return model.getNNodes(), nodesel_policy.tree


def get_loss_img(epochs, acc_list, loss_list, shrink_epochs, args):
    
    if args.ins_type == "setcover_400r_1000c_0.05d_100mc_0se":
        ins_type = 1
        img_name = f"loss_imgs/1.setcover/{ins_type}_{args.code_id}_{args.loss}_{args.lr}_{args.shrink_lr}"

    elif args.ins_type == "cauctions_0se":
        ins_type = 2
        img_name = f"loss_imgs/2.cauctions/{ins_type}_{args.code_id}_{args.loss}_{args.lr}_{args.shrink_lr}"

    elif args.ins_type == "facility_0se":
        ins_type = 3
        img_name = f"loss_imgs/3.facility/{ins_type}_{args.code_id}_{args.loss}_{args.lr}_{args.shrink_lr}"

    elif args.ins_type == "indset_400n_4a_0se":
        ins_type = 4
        img_name = f"loss_imgs/4.indset/{ins_type}_{args.code_id}_{args.loss}_{args.lr}_{args.shrink_lr}"

    elif args.ins_type == "gisp":
        ins_type = 5
        img_name = f"loss_imgs/5.gisp/{ins_type}_{args.code_id}_{args.loss}_{args.lr}_{args.shrink_lr}"

    elif args.ins_type == "wpms":
        ins_type = 6
        img_name = f"loss_imgs/6.wpms/{ins_type}_{args.code_id}_{args.loss}_{args.lr}_{args.shrink_lr}"

    elif args.ins_type == "fcmcnf":
        ins_type = 7
        img_name = f"loss_imgs/7.fcmcnf/{ins_type}_{args.code_id}_{args.loss}_{args.lr}_{args.shrink_lr}"

    
    # acc图像
    plt.subplot(2, 1, 1)
    plt.plot(epochs, acc_list, '.-')
    for epoch in shrink_epochs:
        plt.axvline(x=epoch, color='gray', linestyle='--')

    max_acc_index = acc_list.index(max(acc_list))

    current_datetime = datetime.now()
    year = current_datetime.year
    month = current_datetime.month
    day = current_datetime.day

    plt.title(f'{year}-{month}-{day}:model accuracy is {epochs[max_acc_index]+1} epoch:{acc_list[max_acc_index]} acc')
    plt.ylabel('accuracy unit:%')
    

    # loss图像
    plt.subplot(2, 1, 2)
    plt.plot(epochs, loss_list, '.-')
    for epoch in shrink_epochs:
        plt.axvline(x=epoch, color='gray', linestyle='--')

    plt.xlabel('model loss')
    plt.ylabel('loss')

    plt.savefig(f"{img_name}.jpg")
    # plt.show()
