import argparse
import pulp
from ast import arg
import sys, time
import os
import pandas as pd
import pickle
import gzip
# import torch
from pyscipopt import Model, Branchrule, SCIP_RESULT
sys.path.append('../')
# import  src.utils as utils
# import src.model as model
# from src.utils import set_device_seed, set_scip
# from src.logger import Logger
# logger = Logger.logger
# import shutil
# from collections import deque,OrderedDict
# import dgl
# import heapq
import numpy as np
from multiprocessing.dummy import Pool as ThreadPool
# import time
import copy
# import pickle
# from datetime import datetime
# from scipy.optimize import linprog
# from concurrent.futures import ProcessPoolExecutor
# from functools import partial
# import multiprocessing
import random

def copy_data_new_20(data, k = 1):
    change_pair_list = []
    sample = data['data']
    episode = data['episode']
    instance = data['instance']
    seed = data['seed']
    node_number = data['node_number']
    node_depth = data['node_depth']
    max_depth = data['max_depth']
    gcn_state, bestcand, action_set, scores, l1, l2 = sample
    
    #constraint_features, edge_features, variable_features
    print(gcn_state[2].keys())
    c_feats = gcn_state[0]
    edge_feats = gcn_state[1]
    v_feats = gcn_state[2]['values']
    type_dict = {0:[],1:[], 2:[],3:[]}#BINARY INTEGER IMPLINT CONTINUOUS
    for j in range(len(v_feats)):
        if (v_feats[j][0]==1.0):
            type_dict[0].append(j)
        if (v_feats[j][1]==1.0):
            type_dict[1].append(j)
        if (v_feats[j][2]==1.0):
            type_dict[2].append(j)
        if (v_feats[j][3]==1.0):
            type_dict[3].append(j)
    change = -1
    if (len(type_dict[3])>2):
        change = 3
    else:
        if (len(type_dict[1])>2):
            change = 1
        else:
            if (len(type_dict[0])>2):
                change = 0
            else:
                return 0
    new_data_list = []
    for i in range(k):
        new_data = {'episode':episode, 'instance':instance, 'seed':seed, 'node_depth':node_depth, 'max_depth':max_depth, 'node_depth':node_depth}
        new_scores = scores
        new_l1 = l1
        new_l2 = l2
        selected = random.sample(type_dict[change], 2)
        x_1, x_2 = selected[0], selected[1]
        if (v_feats[x_2][4] > v_feats[x_1][4]):#确保x_2的系数小于x_1的，这样用x1+x2替换x2
            x_3 = x_1
            x_1 = x_2
            x_2 = x_3
        # elif (v_feats_new[x_2][4] == v_feats_new[x_1][4]):
        #     continue
        v_feats_new = copy.deepcopy(v_feats)
        c_feats_new = copy.deepcopy(c_feats)
        edge_feats_new = copy.deepcopy(edge_feats)
        if ((x_1,x_2) in change_pair_list or x_2==action_set[bestcand]):
            continue
        else:
            change_pair_list.append((x_1,x_2))
        #只更改涉及x_2的，令x_2 = x_1 + x_2
        #type更改：修改 change为3则不动， 为1也不动，为0变成1，
        if (change==0):
            v_feats_new[x_2][0] = 0
            v_feats_new[x_2][1] = 1.0
        #改目标函数里的系数
        v_feats_new[x_1][4] = v_feats_new[x_1][4] - v_feats_new[x_2][4] #没有正则化
        #has_lb, has_ub
        v_feats_new[x_2][5] = min(v_feats_new[x_1][5], v_feats_new[x_2][5])
        v_feats_new[x_2][6] = min(v_feats_new[x_1][6], v_feats_new[x_2][6])
        #sol_is_at_lb, sol_is_at_ub
        v_feats_new[x_2][7] = min(v_feats_new[x_1][7], v_feats_new[x_2][7])
        v_feats_new[x_2][8] = min(v_feats_new[x_1][8], v_feats_new[x_2][8])
        #sol_frac
        #表示一个解的 可行性松弛度，它量化了当前解距离最优可行解的远近。
        # 具体而言，它是可行区域内的一个分数值，越接近于 1，表示当前解越接近于一个可行解，越接近于 0，表示解不太接近于可行解。
        v_feats_new[x_2][9] = (v_feats_new[x_1][9] + v_feats_new[x_2][9])/2
        if (v_feats_new[x_2][1] == 1.0):#连续的为0
            v_feats_new[x_2][9] = 0
        
        #'basis_status_0', 'basis_status_1', 'basis_status_2', 'basis_status_3'LOWER BASIC UPPER ZERO
        #10, 11, 12, 13
        if ((v_feats_new[x_2][13] == v_feats_new[x_1][13]) and (v_feats_new[x_2][12] == v_feats_new[x_1][12]) and
            (v_feats_new[x_2][11] == v_feats_new[x_1][11]) and (v_feats_new[x_2][10] == v_feats_new[x_1][10])):
            pass
        else:
            v_feats_new[x_2][11] = 1.0
            v_feats_new[x_2][10] = 0
            v_feats_new[x_2][12] = 0
            v_feats_new[x_2][13] = 0
            
        #'reduced_cost'
        v_feats_new[x_2][14] = v_feats_new[x_1][14] + v_feats_new[x_2][14]
        
        #age age 属性是指 列（变量） 在 SCIP 中创建的 时间标记, 年龄较小的列可能更优先于年龄较大的列被考虑
        v_feats_new[x_2][15] = min(v_feats_new[x_1][15], v_feats_new[x_2][15])
        
        #'sol_val' #SCIPcolGetPrimsol(cols[i])
        v_feats_new[x_2][16] = v_feats_new[x_1][16] + v_feats_new[x_2][16]
        #'inc_val'
        v_feats_new[x_2][17] = v_feats_new[x_1][17] + v_feats_new[x_2][17]
        #'avg_inc_val'
        v_feats_new[x_2][18] = v_feats_new[x_1][18] + v_feats_new[x_2][18]
        
        v_feats_new_dict = {'names':gcn_state[2]['names'],
                            'values':v_feats_new}
        #constraint features 都不变
        #obj_cosine_similarity
        #bias
        #is_tight
        #age
        #dualsol_val_normalized
        #edge features
        
        #edge constraint
        #只有x1的部分不变
        #只有x2的部分或都有的时候 x_2 - x_1, x_1的系数变为 x_1系数- x_2系数
        
        #edge features 要先找到对应的边
        # print('edge_feats_new',edge_feats_new)
        
        idx = edge_feats_new['indices']#(con_idx, var_idx)
        edge = edge_feats_new['values']
        index_list = [[],[]]
        edge_list = []
        con_dict = {}
        #按照con_index划分系数
        for i in range(edge.shape[0]):
            con_idx = idx[0][i]
            var_idx = idx[1][i]
            coef = edge[i]
            if (con_idx not in con_dict.keys()):
                con_dict[con_idx] = [[var_idx], [coef]]
            else:
                con_dict[con_idx][0].append(var_idx)
                con_dict[con_idx][1].append(coef)
        for con_idx in con_dict.keys():
            var_list = con_dict[con_idx][0]
            coef_list = con_dict[con_idx][1]
            if (x_2 in var_list):
                if (x_1 in var_list):
                    coef_list[var_list.index(x_1)] = coef_list[var_list.index(x_1)] - coef_list[var_list.index(x_2)]
                else:#加入x_1
                    var_list.append(x_1)
                    coef_list.append(-coef_list[var_list.index(x_2)])
            for i in range(len(var_list)):
                index_list[0].append(con_idx)
                index_list[1].append(var_list[i])
                edge_list.append(coef_list[i])
        
        edge_feats_new = {
        'names': edge_feats_new['names'],
        'indices': np.array(index_list),
        'values': np.array(edge_list).reshape(-1,1),}
        
        gcn_state_new = [c_feats_new, edge_feats_new, v_feats_new_dict]
        # sample = data['data']
        
        #construct a SCIP model
        # var_type = ['B', 'I', 'K', 'C']
        # for i in range(len(v_feats_new)):
        #     for j in range(4):
        #         if (v_feats_new[i][j]==1.0):
                    
            
        if (x_2 in action_set):
            idx = action_set.index(x_2)
            idx_1 = action_set.index(x_1)
            new_l1[idx] = new_l1[idx] - (max(new_l1[idx], new_l1[idx_1]) - min(new_l1[idx], new_l1[idx_1]))
            new_l2[idx] = new_l2[idx] - (max(new_l2[idx], new_l2[idx_1]) - min(new_l2[idx], new_l2[idx_1]))
            new_scores[idx] = new_l1[idx]*new_l2[idx]
        new_sample = gcn_state_new, bestcand, action_set, new_scores, new_l1, new_l2#再去除这些，主要是l1, l2, score，都要建模成一个scip的问题去计算
        new_data['data'] = new_sample#construct a problem
        new_data_list.append(new_data)
    return new_data_list


