import pickle, gzip
import numpy as np
import torch
import torch.nn.functional as F
from collections import defaultdict
import copy
from collections import Counter
import matplotlib.pyplot as plt
from pyscipopt import Model

length_counter = Counter()  # 用于统计候选变量长度的频数

def flip_variables(sample):
    edge_feats = sample[1]['values']
    # print('sample[1][\'values\']', sample[1]['values'].shape)
    edge_feat_indices = sample[1]['indices']
    c_feats = sample[0]['values']
    v_feats = sample[2]['values']
    print('v_feats', c_feats.shape)
    # 获得约束和变量的索引
    # num_nodes_dict = {'c': c_feats.shape[0], 'v': v_feats.shape[0]}
    # c_feats, edge_feats, v_feats = data
    flip_mask = np.random.rand(v_feats.shape[0]) < 0.5
    flip_indices = np.where(flip_mask)[0]
    # print('flip_indices',flip_indices)
    # 直接翻转归一化后的系数（无需调整范围）
    # print(sample[2]['values'][flip_indices,4][:10])
    for idx in flip_indices:
        v_feats[idx, 4] *= -1  # coef_normalized
        
        has_ub, has_lb = v_feats[idx, 5], v_feats[idx, 6]
        v_feats[idx, 5], v_feats[idx, 6] = has_lb, has_ub
        is_at_ub, is_at_lb = v_feats[idx, 7], v_feats[idx, 8]
        v_feats[idx, 7], v_feats[idx, 8] = is_at_lb, is_at_ub
        v_feats[idx, 9] = 1 - v_feats[idx, 9]#sol_frac: 当前解的小数部分（对于整数变量）
        #basis_status_0:变量在单纯形法中的基状态
        #reduced_cost: 缩减成本（对偶变量计算的边际成本）reduced_cost = c_j - πᵀA_j，其中：c_j：目标系数 π：对偶变量 A_j：约束矩阵的第j列
        #['type_0', 'type_1', 'type_2', 'type_3', 'coef_normalized', 9,14,16
