import random
import numpy as np
import torch
import argparse
import scipy.sparse as sp
import dgl
import pickle, gzip
import sys
import itertools
from pyscipopt import SCIP_PARAMSETTING, Model, Branchrule, Nodesel, SCIP_RESULT
from torch.utils.data import Dataset
from treelib import Tree
sys.path.append('../')
from src.logger import Logger
from src.model import GCNN_Net, PD_Net
import time
import torch.nn.functional as F
from collections import defaultdict
import copy

def flip_variables(sample):
    edge_feats = sample[1]['values']
    edge_feat_indices = sample[1]['indices']
    c_feats = sample[0]['values']
    v_feats = sample[2]['values']
    flip_mask = np.random.rand(v_feats.shape[0]) < 0.5
    flip_indices = np.where(flip_mask)[0]
    for idx in flip_indices:
        v_feats[idx, 4] *= -1  # coef_normalized
        
        has_ub, has_lb = v_feats[idx, 5], v_feats[idx, 6]
        v_feats[idx, 5], v_feats[idx, 6] = has_lb, has_ub
        is_at_ub, is_at_lb = v_feats[idx, 7], v_feats[idx, 8]
        v_feats[idx, 7], v_feats[idx, 8] = is_at_lb, is_at_ub
        v_feats[idx, 9] = 1 - v_feats[idx, 9]#sol_frac: 当前解的小数部分（对于整数变量）
        #basis_status_0:变量在单纯形法中的基状态
        #reduced_cost: 缩减成本（对偶变量计算的边际成本）reduced_cost = c_j - πᵀA_j，其中：c_j：目标系数 π：对偶变量 A_j：约束矩阵的第j列
        #['type_0', 'type_1', 'type_2', 'type_3', 'coef_normalized', 9,14,16
# 'has_lb', 'has_ub', 'sol_is_at_lb', 'sol_is_at_ub', 'sol_frac', 'basis_status_0', 'basis_status_1',
# 'basis_status_2', 'basis_status_3', 'reduced_cost', 'age', 'sol_val', 'inc_val', 'avg_inc_val'] 
#sol_val​​：当前线性规划松弛解中变量的值。
# ​​inc_val​​：当前找到的最优整数解（incumbent）中变量的值。
# ​​avg_inc_val​​：历史最优解中该变量的平均值
        v_feats[idx, 14] *= -1#reduced_cost
        v_feats[idx, 16] *= -1
        
    edge_mask = np.isin(edge_feat_indices[1], flip_indices)  # 找到这些列的边
    edge_feats[edge_mask] *= -1  # 系数取反
    return sample, 'flip'

