# 增加了参数weight

import random
import numpy as np
import torch
import torch.nn.functional as F
import torch.optim as opt
import argparse
import scipy.sparse as sp
import dgl
import pickle, gzip
import sys
import itertools
from pyscipopt import SCIP_PARAMSETTING, Model, Branchrule, Nodesel, SCIP_RESULT
from torch.utils.data import Dataset
from treelib import Tree
sys.path.append('../')
from src.logger import Logger
from src.model import GCNN_Net, PD_Net
import time
logger = Logger.logger


# 将scip模型当前的特征转换成GCNN模型二部图表示，其中变量和约束是两端的节点。
def extract_state_new_41(model, buffer=None):
    """
    计算求解器的二部图表示。在这种表示中，MILP的变量和约束是左右两边的节点，
    如果变量涉及约束，则一条边连接两个节点。节点和边都带有特征。
    Compute a bipartite graph representation of the solver. In this
    representation, the variables and constraints of the MILP are the
    left- and right-hand side nodes, and an edge links two nodes iff the
    variable is involved in the constraint. Both the nodes and edges carry
    features.
    
    Parameters
    ----------
    model : pyscipopt.scip.Model
        The current model.
        当前的scip模型
    buffer : dict
        A buffer to avoid re-extracting redundant information from the solver
        each time.
        缓冲区，以避免每次从解算器中重新提取冗余信息。
    Returns
    -------
    variable_features : dictionary of type {'names': list, 'values': np.ndarray}
        The features associated with the variable nodes in the bipartite graph.
        二部图中与变量节点相关的特征。
        这是一个稀疏矩阵在COO格式给出。
    edge_features : dictionary of type ('names': list, 'indices': np.ndarray, 'values': np.ndarray}
        The features associated with the edges in the bipartite graph.
        This is given as a sparse matrix in COO format.
        二部图中与边有关的特征。这是一个在COO格式的稀疏矩阵。
    constraint_features : dictionary of type {'names': list, 'values': np.ndarray}
        The features associated with the constraint nodes in the bipartite graph.
        二部图中与约束节点相关的特征。
    """
    if buffer is None or model.getNNodes() == 1:
        buffer = {}

    # 从缓冲区中更新状态
    # update state from buffer if any
    s = model.getState(buffer['scip_state'] if 'scip_state' in buffer else None)


    buffer['scip_state'] = s

    if 'state' in buffer:
        obj_norm = buffer['state']['obj_norm']
    else:
        obj_norm = np.linalg.norm(s['col']['coefs'])
        obj_norm = 1 if obj_norm <= 0 else obj_norm

    row_norms = s['row']['norms']
    row_norms[row_norms == 0] = 1

    # 列的功能 -> 变量特征
    # Column features
    n_cols = len(s['col']['types'])

    if 'state' in buffer:
        col_feats = buffer['state']['col_feats']
    else:
        col_feats = {}
        col_feats['type'] = np.zeros((n_cols, 4))  # BINARY INTEGER IMPLINT CONTINUOUS
        col_feats['type'][np.arange(n_cols), s['col']['types']] = 1
        col_feats['coef_normalized'] = s['col']['coefs'].reshape(-1, 1) / obj_norm

    col_feats['has_lb'] = ~np.isnan(s['col']['lbs']).reshape(-1, 1)
    col_feats['has_ub'] = ~np.isnan(s['col']['ubs']).reshape(-1, 1)
    
    col_feats['sol_is_at_lb'] = s['col']['sol_is_at_lb'].reshape(-1, 1)
    col_feats['sol_is_at_ub'] = s['col']['sol_is_at_ub'].reshape(-1, 1)
    col_feats['sol_frac'] = s['col']['solfracs'].reshape(-1, 1)
    col_feats['sol_frac'][s['col']['types'] == 3] = 0  # continuous have no fractionality
    col_feats['basis_status'] = np.zeros((n_cols, 4))  # LOWER BASIC UPPER ZERO
    col_feats['basis_status'][np.arange(n_cols), s['col']['basestats']] = 1
    col_feats['reduced_cost'] = s['col']['redcosts'].reshape(-1, 1) / obj_norm
    col_feats['age'] = s['col']['ages'].reshape(-1, 1) / (s['stats']['nlps'] + 5)
    col_feats['sol_val'] = s['col']['solvals'].reshape(-1, 1)
    col_feats['inc_val'] = s['col']['incvals'].reshape(-1, 1)
    col_feats['avg_inc_val'] = s['col']['avgincvals'].reshape(-1, 1)


    col_feat_names = [[k, ] if v.shape[1] == 1 else [f'{k}_{i}' for i in range(v.shape[1])] for k, v in col_feats.items()]
    col_feat_names = [n for names in col_feat_names for n in names]
    col_feat_vals = np.concatenate(list(col_feats.values()), axis=-1)
    # col_feat_vals.shape : 1000,19

    variable_features = {
        'names': col_feat_names,
        'values': col_feat_vals,}

    # 行的功能 -> 约束特征
    # Row features

    if 'state' in buffer:
        row_feats = buffer['state']['row_feats']
        has_lhs = buffer['state']['has_lhs']
        has_rhs = buffer['state']['has_rhs']
    else:
        row_feats = {}
        has_lhs = np.nonzero(~np.isnan(s['row']['lhss']))[0]
        has_rhs = np.nonzero(~np.isnan(s['row']['rhss']))[0]

        row_feats['obj_cosine_similarity'] = np.concatenate((
            -s['row']['objcossims'][has_lhs],
            +s['row']['objcossims'][has_rhs])).reshape(-1, 1)
        
        row_feats['bias'] = np.concatenate((
            -(s['row']['lhss'] / row_norms)[has_lhs],
            +(s['row']['rhss'] / row_norms)[has_rhs])).reshape(-1, 1)

    row_feats['is_tight'] = np.concatenate((
        s['row']['is_at_lhs'][has_lhs],
        s['row']['is_at_rhs'][has_rhs])).reshape(-1, 1)

    row_feats['age'] = np.concatenate((
        s['row']['ages'][has_lhs],
        s['row']['ages'][has_rhs])).reshape(-1, 1) / (s['stats']['nlps'] + 5)

    # todo: check whether this change helps
    tmp = s['row']['dualsols'] / (row_norms * obj_norm)
    # tmp = s['row']['dualsols'] * row_norms / obj_norm
    row_feats['dualsol_val_normalized'] = np.concatenate((
            -tmp[has_lhs],
            +tmp[has_rhs])).reshape(-1, 1)

    row_feat_names = [[k, ] if v.shape[1] == 1 else [f'{k}_{i}' for i in range(v.shape[1])] for k, v in row_feats.items()]
    row_feat_names = [n for names in row_feat_names for n in names]
    row_feat_vals = np.concatenate(list(row_feats.values()), axis=-1)

    constraint_features = {
        'names': row_feat_names,
        'values': row_feat_vals,}

    # 边特征
    # Edge features
    if 'state' in buffer:
        edge_row_idxs = buffer['state']['edge_row_idxs']
        edge_col_idxs = buffer['state']['edge_col_idxs']
        edge_feats = buffer['state']['edge_feats']
    else:
        # 权重/范数
        coef_matrix = sp.csr_matrix(
            (s['nzrcoef']['vals'] / row_norms[s['nzrcoef']['rowidxs']],
            (s['nzrcoef']['rowidxs'], s['nzrcoef']['colidxs'])),
            shape=(len(s['row']['nnzrs']), len(s['col']['types'])))
        coef_matrix = sp.vstack((
            -coef_matrix[has_lhs, :],
            coef_matrix[has_rhs, :])).tocoo(copy=False)

        edge_row_idxs, edge_col_idxs = coef_matrix.row, coef_matrix.col
        edge_feats = {}

        edge_feats['coef_normalized'] = coef_matrix.data.reshape(-1, 1)

    edge_feat_names = [[k, ] if v.shape[1] == 1 else [f'{k}_{i}' for i in range(v.shape[1])] for k, v in edge_feats.items()]
    edge_feat_names = [n for names in edge_feat_names for n in names]
    edge_feat_indices = np.vstack([edge_row_idxs, edge_col_idxs])
    edge_feat_vals = np.concatenate(list(edge_feats.values()), axis=-1)


    edge_features = {
        'names': edge_feat_names,
        'indices': edge_feat_indices,
        'values': edge_feat_vals,}

    if 'state' not in buffer:
        buffer['state'] = {
            'obj_norm': obj_norm,
            'col_feats': col_feats,
            'row_feats': row_feats,
            'has_lhs': has_lhs,
            'has_rhs': has_rhs,
            'edge_row_idxs': edge_row_idxs,
            'edge_col_idxs': edge_col_idxs,
            'edge_feats': edge_feats,
        }

    return constraint_features, edge_features, variable_features