# 'has_lb', 'has_ub', 'sol_is_at_lb', 'sol_is_at_ub', 'sol_frac', 'basis_status_0', 'basis_status_1',
# 'basis_status_2', 'basis_status_3', 'reduced_cost', 'age', 'sol_val', 'inc_val', 'avg_inc_val'] 
#sol_val​​：当前线性规划松弛解中变量的值。
# ​​inc_val​​：当前找到的最优整数解（incumbent）中变量的值。
# ​​avg_inc_val​​：历史最优解中该变量的平均值
        v_feats[idx, 14] *= -1#reduced_cost
        v_feats[idx, 16] *= -1
        # v_feats[idx, 17] *= -1
        # v_feats[idx, 18] *= -1
        
        # edge_mask = np.isin(edge_feat_indices[1], idx)  # 找到这些列的边
        # edge_feats['coef_normalized'][edge_mask] *= -1  # 系数取反
        
    edge_mask = np.isin(edge_feat_indices[1], flip_indices)  # 找到这些列的边
    edge_feats[edge_mask] *= -1  # 系数取反
    #normlize目标系数
    # avg = np.mean(v_feats[:, 4].reshape(-1))
    # v_feats[:, 4] = v_feats[:, 4] - avg
    # print('np.sum(obj_norm)', np.linalg.norm(v_feats[:, 4]),np.sum(v_feats[:, 4]))
    # print('np.sum(obj_norm)', np.sum(edge_feats), np.linalg.norm(edge_feats), c_feats.shape[0])
        # obj_norm = np.linalg.norm(v_feats[:, 4].reshape(-1))
        # v_feats[:, 4] = v_feats[:, 4]/obj_norm
        #重新正则化
        # 9,14,16
        # lb, ub = state['col']['lbs'][idx], state['col']['ubs'][idx]
        # for idx in flip_var_indices:
        #     # 目标系数
        #     state['col']['coefs'][idx] *= -1
        #     # 上下界交换并取反
        #     lb, ub = state['col']['lbs'][idx], state['col']['ubs'][idx]
        #     state['col']['lbs'][idx] = -np.inf if np.isinf(ub) else -ub
        #     state['col']['ubs'][idx] = -np.inf if np.isinf(lb) else -lb
        #     # 缩减成本和解值
        #     state['col']['redcosts'][idx] *= -1
        #     state['col']['solvals'][idx] *= -1
        #     state['col']['solfracs'][idx] = 1 - state['col']['solfracs'][idx]
        #['obj_cosine_similarity', 'bias', 'is_tight', 'age', 'dualsol_val_normalized']
        # # 翻转约束矩阵中的系数
        # for i in range(len(state['nzrcoef']['vals'])):
        #     col_idx = state['nzrcoef']['colidxs'][i]
        #     if col_idx in flip_var_indices:
        #         state['nzrcoef']['vals'][i] *= -1

        # # 翻转对偶解（约束的符号需同步）
        # state['row']['dualsols'] *= -1
    # return state
        # 翻转边特征中的系数（保持归一化性质）
        # edge_feats[:, 2] = np.where(
        #     np.isin(edge_feats[:, 1], flip_indices),
        #     edge_feats[:, 2] * -1,  # 直接翻转归一化值
        #     edge_feats[:, 2]
        # )
    # print('sample_before_flip', sample[2]['values'][:,4],np.sum(sample[2]['values'][:,4] == v_feats[:,4]))
    # print(v_feats[flip_indices,4][:10])
    
    # sample[1]['values'] = edge_feats
    # # print('sample[1][\'values\']', sample[1]['values'].shape)
    # # edge_feat_indices = sample[1]['indices']
    # sample[0]['values'] = c_feats
    # sample[2]['values'] = v_feats
    # print('sample_after_flip', sample[2]['values'][:,4])
    return sample, 'flip'