if __name__ == '__main__':
    for i in range(1,2):#/home/LAB/lutk/linglong-data/2.rl4bb
        dir = f'../data/samples/setcover_400r_1000c_0.05d_100mc_0se/new_20/train/sample_{i}.pkl'#不同的实例的一个数据
        data = pickle.load(gzip.open(dir, 'rb'))
        print('data',type(data),data.keys())
        new_data_list = copy_data_new_20(data, k = 1)
# def get_Abc(topk_state):
#     obj_coefs = topk_state['obj_coefs']#目标系数
#     n_cols = topk_state['n_cols']
#     # print('n_cols', n_cols)#1000
#     con_lhs = topk_state['con_lhs']
#     con_rhs = topk_state['con_rhs']#约束的左常量和右常量
#     ## con_lhs <= activity <= con_rhs
#     n_rows = con_rhs.shape[0]
    
#     colidxs = topk_state['colidxs']#变量的index
#     rowidxs = topk_state['rowidxs']#约束的index
#     con_coefs = topk_state['con_coefs']
#     c = obj_coefs
#     constraints = {}
#     for index, value in enumerate(rowidxs):
#         if value not in constraints.keys():
#             constraints[value] = []  # 如果该值没有对应的索引列表，创建一个空列表
#         constraints[value].append(index)
#     A = []
#     b = []
#     first_index = []
#     rowidxs_list = list(set(rowidxs))
#     for i in range(len(rowidxs_list)):
#         con = rowidxs_list[i]
#         lhs = con_lhs[i]
#         rhs = con_rhs[i]
#         vars_index = constraints[con]
#         a1 = [0 for _ in range(n_cols)]
#         a2 = [0 for _ in range(n_cols)]
#         first_index.append(rowidxs[vars_index[0]])
#         if (not np.isnan(lhs)):#左边不是无穷大
#             for index in vars_index:
#                 var = colidxs[index]
#                 a1[var] = -con_coefs[index]
#             b.append(-lhs)
#             A.append(a1)
#         if (not np.isnan(rhs)):
#             for index in vars_index:
#                 var = colidxs[index]
#                 a2[var] = con_coefs[index]
#             b.append(rhs)
#             A.append(a2)
#     return A, b, c