# 将scip当中提出的二部图使用dgl库构建基本的异构图。
def graph_transform_new_41(state):
    """
    :param state: features extracted from scip
    参数状态:从scip中提取的特征
    :return: dgl heterograph graph
    返回:DGL异质图
    """

    idx = state[1]['indices']
    # make sure that feature shape matches the graph
    # 确保特征形状与图形匹配
    v2c_index = (idx[1, :], idx[0, :])
    c2v_index = (idx[0, :], idx[1, :])

    # 获得边特征、约束特征、变量特征。
    edge_feats = torch.tensor(state[1]['values'], dtype=torch.float).view(-1, 1)
    c_feats = torch.tensor(state[0]['values'], dtype=torch.float)
    v_feats = torch.tensor(state[2]['values'][:,:17], dtype=torch.float)
    # 获得约束和变量的索引
    num_nodes_dict = {'c': c_feats.shape[0], 'v': v_feats.shape[0]}
    
    # 构建异构图，一个异构图由一系列子图构成，一个子图对应一种关系。
    # 每个关系由一个字符串三元组定义（源节点类型，边类型，目标节点类型）。
    graph = dgl.heterograph({
        ('v', 'v2c', 'c'): v2c_index,
        ('c', 'c2v', 'v'): c2v_index,
    }, num_nodes_dict)

    # 给异构图赋值
    graph.edges['v2c'].data['h'] = edge_feats
    graph.edges['c2v'].data['h'] = edge_feats
    graph.nodes['v'].data['h'] = v_feats
    graph.nodes['c'].data['h'] = c_feats

    # 返回异构图
    return graph