def add_redundant_constraint(sample):#改变了系数的归一化，要重新改写，包括cosine_simlarity等等都要重新计算
    edge_feats = sample[1]['values']
    edge_feat_indices = sample[1]['indices']
    c_feats = sample[0]['values']
    v_feats = sample[2]['values']
    bias = c_feats[:, 1]
    num_rows = len(bias)
    # ['obj_cosine_similarity', 'bias', 'is_tight', 'age', 'dualsol_val_normalized']
    if num_rows < 2:
        return sample  # 无法合并
    row1, row2 = np.random.choice(num_rows, 2, replace=False)
    
    # 3. 提取两行的系数和原始范数
    row1_mask = edge_feat_indices[0] == row1
    row2_mask = edge_feat_indices[0] == row2
    row1_cols = edge_feat_indices[1][row1_mask]  # 行1的变量列索引
    row1_coefs = edge_feats[row1_mask].flatten()
    row2_cols = edge_feat_indices[1][row2_mask]  # 行2的变量列索引
    row2_coefs = edge_feats[row2_mask].flatten()
    
    
    # 4. 合并系数（A_new = A1 + A2）
    # --------------------------------------------
    # 用字典合并相同列的系数
    merged_coefs = defaultdict(float)
    for col, coef in zip(row1_cols, row1_coefs):
        merged_coefs[col] += coef
    for col, coef in zip(row2_cols, row2_coefs):
        merged_coefs[col] += coef
    # 5. 计算新约束的范数和归一化系数
        # --------------------------------------------
    # new_indices = np.array(list(merged_coefs.keys()))
    new_coef_values = np.array(list(merged_coefs.values()))
    new_norm = np.linalg.norm(new_coef_values)
    new_coef_normalized = new_coef_values / new_norm if new_norm > 0 else new_coef_values

    # 6. 更新约束矩阵（替换 row1 的约束）
    # --------------------------------------------
    new_row_idx = num_rows  # 新行的索引为当前行数
    new_rowidxs = np.full(len(merged_coefs), new_row_idx)
    new_colidxs = np.array(list(merged_coefs.keys()))
    
    edge_feat_indices = np.hstack([
        edge_feat_indices,
        np.vstack([new_rowidxs, new_colidxs])
    ])
    edge_feats = np.vstack([
        edge_feats,
        new_coef_normalized.reshape(-1, 1)
    ])
    
    

    

    # 7. 更新右侧值（b_new = b1 + b2）
    # --------------------------------------------
    # 反归一化后相加，再重新归一化
    b1 = bias[row1]
    b2 = bias[row2]
    new_bias = (b1 + b2) / new_norm if new_norm > 0 else (b1 + b2)
    # bias[row1] = new_bias
    # print('c_feats',c_feats.shape)
    # ['obj_cosine_similarity', 'bias', 'is_tight', 'age', 'dualsol_val_normalized']
    c_feats = np.vstack([c_feats, np.zeros((1, c_feats.shape[1]))])
    c_feats[-1,1] = new_bias
    c_feats[-1, 2] = 1 if c_feats[row1, 2] == 1 and c_feats[row1, 2] == 1 else 0
    # print('is_tight', c_feats[:20,2])
    # for ii in range(c_feats.shape[1]):
    #     if ii == 1:
    #         c_feats[:,1] = np.vstack([c_feats[1], new_bias])
    #     else:
    #         # 其他特征（如obj_cosine_similarity）初始化为0或继承平均值
    #         c_feats[:,ii] = np.append(c_feats[:,ii], 0.0)  # 示例：初始化为0


        # 8. 更新对偶解（dualsol_val_normalized）
    # if 'dualsol_val_normalized' in row_feats:
    pi1 = c_feats[row1, 4]
    pi2 = c_feats[row2, 4]
    pi_new = (pi1 + pi2) / 2  # 平均值策略
    c_feats[-1, 4] = pi_new

    # 9. 重新计算余弦相似度（可选）
    # --------------------------------------------
    obj_coef = v_feats[:, 4]  # 目标函数系数
    # num_rows = len(sample[0]['bias'])
    
    new_cosine = 0
    for i in range(len(new_colidxs)):
        # 提取第i个约束的非零系数
        idx = new_colidxs[i]
        new_cosine += obj_coef[idx] * new_coef_normalized[i]
        
    c_feats[-1,0] = new_cosine
    
    
    # if 'obj_cosine_similarity' in row_feats:
    # row1_cols = new_colidxs
    # row1_coefs = new_coef_normalized
    # dot_product = np.sum(obj_coef[row1_cols] * row1_coefs)
    # obj_norm = np.linalg.norm(obj_coef)
    # row_norm = np.linalg.norm(row1_coefs)
    # row_feats['obj_cosine_similarity'][row1] = dot_product / (obj_norm * row_norm) if obj_norm * row_norm > 0 else 0.0

    # 10. 返回更新后的样本
    return sample
def perturb_objective(sample):
    """微调目标系数"""
    edge_feats = sample[1]['values']
    edge_feat_indices = sample[1]['indices']
    c_feats = sample[0]['values']
    v_feats = sample[2]['values']
    bias = c_feats[:, 1]
    num_rows = len(bias)
    
    noise = np.random.normal(0, 0.01, v_feats[:, 4].shape)
    v_feats[:, 4] += noise
    return sample

def perturb_constraints(sample):
    """微调约束"""
    edge_feats = sample[1]['values']
    edge_feat_indices = sample[1]['indices']
    c_feats = sample[0]['values']
    v_feats = sample[2]['values']
    bias = c_feats[:, 1]
    num_rows = len(bias)
    
    noise = np.random.normal(0, 0.01, edge_feats.shape)
    edge_feats += noise
    return sample