# def solve_lp(problem_index, is_ub, A, b, c, vars_list):#, var_list, row_list, objective
#     """
#     传入问题的名称和问题对象，返回求解结果。
#     """
    
    
#     # (problem_index, is_ub) = tasks
#     if (is_ub):
#         numm = problem_index*2
#     else:
#         numm = problem_index*2 + 1
#     lp_model = Model(f"LP Example{numm}")
#     lp_model.hideOutput()
#     var_mapping = {}  # 保存原变量与新变量的映射关系
#     # vars_list = model.getVars()
#     for (name, lb, ub, index) in vars_list:
#         new_var = lp_model.addVar(
#             name= name,
#             lb= lb,
#             ub= ub,
#             vtype="C"  # 强制为连续变量
#         )
#         var_mapping[index] = new_var
#         # print('var',var,type(var))
#     # print('var_mapping',len(var_mapping.keys()),var_mapping.keys())

#     for i in range(len(b)):#model.getVars(),model.getLPRowsData(),model.getObjective()
#         row = A[i]
#         hs = b[i]
#         # row_vars, row_vals = rows.getCols(), rows.getVals()

#         new_row = sum(var_mapping[j] * row[j] for j in range(len(row)))
#         # new_row1 = sum(-var_mapping[var.getLPPos()] * coef for var, coef in zip(row_vars, row_vals))
#         # print('new_row',new_row)
#         # lp_model.addCons(new_row1 <= -lhs)
#         lp_model.addCons(new_row <= hs)

#     # 添加目标函数
#     # objective = model.getObjective()
#     # print('len(c)',len(c),'len(b)',len(b),var_mapping.keys())
#     lp_objective = sum(var_mapping[i] * c[i] for i in range(len(c)))
#     lp_model.setObjective(lp_objective, sense="minimize")
    
#     cand = var_mapping[problem_index]
#     if (is_ub):
#         lb = cand.getLbGlobal()
#         lp_model.chgVarUb(cand, lb)
#     else:
#         ub = cand.getUbGlobal()
#         lp_model.chgVarLb(cand, ub)
#     # print(lp_model)
#     lp_model.optimize()
#     # print('self.model.getLPObjVal()',self.model.getLPObjVal(), cur_z)
#     lp_score = lp_model.getObjVal()
    
#     return lp_score



# for i in range(1,2):
#     dir = f'/home/LAB/lutk/linglong-data/2.rl4bb/data/samples/facility_0se/new_20/train/sample_{i}.pkl'
#     data = pickle.load(gzip.open(dir, 'rb'))
#     sample = pickle.load(gzip.open(dir, 'rb'))['data']
#     depth = pickle.load(gzip.open(dir, 'rb'))['node_depth']

#     device = torch.device(f"cuda:{args.device}" if torch.cuda.is_available() else "cpu")
#     print('len(sample)', len(sample))
#     gcn_state, bestcand, action_set, scores, l1, l2 = sample