# 加载分支数据
class BranchDataset_new_41(Dataset):
    def __init__(self, dirs):
        super().__init__()
        self.data = []
        for dir in dirs:
            sample = pickle.load(gzip.open(dir, 'rb'))['data']
            depth = pickle.load(gzip.open(dir, 'rb'))['node_depth']

            gcn_state, bestcand, action_set, scores = sample

            data = gcn_state
            label = bestcand
            candidates = action_set
            scores = scores
            self.data.append((data, candidates, label, scores, depth))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

# 整理dgl信息， 
def dgl_collate_new_41(batch):
    data, candidates, label, scores, depth = map(list, zip(*batch))
    graphs = []

    for d in data:
        graphs.append(graph_transform_new_41(d))
    batched_graph = dgl.batch(graphs)
    label = torch.LongTensor(label)
    depth = torch.FloatTensor(depth)
    return batched_graph, candidates, label, scores, depth


def test_new_41(args, models, testData, optimizer, epoch, device):
    for model in models:
        model.eval() 
    batch_loss = 0
    correct_count = 0
    total_count = 0
    hard_correct_count = 0
    acc_dict = {
        1 : 0.0,
        3 : 0.0,
        5 : 0.0,
        10 : 0.0
    }

    def model_select(now_depth, models, bg, index_count, epoch):
        # print(f"index_count:{index_count} now_depth:{now_depth}")

        if epoch == 0:
            if now_depth <= 1:
                return models[0].forward(bg)
            elif now_depth <= 4:
                return models[1].forward(bg)
            elif now_depth <= 7:
                return models[2].forward(bg)
            else:
                return models[3].forward(bg)
        
        elif epoch == 1:
            return models[0].forward(bg)
        elif epoch == 2:
            return models[1].forward(bg)
        elif epoch == 3:
            return models[2].forward(bg)
        elif epoch == 4:
            return models[3].forward(bg)  

    index_count = 0
    with torch.no_grad():
        for bg, candidates, label, scores, depth in testData:
            optimizer.zero_grad()
            bg = bg.to(device)
            label = label.to(device)
            depth = depth.to(device)

            count = len(label)
            total_count += count

            index_count = index_count + 1
            bg = model_select(depth, models, bg, index_count, epoch)

            graphs = dgl.unbatch(bg)
            logits = [g.nodes['v'].data['s'][candidates[i]] for i, g in enumerate(graphs)]

            pred = [logits[i].argmax(dim=0) for i in range(count)]

            best_scores = [max(scores[i]) for i in range(count)]
            top_k = [1, 3, 5, 10]
            for k in top_k:
                topk_index = [logits[i].squeeze().topk(min(k,len(logits[i].squeeze())))[1].tolist() for i in range(count)]

                pred_top_k_true_scores = [[scores[i][j] for j in topk_index[i]] for i in range(count)]
                acc_dict[k] += sum([max(pred_top_k_true_scores[i])==best_scores[i] for i in range(count)])
            
            # loss = sum([F.cross_entropy(logits[i].T, label[i:i+1]) for i in range(count)]) / count
            loss = sum([F.cross_entropy(logits[i].T, label[i:i+1]) for i in range(count)]) / count


            batch_loss += loss.item() * count

    loss = batch_loss / total_count
    logger.info('Test epoch {}:\t Test Loss: {:.6f} acc@1: {:.4f} acc@3: {:.4f} acc@5: {:.4f} acc@10: {:.4f}'.format(epoch + 1, loss,
    acc_dict[1]/total_count, acc_dict[3]/total_count, acc_dict[5]/total_count, acc_dict[10]/total_count))
    logger.info('Test epoch {}:\t Test Loss: {:.6f} acc: {:.4f}\t{:.4f}\t{:.4f}\t{:.4f}'.format(epoch + 1, loss,
    acc_dict[1]/total_count, acc_dict[3]/total_count, acc_dict[5]/total_count, acc_dict[10]/total_count))
    return loss, acc_dict[1]/total_count, acc_dict[3]/total_count