def perturb_duals(sample):
    """扰动对偶解和缩减成本"""
    edge_feats = sample[1]['values']
    edge_feat_indices = sample[1]['indices']
    c_feats = sample[0]['values']
    v_feats = sample[2]['values']
    bias = c_feats[:, 1]
    num_rows = len(bias)
    
    c_feats[:, 4] *= np.random.normal(1, 0.05, c_feats[:, 4].shape)
    
    noise = np.random.normal(0, 0.05, v_feats[:, 14].shape)
    v_feats[:, 14] = np.sign(v_feats[:, 14]) * \
                                np.maximum(np.abs(v_feats[:, 14] + noise), 1e-5)
    return sample        
depth_dict = {}#{1: 70, 4: 388, 5: 558, 10: 1021, 3: 239, 7: 1000, 11: 953, 6: 734, 9: 1111, 8: 1091, 12: 762, 2: 133, 13: 587, 15: 330, 14: 430, 22: 9, 18: 82, 16: 220, 17: 135, 19: 52, 20: 38, 21: 15, 25: 3, 0: 32, 24: 2, 23: 5}#{}
# num1 = 0
# num2 =0
# num3 =0
# num4 = 0
# max_depth = 0
# for a  in depth_dict.keys():
#     num = depth_dict[a]
#     # print(a, num)
#     if (a<=1):
#         num1+=  num
#     elif (a < 5):
#         num2 += num
#     elif (a<8):
#         num3 += num
#     else:
#         num4 += num
#         if (a>max_depth):
#             max_depth = a
# print(num1,num2,num3,num4, max_depth)#102 760 2292 6846 25
#102 760 2292 6846
#label_sb = scores.index(max(scores))
            # label_lp_0 = lp_scores_0.index(max(lp_scores_0))
            # label_lp_1 = lp_scores_1.index(max(lp_scores_1))
#setcover_400r_1000c_0.05d_100mc_0se cauctions_0se facility_0se fcmcnf  indset_400n_4a_0se gisp wpms
depth_list = []
dep_dict = {0:[],1:[],2:[],3:[],4:[],5:[]}
for i in range(100):
    dir = f'../data/instances/setcover_400r_1000c_0.05d_100mc_0se/transfer_2000r/instance_{i}.lp'
# for i in range(1, 10000):
#     dir = f'../data/samples/setcover_400r_1000c_0.05d_100mc_0se/new_20/train/sample_{i}.pkl'
#     try:
#         data = pickle.load(gzip.open(dir, 'rb'))
#         sample = data['data']
#         depth = data['node_depth']
#         max_depth = data['max_depth']
#         if (depth <=5):
            
#             gcn_state, bestcand, action_set, scores, l1, l2 = sample

#             cand_len = len(action_set)
#             dep_dict[depth].append(cand_len)
#             length_counter[cand_len] += 1  # 记录该长度出现了一次

#             # print(f"[{i}] depth={depth}, max_depth={max_depth}, num_cands={cand_len}")

#     except Exception as e:
#         print(f"Error loading sample_{i}: {e}")
# print(sum(dep_dict[0]) / len(dep_dict[0]))
# print(sum(dep_dict[1]) / len(dep_dict[1]))
# print(sum(dep_dict[2]) / len(dep_dict[2]))
# print(sum(dep_dict[3]) / len(dep_dict[3]))
# print(sum(dep_dict[4]) / len(dep_dict[4]))
# print(sum(dep_dict[5]) / len(dep_dict[5]))
print("\nCandidate set length distribution:")
# for length, count in sorted(length_counter.items()):
#     print(f"Length {length:2d} → {count:3d} times")


# plt.figure(figsize=(8, 5))  # 可选：设置图片大小
# plt.bar(length_counter.keys(), length_counter.values(), color='skyblue', edgecolor='black')
# plt.xlabel("Number of Candidate Variables (len(action_set))")
# plt.ylabel("Frequency")
# plt.title("Distribution of Candidate Variable Set Sizes")
# plt.grid(True)