def add_redundant_constraint(sample):#改变了系数的归一化，要重新改写，包括cosine_simlarity等等都要重新计算
    edge_feats = sample[1]['values']
    edge_feat_indices = sample[1]['indices']
    c_feats = sample[0]['values']
    v_feats = sample[2]['values']
    bias = c_feats[:, 1]
    num_rows = len(bias)
    # ['obj_cosine_similarity', 'bias', 'is_tight', 'age', 'dualsol_val_normalized']
    if num_rows < 2:
        return sample  # 无法合并
    row1, row2 = np.random.choice(num_rows, 2, replace=False)
    
    # 3. 提取两行的系数和原始范数
    row1_mask = edge_feat_indices[0] == row1
    row2_mask = edge_feat_indices[0] == row2
    row1_cols = edge_feat_indices[1][row1_mask]  # 行1的变量列索引
    row1_coefs = edge_feats[row1_mask].flatten()
    row2_cols = edge_feat_indices[1][row2_mask]  # 行2的变量列索引
    row2_coefs = edge_feats[row2_mask].flatten()
    
    
    # 4. 合并系数（A_new = A1 + A2）
    # --------------------------------------------
    # 用字典合并相同列的系数
    merged_coefs = defaultdict(float)
    for col, coef in zip(row1_cols, row1_coefs):
        merged_coefs[col] += coef
    for col, coef in zip(row2_cols, row2_coefs):
        merged_coefs[col] += coef
    # 5. 计算新约束的范数和归一化系数
        # --------------------------------------------
    # new_indices = np.array(list(merged_coefs.keys()))
    new_coef_values = np.array(list(merged_coefs.values()))
    new_norm = np.linalg.norm(new_coef_values)
    new_coef_normalized = new_coef_values / new_norm if new_norm > 0 else new_coef_values

    # 6. 更新约束矩阵（替换 row1 的约束）
    # --------------------------------------------
    new_row_idx = num_rows  # 新行的索引为当前行数
    new_rowidxs = np.full(len(merged_coefs), new_row_idx)
    new_colidxs = np.array(list(merged_coefs.keys()))
    
    edge_feat_indices = np.hstack([
        edge_feat_indices,
        np.vstack([new_rowidxs, new_colidxs])
    ])
    edge_feats = np.vstack([
        edge_feats,
        new_coef_normalized.reshape(-1, 1)
    ])
    
    # 7. 更新右侧值（b_new = b1 + b2）
    # --------------------------------------------
    # 反归一化后相加，再重新归一化
    b1 = bias[row1]
    b2 = bias[row2]
    new_bias = (b1 + b2) / new_norm if new_norm > 0 else (b1 + b2)
    # bias[row1] = new_bias
    # print('c_feats',c_feats.shape)
    # ['obj_cosine_similarity', 'bias', 'is_tight', 'age', 'dualsol_val_normalized']
    c_feats = np.vstack([c_feats, np.zeros((1, c_feats.shape[1]))])
    c_feats[-1,1] = new_bias
    c_feats[-1, 2] = 1 if c_feats[row1, 2] == 1 and c_feats[row1, 2] == 1 else 0
    # print('is_tight', c_feats[:20,2])
    # for ii in range(c_feats.shape[1]):
    #     if ii == 1:
    #         c_feats[:,1] = np.vstack([c_feats[1], new_bias])
    #     else:
    #         # 其他特征（如obj_cosine_similarity）初始化为0或继承平均值
    #         c_feats[:,ii] = np.append(c_feats[:,ii], 0.0)  # 示例：初始化为0


        # 8. 更新对偶解（dualsol_val_normalized）
    # if 'dualsol_val_normalized' in row_feats:
    pi1 = c_feats[row1, 4]
    pi2 = c_feats[row2, 4]
    pi_new = (pi1 + pi2) / 2  # 平均值策略
    c_feats[-1, 4] = pi_new

    # 9. 重新计算余弦相似度（可选）
    # --------------------------------------------
    obj_coef = v_feats[:, 4]  # 目标函数系数
    # num_rows = len(sample[0]['bias'])
    
    new_cosine = 0
    for i in range(len(new_colidxs)):
        # 提取第i个约束的非零系数
        idx = new_colidxs[i]
        new_cosine += obj_coef[idx] * new_coef_normalized[i]
        
    c_feats[-1,0] = new_cosine
    
    
    # if 'obj_cosine_similarity' in row_feats:
    # row1_cols = new_colidxs
    # row1_coefs = new_coef_normalized
    # dot_product = np.sum(obj_coef[row1_cols] * row1_coefs)
    # obj_norm = np.linalg.norm(obj_coef)
    # row_norm = np.linalg.norm(row1_coefs)
    # row_feats['obj_cosine_similarity'][row1] = dot_product / (obj_norm * row_norm) if obj_norm * row_norm > 0 else 0.0

    # 10. 返回更新后的样本
    return sample, 'add'  

def perturb_objective(sample):
    """微调目标系数"""
    edge_feats = sample[1]['values']
    edge_feat_indices = sample[1]['indices']
    c_feats = sample[0]['values']
    v_feats = sample[2]['values']
    bias = c_feats[:, 1]
    num_rows = len(bias)
    
    noise = np.random.normal(0, 0.01, v_feats[:, 4].shape)
    v_feats[:, 4] += noise
    return sample, 'per_obj'

def perturb_constraints(sample):
    """微调约束"""
    edge_feats = sample[1]['values']
    edge_feat_indices = sample[1]['indices']
    c_feats = sample[0]['values']
    v_feats = sample[2]['values']
    bias = c_feats[:, 1]
    num_rows = len(bias)
    
    noise = np.random.normal(0, 0.01, edge_feats.shape)
    edge_feats += noise
    return sample, 'per_con'

def perturb_duals(sample):
    """扰动对偶解和缩减成本"""
    edge_feats = sample[1]['values']
    edge_feat_indices = sample[1]['indices']
    c_feats = sample[0]['values']
    v_feats = sample[2]['values']
    bias = c_feats[:, 1]
    num_rows = len(bias)
    
    c_feats[:, 4] *= np.random.normal(1, 0.05, c_feats[:, 4].shape)
    
    noise = np.random.normal(0, 0.05, v_feats[:, 14].shape)
    v_feats[:, 14] = np.sign(v_feats[:, 14]) * \
                                np.maximum(np.abs(v_feats[:, 14] + noise), 1e-5)
    return sample, 'duals'

def augment(sample, probs=[0.5, 0.5, 0, 0, 0]):#默认[0.35, 0.35, 0.1, 0.1, 0.1]0, 0, 0, 0, 1
    """
    按照 probs 概率从5个增强函数中随机选一个执行。

    probs: 长度为5的概率列表，和为1。
    顺序对应：
    0 - flip_variables
    1 - add_redundant_constraint
    2 - perturb_objective
    3 - perturb_constraints
    4 - perturb_duals
    """
    assert len(probs) == 5 and abs(sum(probs) - 1.0) < 1e-6, "probs长度必须为5且和为1"
    
    funcs = [
        flip_variables,
        add_redundant_constraint,
        perturb_objective,
        perturb_constraints,
        perturb_duals,
    ]
    
    idx = random.choices(range(5), weights=probs, k=1)[0]
    return funcs[idx](sample)