# plt.tight_layout()
# plt.savefig("candidate_dist.png", dpi=300)  # 保存为高清图片
# plt.show()
    # print(len(gcn_state[0]['names']),gcn_state[0]['names'], len(gcn_state[0]['values']), len(gcn_state[0]['values'][0]))
    # print(len(gcn_state[1]['names']),gcn_state[1]['names'], len(gcn_state[1]['values']), len(gcn_state[1]['values'][0]))
    # print(len(gcn_state[2]['names']),gcn_state[2]['names'], len(gcn_state[2]['values']), len(gcn_state[2]['values'][0]),gcn_state[2]['values'][0])
    # coef_list = []
    # idx = gcn_state[1]['indices']#(con_idx, var_idx)
    # edge = gcn_state[1]['values']
    # names = gcn_state[1]['names']
    # print(idx.shape, edge.shape, len(names))
    # num = idx.shape[1]
    # sum = 0
    # for i in range(num):
    #     if (idx[1][i] == 0):
    #         sum += edge[i]
    # print(names)
    # print(bestcand, scores.index(max(scores)), action_set[scores.index(max(scores))], len(scores), len(action_set))
    # make sure that feature shape matches the graph
    # 确保特征形状与图形匹配
    # v2c_index = (idx[1, :], idx[0, :])
    # c2v_index = (idx[0, :], idx[1, :])
    # print(v2c_index)
    #['type_0', 'type_1', 'type_2', 'type_3', 'coef_normalized', 9,14,16
    # 'has_lb', 'has_ub', 'sol_is_at_lb', 'sol_is_at_ub', 'sol_frac', 'basis_status_0', 'basis_status_1',
    # 'basis_status_2', 'basis_status_3', 'reduced_cost', 'age', 'sol_val', 'inc_val', 'avg_inc_val'] 
    #['obj_cosine_similarity', 'bias', 'is_tight', 'age', 'dualsol_val_normalized']
    #['coef_normalized'] 29515
    # for i in range(node_number):
    #     coef = gcn_state[2]['values'][i][4]
    #     if (coef <1):
    #         print(i)
        # coef_list.append(coef)
# print(depth_list)
# print(coef_list)
    # print('depth', depth)
    # if (depth in depth_dict.keys()):
    #     depth_dict[depth] +=1
    # else:
    #     depth_dict[depth] =1
    # device = torch.device(f"cuda:{args.device}" if torch.cuda.is_available() else "cpu")
    # print('len(sample)', len(sample))
    # gcn_state, bestcand, action_set, scores, l1, l2 = sample
    # print('node_depth', depth)
    # print('gcn_state', len(gcn_state))
    
    # index = np.argmax(scores)
    # label_sb = scores.index(max(scores))#只训练一下，可不可以训练top_5
    # print('label_sb', label_sb)
    # print('bestcand',bestcand, action_set[index])#最好的是谁
    # print('action_set', len(action_set))
    # print('scores', len(scores))
    # # print(data)
    # aa = gcn_state[2]['values']
    # names = gcn_state[2]['names']
    # print(names, len(aa))
    # sum1 = 0
    # sum2 = 0
    # sum3 = 0
    # sum4 = 0
    # for i in range(len(aa)):
    #     sum1 += aa[i][0]
    #     sum2 += aa[i][1]
    #     sum3 += aa[i][2]
    #     sum4 += aa[i][3]
    # # type1 = aa[0]
    # # type2 = aa[1]
    # # type3 = aa[2]
    # # type4 = aa[3]
    # # for i in range(1,5):
    # #      exec(f"print('type{i}:', len(type{i}), sum(type{i}))")
    # print(sum1 ,sum2, sum3, sum4)
    # k = 5
    # label_sb = scores.index(max(scores))
    # label_lp_0 = l1.index(max(l1))
    # label_lp_1 = l2.index(max(l2))
    # print(label_sb, label_lp_0, label_lp_1)
    # label_sb_topk = np.argpartition(scores, -k)[-k:]
    # label_lp_0_topk = np.argpartition(l1, -k)[-k:]
    # label_lp_1_topk = np.argpartition(l2, -k)[-k:]
    # print(label_sb_topk, label_lp_0_topk, label_lp_1_topk)
    # ss = torch.FloatTensor(scores).to(device)
    # print(ss.shape)
    # mask = torch.zeros_like(ss)
    # print('mask', label_sb_topk)
    # src = torch.ones(5)
    # mask.scatter_(dim=0, index=torch.tensor(label_sb_topk), src=src)  # 将topk的索引位置设置为1，其余为0
    # print('mask',mask)
    # # 将掩码应用于原始分数：仅保留topk元素的分数，其他元素置为0
    # scores = torch.tensor(scores)
    # masked_scores = scores * mask
    
    # # 对处理后的分数应用Softmax归一化
    # softmax_scores = F.softmax(masked_scores, dim=0)

    # logits_sb = [g.nodes['v'].data['s'][candistaes[i]] for i, g in enumerate(graphs)]
    # logits_0 = [g.nodes['v'].data['s_0'][candistaes[i]] for i, g in enumerate(graphs)]
    # logits_1 = [g.nodes['v'].data['s_1'][candistaes[i]] for i, g in enumerate(graphs)]

    # # loss_ce_sb = sum([F.cross_entropy(logits_sb[i].T, label_sb[i:i+1]) for i in range(count)]) / count
    # loss_ce_sb = sum([F.mse_loss(logits_sb[i].squeeze(), torch.FloatTensor(softmax_scores[i]).to(device)) for i in range(count)]) / count
    # print(len(type_{}.format(i)), sum(type_{}.format(i)))
    # print(len(aa),len(name))
    # print(aa['type_0'].shape, aa['type_1'].shape, aa['type_2'].shape, aa['type_3'].shape,)
    # for state in gcn_state:
    #     print(state)
    #     print('111111111111111')
    #     break
    #['type_0', 'type_1', 'type_2', 'type_3', 'coef_normalized', 'has_lb', 'has_ub', 'sol_is_at_lb', 'sol_is_at_ub', 'sol_frac', 'basis_status_0', 'basis_status_1', 'basis_status_2', 'basis_status_3', 'reduced_cost', 'age', 'sol_val', 'inc_val', 'avg_inc_val']
    # 'type_0', 'type_1', 'type_2', 'type_3'：变量类型 BINARY INTEGER IMPLINT CONTINUOUS
    #sol_is_at_lb：当前最优解和lb相等
    #sol_frac：当前最优解接近整数
    #'basis_status_0', 'basis_status_1', 'basis_status_2', 'basis_status_3', # LOWER BASIC UPPER ZERO选择非基变量进行变换
    # 'reduced_cost', 
    # 'age', 
    # 'sol_val',
    # 'inc_val',
    # 'avg_inc_val'
    # print('lp_scores_1', lp_scores_1)
    # print(depth_dict)
    # 0-1,2-4,5-7,8
# max_depth = pickle.load(gzip.open(dir, 'rb'))['max_depth']
# min_depth = pickle.load(gzip.open(dir, 'rb'))['min_depth']
# print('max_depth',max_depth,'min_depth',min_depth)
# depth -= min_depth
# max_depth -= min_depth

# weight_init = depth/max_depth if max_depth else 1.0
# # weight_1 = (1 + np.exp(-0.5))/(1 + np.exp(weight_init - 0.5))
# weight_2 = -0.4 * weight_init + 1

# # print(f"depth:{depth} \t min_depth:{min_depth} max_depth:{max_depth} \t weight:{weight_init} \t wegiht_2:{weight_2}")

# weight = weight_2

# gcn_state, bestcand, action_set, scores, lp_scores_0, lp_scores_1 = sample