import torch
import torch.nn as nn
import torch.nn.functional as f
from torch_geometric.nn import GATv2Conv, GatedGraphConv, EGConv, JumpingKnowledge, MaxAggregation, MeanAggregation, SumAggregation, AttentionalAggregation, EquilibriumAggregation, MultiAggregation, PowerMeanAggregation, SoftmaxAggregation, MinAggregation, SAGEConv
from torch_geometric.nn import global_mean_pool
from torch_geometric.nn import global_sort_pool
from torch_geometric.nn import global_max_pool
from torch_geometric.nn import GCNConv,GATConv
from torch_geometric.utils import remove_self_loops
from torch_scatter import scatter
'''
File - ggnn.py
This file includes three architectures for GGNNs, two of which do not
actually use gated recurrent units.
'''

def build_aggregators(aggregator_type, fcInputLayerSize=None):
    if aggregator_type == "add":
        aggregator = SumAggregation()
    elif aggregator_type == "mean":
        aggregator = MeanAggregation()
    elif aggregator_type == "max":
        aggregator = MaxAggregation()
    elif aggregator_type == "min":
        aggregator = MinAggregation()
    elif aggregator_type == "power":
        aggregator = PowerMeanAggregation(learn=True)
    elif aggregator_type == "softmax":
        aggregator = SoftmaxAggregation(learn=True)
    elif aggregator_type == "attention":
        assert fcInputLayerSize is not None, "Input Layer size must be set for Attention Aggregator"
        aggregator = AttentionalAggregation(gate_nn=torch.nn.Linear(fcInputLayerSize-1, 1))
    elif aggregator_type == "equilibrium":
        aggregator = EquilibriumAggregation(fcInputLayerSize-1,fcInputLayerSize-1,[256,256])
    else:
        raise ValueError("Not a valid aggregator")
    
    return aggregator

class EGC(torch.nn.Module):
    def __init__(self, passes, inputLayerSize, outputLayerSize, pool, aggregators = ["symnorm"], shouldJump=True, num_heads=8, num_bases=4):
        super(EGC, self).__init__()
        self.passes = passes
        self.shouldJump = shouldJump
        self.modSize = inputLayerSize - (inputLayerSize%num_heads)

        self.egcs = nn.ModuleList([EGConv(in_channels=inputLayerSize if i == 0 else self.modSize, out_channels=self.modSize,aggregators=aggregators,num_heads=num_heads, num_bases=num_bases) for i in range(passes)])

        if self.shouldJump:
            self.jump = JumpingKnowledge('cat', (self.passes*self.modSize)+inputLayerSize)
            fcInputLayerSize = ((self.passes*self.modSize)+inputLayerSize)*len(pool)+1
        else:
            fcInputLayerSize = self.modSize*len(pool) + 1
        
        if len(pool) > 1:
            pools = []
            print(pool)
            for pool_type in pool:
                if self.shouldJump:
                    pools+=[build_aggregators(pool_type, fcInputLayerSize=((self.passes*self.modSize)+inputLayerSize)+1)]
                else:
                    pools+=[build_aggregators(pool_type, fcInputLayerSize=(self.modSize)+1)]
            self.pool = MultiAggregation(aggrs=pools)
        else:
            self.pool = build_aggregators(pool[0], fcInputLayerSize)
    

        self.fc1 = nn.Linear(fcInputLayerSize, fcInputLayerSize//2)
        self.fc2 = nn.Linear(fcInputLayerSize//2,fcInputLayerSize//2)
        self.fcLast = nn.Linear(fcInputLayerSize//2, outputLayerSize)

    def forward(self, x, edge_index, batch, problemType=torch.FloatTensor([1])):
        if self.shouldJump:
            xs = [x]

        for egc in self.egcs: 
            out = egc(x, edge_index)
            x = f.leaky_relu(out)
            if self.shouldJump:
                xs += [x]

        if self.shouldJump:
            x = self.jump(xs)

        x = self.pool(x, batch.long())

        try:
            x = torch.cat((x, problemType.unsqueeze(1).cuda()), dim=1)
        except Exception as e:
            print(e)
            print(problemType.unsqueeze(1))
            print(x.size())
            exit()

        x = self.fc1(x)
        x = f.leaky_relu(x)
        x = self.fc2(x)
        x = f.leaky_relu(x)
        x = self.fcLast(x)

        return x

class GAT(torch.nn.Module):
    def __init__(self, passes, numEdgeSets, inputLayerSize, outputLayerSize, numAttentionLayers, mode, pool, k, shouldJump=True):
        super(GAT, self).__init__()
        self.passes = passes
        self.mode = mode
        self.k = 1
        self.shouldJump = shouldJump
            
        self.gats = nn.ModuleList([GATv2Conv(inputLayerSize,inputLayerSize, heads=numAttentionLayers, concat=False, edge_dim=1) for i in range(passes)])
        if self.passes and self.shouldJump:
           self.jump = JumpingKnowledge(self.mode, channels=inputLayerSize, num_layers=self.passes)
        if self.mode == 'cat' and self.shouldJump:
            fcInputLayerSize = ((self.passes+1)*inputLayerSize*self.k)+1
        else:
            fcInputLayerSize = (inputLayerSize*self.k)+1
        
        if pool == "add":
                self.pool = global_add_pool
        elif pool == "mean":
            self.pool = global_mean_pool
        elif pool == "max":
            self.pool = global_max_pool
        elif pool == "attention":
            self.pool = GlobalAttention(gate_nn=nn.Sequential(torch.nn.Linear(fcInputLayerSize-1, fcInputLayerSize//2), nn.LeakyReLU(), torch.nn.Linear(fcInputLayerSize//2, fcInputLayerSize//2), nn.LeakyReLU(), torch.nn.Linear(fcInputLayerSize//2, 1), nn.Tanh()))
        else:
            raise ValueError("Not a valid pool")

        self.fc1 = nn.Linear(fcInputLayerSize, fcInputLayerSize//2)
        self.fc2 = nn.Linear(fcInputLayerSize//2,fcInputLayerSize//2)
        self.fcLast = nn.Linear(fcInputLayerSize//2, outputLayerSize)
    
    def forward(self, x, edge_index, edge_attr, batch, problemType):
        if self.passes:
            if self.shouldJump:
                xs = [x]

            for gat in self.gats: 
                out = gat(x, edge_index, edge_attr=edge_attr)
                x = f.leaky_relu(out)
                if self.shouldJump:
                    xs += [x]

            if self.shouldJump:
                x = self.jump(xs)

        if self.pool == global_sort_pool:
            x = self.pool(x, batch, self.k)
        else:
            x = self.pool(x, batch)
        
        x = torch.cat((x.reshape(1,x.size(0)*x.size(1)), problemType.unsqueeze(1)), dim=1)
        
        x = self.fc1(x)
        x = f.leaky_relu(x)
        x = self.fc2(x)
        x = f.leaky_relu(x)
        x = self.fcLast(x)

        return x

class GGNN(nn.Module):
    '''
    Base GGNN class
    This Graph Neural Network includes a GRU to produce new node representations
    '''
    def __init__(self, passes, numEdgeSets, inputLayerSize, outputLayerSize, collate):
        super(GGNN, self).__init__()
        self.ggcs = nn.ModuleList([GatedGraphConv(inputLayerSize, outputLayerSize, aggr=collate) for _ in range(numEdgeSets)])
        self.gcn = GatedGraphConv(inputLayerSize, outputLayerSize, aggr=collate)
        self.passes = passes
        self.fc1 = nn.Linear(inputLayerSize+1, 80)
        self.fc2 = nn.Linear(80,80)
        self.fcLast = nn.Linear(80, outputLayerSize)
        self.collate=collate

    def forward(self, x, edge_index, edge_attr, batch, problemType):
       # x, edge_index, problemType = data.x, data.edge_index, data.problemType

        x = f.dropout(x, p=0.6, training=self.training)

        for _ in range(self.passes):
            placeholderX = torch.zeros_like(x)
            for val, gcn in zip(torch.unique(edge_attr), self.ggcs):
                placeholderX += gcn(x, edge_index.transpose(0,1)[(edge_attr==val).squeeze()].transpose(0,1))
            x = placeholderX/len(torch.unique(edge_attr))
        x = self.gcn(x,edge_index)
        x = f.dropout(x, p=0.6, training=self.training)
        if self.collate == "sum":
            x = global_add_pool(x, batch)
        elif self.collate == "mean":
            x = global_mean_pool(x, batch)
        elif self.collate == "max":
            x = global_max_pool(x, batch)
        else:
            raise ValueError("Not a valid collate type")
        
        if len(problemType.shape) == 1:
            problemType = problemType.unsqueeze(1)

        x = torch.cat((x, problemType), dim=1)
        x = self.fc1(x)
        x = f.leaky_relu(x)
        x = self.fc2(x)
        x = f.leaky_relu(x)
        x = self.fcLast(x)
       # print(x)
        return x

class GGNNtest(nn.Module):
    '''
    Base GGNN class
    This Graph Neural Network includes a GRU to produce new node representations
    '''
    def __init__(self, passes, numEdgeSets, inputLayerSize, outputLayerSize, collate):
        super(GGNNtest, self).__init__()
        self.ggcs = nn.ModuleList([GatedGraphConv(inputLayerSize, outputLayerSize, aggr=collate) for _ in range(numEdgeSets)])
        self.passes = passes
        self.node_conv1 = GCNConv(inputLayerSize, 101)
        self.node_conv2 = GCNConv(101, 101)
        self.edge_conv1 = GATConv(101, 101, edge_dim=101, heads=4)
        self.edge_conv2 = GATConv(101, 101, edge_dim=101, heads=4)
        self.fc1 = nn.Linear(102, 80)
        self.fc2 = nn.Linear(80,80)
        self.fcLast = nn.Linear(80, outputLayerSize)
        self.collate=collate

    def forward(self, x, edge_index, edge_attr, batch, problemType):
        #x, edge_index, edge_attr, problemType = data.x, data.edge_index, data.edge_attr, data.problemType

        x = f.dropout(x, p=0.6, training=self.training)
        
        for _ in range(self.passes):
            placeholderX = torch.zeros_like(x)

            # 只使用每条边的第一个值作为边的类型
            edge_types = torch.unique(edge_attr[:, 0])  # 只取每条边的类型

            for val in edge_types:
            # 选择相应的 GGC 层
                index =int(val.item())
                gcn = self.ggcs[index]

                # 过滤出当前类型的边和相应的边特征
                relevant_edges = edge_index[:, (edge_attr[:, 0] == val).squeeze()]
                relevant_edge_attr = edge_attr[(edge_attr[:, 0] == val).squeeze()]

                # 应用 GGC 层并累加结果
                placeholderX += gcn(x, relevant_edges)

                # 节点特征更新（GCN层）并加上残差连接
                x_res = placeholderX
                placeholderX = f.relu(self.node_conv1(placeholderX, relevant_edges))
                placeholderX = placeholderX + x_res  # 残差连接
                print("Shape of placeholderX before edge_conv1:", placeholderX.shape)
                # 考虑边特征的节点特征更新（GAT层）并加上残差连接
                x_res = placeholderX
                placeholderX = f.relu(self.edge_conv1(placeholderX, edge_index, relevant_edge_attr))
                placeholderX = placeholderX + x_res  # 残差连接
                
                print("Shape of placeholderX before node_conv2:", placeholderX.shape)
                # 再次更新节点特征（GCN层）并加上残差连接
                x_res = placeholderX
                placeholderX = f.relu(self.node_conv2(placeholderX, relevant_edges))
                placeholderX = placeholderX + x_res  # 残差连接
                
                print("Shape of placeholderX before edge_conv2:", placeholderX.shape)
                # 再次更新节点特征（GAT层）并加上残差连接
                x_res = placeholderX
                placeholderX = f.relu(self.edge_conv2(placeholderX, edge_index, relevant_edge_attr))
                placeholderX = placeholderX + x_res  # 残差连接

            # 平均化特征
            x = placeholderX / len(edge_types)

        x = f.dropout(x, p=0.6, training=self.training)

        if self.collate == "sum":
            x = global_add_pool(x, batch)
        elif self.collate == "mean":
            x = global_mean_pool(x, batch)
        elif self.collate == "max":
            x = global_max_pool(x, batch)
        else:
            raise ValueError("Not a valid collate type")

        if len(problemType.shape) == 1:
            problemType = problemType.unsqueeze(1)

        x = torch.cat((x, problemType), dim=1)
        x = self.fc1(x)
        x = f.leaky_relu(x)
        x = self.fc2(x)
        x = f.leaky_relu(x)
        x = self.fcLast(x)
        print(x)
        return x

class GGNNtest2(nn.Module):
    def __init__(self, passes, numEdgeSets, inputLayerSize, outputLayerSize, collate):
        super(GGNNtest2, self).__init__()
        self.ggcs = nn.ModuleList([GatedGraphConv(inputLayerSize, outputLayerSize, aggr=collate) for _ in range(numEdgeSets)])
        self.passes = passes
        self.fc1 = nn.Linear(inputLayerSize + 1, 80)
        self.fc2 = nn.Linear(80, 80)
        self.fcLast = nn.Linear(80, outputLayerSize)
        self.collate = collate

    def forward(self, x, edge_index, edge_attr, batch, problemType):
        x = f.dropout(x, p=0.6, training=self.training)

        # 过滤出 edge_attr 为 12, 13, 14, 15 的边类型
        valid_vals = torch.tensor([12, 13, 14, 15], device=edge_attr.device)
        edge_types = torch.unique(edge_attr[:, 0])
        edge_types = edge_types[torch.isin(edge_types, valid_vals)]  # 仅保留 12, 13, 14, 15

        for _ in range(self.passes):
            placeholderX = torch.zeros_like(x)
            for val, gcn in zip(edge_types, self.ggcs):
                relevant_edges = edge_index[:, (edge_attr[:, 0] == val).squeeze()]
                placeholderX += gcn(x, relevant_edges)
            x = placeholderX / len(edge_types)

        x = f.dropout(x, p=0.6, training=self.training)
        if self.collate == "sum":
            x = global_add_pool(x, batch, dim=0)
        elif self.collate == "mean":
            x = global_mean_pool(x, batch, dim=0)
        elif self.collate == "max":
            x = global_max_pool(x, batch)
        else:
            raise ValueError("Not a valid collate type")

        if len(problemType.shape) == 1:
            problemType = problemType.unsqueeze(1)

        x = torch.cat((x, problemType), dim=1)
        x = self.fc1(x)
        x = f.leaky_relu(x)
        x = self.fc2(x)
        x = f.leaky_relu(x)
        x = self.fcLast(x)
        return x

class GGNNtest3(nn.Module):
    def __init__(self, passes, numEdgeSets, inputLayerSize, outputLayerSize, collate):
        super(GGNNtest3, self).__init__()
        self.ggcs = nn.ModuleList([GatedGraphConv(inputLayerSize, outputLayerSize, aggr=collate) for _ in range(numEdgeSets)])
        self.passes = passes
        self.fc1 = nn.Linear(inputLayerSize + 1, 80)
        self.fc2 = nn.Linear(80, 80)
        self.fcLast = nn.Linear(80, outputLayerSize)
        self.collate = collate

    def forward(self, x, edge_index, edge_attr, batch, problemType):
        x = f.dropout(x, p=0.6, training=self.training)

        # �~G滤�~G� edge_attr 为 12, 13, 14, 15 �~Z~D边类�~^~K
        edge_types = torch.unique(edge_attr[:, 0])

        for _ in range(self.passes):
            placeholderX = torch.zeros_like(x)
            for val, gcn in zip(edge_types, self.ggcs):
                relevant_edges = edge_index[:, (edge_attr[:, 0] == val).squeeze()]
                placeholderX += gcn(x, relevant_edges)
            x = placeholderX / len(edge_types)

        x = f.dropout(x, p=0.6, training=self.training)
        if self.collate == "sum":
            x = global_add_pool(x, batch, dim=0)
        elif self.collate == "mean":
            x = global_mean_pool(x, batch, dim=0)
        elif self.collate == "max":
            x = global_max_pool(x, batch)
        else:
            raise ValueError("Not a valid collate type")

        if len(problemType.shape) == 1:
            problemType = problemType.unsqueeze(1)

        x = torch.cat((x, problemType), dim=1)
        x = self.fc1(x)
        x = f.leaky_relu(x)
        x = self.fc2(x)
        x = f.leaky_relu(x)
        x = self.fcLast(x)
        return x

def get_conv_mp_out_size(in_size, last_layer, mps):
    size = in_size
    for mp in mps:
        size = (size - mp["kernel_size"]) // mp["stride"] + 1
   #     size = round((size - mp["kernel_size"]) / mp["stride"] + 1)
  #  size = round(size)+1 if size % 2 != 0 else size
    return int(size * last_layer["out_channels"])

def stable_softmax(x, dim):
    max_vals = torch.max(x, dim=dim, keepdim=True)[0]
    exps = torch.exp(x - max_vals)
    return exps / torch.sum(exps, dim=dim, keepdim=True)

class Conv(nn.Module):
    def __init__(self, conv1d_1, conv1d_2, maxpool1d_1, maxpool1d_2, fc_1_size, fc_2_size):

        super(Conv, self).__init__()
        self.conv1d_1_args = conv1d_1
        self.conv1d_1 = nn.Conv1d(**conv1d_1)
        self.conv1d_2 = nn.Conv1d(**conv1d_2)
        self.bn1 = nn.BatchNorm1d(conv1d_1['out_channels'])
        self.bn2 = nn.BatchNorm1d(conv1d_2['out_channels'])

        self.mp_1 = nn.MaxPool1d(**maxpool1d_1)
        self.mp_2 = nn.MaxPool1d(**maxpool1d_2)

        fc1_size = get_conv_mp_out_size(fc_1_size, conv1d_2, [maxpool1d_1, maxpool1d_2])
        fc2_size = get_conv_mp_out_size(fc_2_size, conv1d_2, [maxpool1d_1, maxpool1d_2])
        
        self.fc1 = nn.Linear(fc1_size, 128)
        self.fc2 = nn.Linear(fc2_size, 128)
        self.fcLast = nn.Linear(128, 18)  # Output size set to 10

        self.drop = nn.Dropout(p=0.6)

    def forward(self,hidden, x):
        concat = torch.cat((hidden, x), 1)
        batch_size=concat.shape[0]
        concat_size =(hidden.shape[1] + x.shape[1])
        concat = concat.view(batch_size, self.conv1d_1_args["in_channels"], concat_size)
        
       # print(concat.shape)
        Z = self.mp_1(f.relu(self.conv1d_1(concat)))
        Z = self.mp_2(f.relu(self.conv1d_2(Z)))
        
       # print(Z.shape)
        
        hidden = hidden.view(batch_size, self.conv1d_1_args["in_channels"], hidden.shape[1])
       # print(hidden.shape)
        Y = self.mp_1(f.relu(self.conv1d_1(hidden)))
        Y = self.mp_2(f.relu(self.conv1d_2(Y)))
        
       # print(Y.shape)
        Z_flatten_size = int(Z.shape[1] * Z.shape[-1])
        Y_flatten_size = int(Y.shape[1] * Y.shape[-1])
       # print(Z_flatten_size)
       # print(Y_flatten_size)
        Z = Z.view(batch_size, Z_flatten_size)
        Y = Y.view(batch_size, Y_flatten_size)
    #    print(Z.shape)
    #    print(Y.shape)
        Z = f.leaky_relu(self.fc1(Z))
        Y = x = f.leaky_relu(self.fc2(Y))
        res = Z + Y
  #      res = self.drop(res)
      #  print("---------")
       # print(self.fcLast(res))
       # probs = stable_softmax(self.fcLast(res), dim=1)
       # probs = torch.nn.functional.log_softmax(self.fcLast(res), dim=1)
        probs=self.fcLast(res)
        #probs = f.softmax(self.fcLast(res),dim=1)  # Output dimension is now 10
        return probs


class Conv1(nn.Module):
    def __init__(self, conv1d_1, conv1d_2, maxpool1d_1, maxpool1d_2, fc_1_size, fc_2_size):

        super(Conv1, self).__init__()
        self.conv1d_1_args = conv1d_1
        self.conv1d_1 = nn.Conv1d(**conv1d_1)
        self.conv1d_2 = nn.Conv1d(**conv1d_2)
        self.bn1 = nn.BatchNorm1d(conv1d_1['out_channels'])
        self.bn2 = nn.BatchNorm1d(conv1d_2['out_channels'])

        self.mp_1 = nn.MaxPool1d(**maxpool1d_1)
        self.mp_2 = nn.MaxPool1d(**maxpool1d_2)

        fc1_size = get_conv_mp_out_size(fc_1_size, conv1d_2, [maxpool1d_1, maxpool1d_2])
        fc2_size = get_conv_mp_out_size(fc_2_size, conv1d_2, [maxpool1d_1, maxpool1d_2])
        
        self.fc1 = nn.Linear(fc1_size,512)
        self.fc2 = nn.Linear(fc2_size, 256)
        self.fcLast = nn.Linear(768, 18)  # Output size set to 10

#        self.drop = nn.Dropout(p=0.5)

    def forward(self,hidden, x):
        concat = torch.cat((hidden, x), 1)
        batch_size=concat.shape[0]
        concat_size =(hidden.shape[1] + x.shape[1])
        concat = concat.view(-1, self.conv1d_1_args["in_channels"], concat_size)
        
       # print(concat.shape)
        Z = self.mp_1(f.relu(self.conv1d_1(concat)))
        Z = self.mp_2(f.relu(self.conv1d_2(Z)))
        
       # print(Z.shape)
        
        hidden = hidden.view(-1, self.conv1d_1_args["in_channels"], hidden.shape[1])
       # print(hidden.shape)
        Y = self.mp_1(f.relu(self.conv1d_1(hidden)))
        Y = self.mp_2(f.relu(self.conv1d_2(Y)))
        
       # print(Y.shape)
        Z_flatten_size = int(Z.shape[1] * Z.shape[-1])
        Y_flatten_size = int(Y.shape[1] * Y.shape[-1])
       # print(Z_flatten_size)
       # print(Y_flatten_size)
        Z = Z.view(-1, Z_flatten_size)
        Y = Y.view(-1, Y_flatten_size)
   #     print(Z.shape)
   #     print(Y.shape)
   #     Z = f.leaky_relu(self.fc1(Z))
   #     Y = x = f.leaky_relu(self.fc2(Y))
 #       res = Z + Y
        res = torch.cat((Z, Y), dim=1)
#        res = self.drop(res)
      #  print("---------")
       # print(self.fcLast(res))
       # probs = stable_softmax(self.fcLast(res), dim=1)
  #     # probs = torch.nn.functional.log_softmax(self.fcLast(res), dim=1)
  #      probs=self.fcLast(res)
        #probs = f.softmax(self.fcLast(res),dim=1)  # Output dimension is now 10
        return res

class GGNNtest4(nn.Module):
    def __init__(self, passes, numEdgeSets, inputLayerSize, outputLayerSize, collate):
        super(GGNNtest4, self).__init__()
        self.ggcs = nn.ModuleList(
            [GatedGraphConv(inputLayerSize, outputLayerSize, aggr=collate) for _ in range(numEdgeSets)])
        self.gcn = GatedGraphConv(inputLayerSize, outputLayerSize, aggr=collate)
        self.passes = passes
        self.edge_conv = nn.ModuleList(
                [GATConv(inputLayerSize, inputLayerSize, edge_dim=inputLayerSize-1, heads=4) for _ in range(numEdgeSets)])
        self.gat = GATConv(inputLayerSize, inputLayerSize, edge_dim=inputLayerSize,heads=4)
        self.linear = nn.Linear(404, 101)
        self.conv1d_1 = {
            'in_channels': 1,  # 输入通道数，假设为10
            'out_channels': 32,  # 输出通道数，假设为32
            'kernel_size': 3,  # 卷积核大小
            'stride': 1,  # 步幅
            'padding': 1  # 填充
        }

        self.maxpool1d_1 = {
            'kernel_size': 2,  # 池化核大小
            'stride': 2  # 步幅
        }

        self.conv1d_2 = {
            'in_channels': 32,  # 输入通道数，要与conv1d_1的输出通道数匹配
            'out_channels': 64,  # 输出通道数，假设为64
            'kernel_size': 3,  # 卷积核大小
            'stride': 1,  # 步幅
            'padding': 1  # 填充
        }

        self.maxpool1d_2 = {
            'kernel_size': 2,  # 池化核大小
            'stride': 2  # 步幅
        }

       # print("INput")
       # print(inputLayerSize)
        # Convolution module
        self.conv = Conv(self.conv1d_1, self.conv1d_2, self.maxpool1d_1, self.maxpool1d_2,
                         fc_1_size= (inputLayerSize+1)*2,
                         fc_2_size= inputLayerSize+1)
        self.collate = collate

    def forward(self, x, edge_index, edge_attr, batch, problemType):
        
        data_x=x
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        
        x = x.to(device)
        edge_index = edge_index.to(device)
        edge_attr = edge_attr.float().to(device)
        x = f.dropout(x, p=0.6, training=self.training)

        # Gated Graph Convolution GGNNtest4
       # for _ in range(self.passes):
         #   placeholderX = torch.zeros_like(x)
          #  for val, gcn in zip(torch.unique(edge_attr), self.ggcs):
           #     mask = (edge_attr == val).squeeze()
            #    edge_subset = edge_index[:, mask]
             #   placeholderX += gcn(x, edge_subset)
           # x = placeholderX / len(torch.unique(edge_attr))
           # data_x= data_x / len(torch.unique(edge_attr))
           # x=placeholderX
       # x=self.ggc(x,edge_index) 
        #print(x.shape)
        #print(data_x.shape)
        #print(edge_index)
        ##GGNNtest5
        for _ in range(self.passes):
            placeholderX = torch.zeros_like(x)
         #   print(edge_index.dtype)
            for val, gcn in zip(torch.unique(edge_attr[:, 0]), self.ggcs):  # 提取不同边类型
          #      print(edge_index.dtype)
                mask = (edge_attr[:, 0] == val).squeeze()  # 用第一个值作为边类型的标识
                edge_subset = edge_index[:, mask].long()  # 选出与该边类型对应的边列表

                # 提取该边类型对应的文本嵌入向量
                embedded_text = edge_attr[mask][:, 1:]  # 从 edge_attr 中提取文本嵌入
               # print(embedded_text)
                # 将文本嵌入向量与边列表（edge_index）进行拼接
               # edge_input = torch.cat([edge_subset.T, embedded_text], dim=1)  # 这里假设每条边连接两个节点
               # edge_input[:, :edge_subset.shape[0]] = edge_input[:, :edge_subset.shape[0]].long()
                # 将边的文本嵌入信息与卷积操作结合
                #print(x.shape)
                #print(edge_subset.shape)
                #print(embedded_text.shape)
                  # 将拼接后的边信息传入 gcn 中
                source_nodes = edge_subset[0]
                target_nodes = edge_subset[1]
                current_edge_attr = edge_attr[mask].to(device).float()
                num_nodes = x.size(0)
                current_aggregated = torch.zeros(num_nodes, current_edge_attr.size(1),device=device).float()
              #  print(device)
                current_aggregated = current_aggregated.scatter_add(
                        0, 
                        edge_subset[0].unsqueeze(1).expand(-1, current_edge_attr.size(1)), 
                        current_edge_attr
                )
                current_aggregated = current_aggregated.scatter_add(
                        0, 
                        edge_subset[1].unsqueeze(1).expand(-1, current_edge_attr.size(1)),
                        current_edge_attr
                )
                x += current_aggregated
            #    placeholderX += gcn(x,edge_input.T)
                placeholderX += gcn(x,edge_subset)
                #placeholderX += gat(x,edge_subset,embedded_text)
              #  print(edge_input.dtype)
            # 更新节点特征
            x = placeholderX / len(torch.unique(edge_attr[:, 0]))  # 根据边类型数量进行归一
        num_nodes = x.size(0)
        current_aggregated = torch.zeros(num_nodes, edge_attr.size(1),device=device).float()
        current_aggregated = current_aggregated.scatter_add(
            0,
            edge_index[0].unsqueeze(1).expand(-1, edge_attr.size(1)),
            edge_attr
        )
        current_aggregated = current_aggregated.scatter_add(
            0,
            edge_index[1].unsqueeze(1).expand(-1, edge_attr.size(1)),
            edge_attr
        )
        valid_vals = torch.tensor([12, 13, 14, 15], device=edge_attr.device)
        edge_types = torch.unique(edge_attr[:, 0])
        valid_edge_types = edge_types[torch.isin(edge_types, valid_vals)]
        valid_indices = torch.isin(edge_attr[:,0],valid_edge_types)
        selected_edge_index = edge_index[:, valid_indices]
       # print(x.shape)
      #  x += current_aggregated 
        #x = self.gat(x,edge_index,edge_attr)
       # print(x.shape)
        x = self.gcn(x,selected_edge_index)

        x = f.dropout(x, p=0.6, training=self.training)
        #print(x.shape)

        if self.collate == "sum":
            x = global_add_pool(x, batch, dim=0)
        elif self.collate == "mean":
            x = global_mean_pool(x, batch, dim=0)
        elif self.collate == "max":
            x = global_max_pool(x, batch)
            data_x = global_max_pool(data_x,batch)
        else:
            raise ValueError("Not a valid collate type")

        if len(problemType.shape) == 1:
            problemType = problemType.unsqueeze(1)

       # problemType = problemType[batch]

        x = torch.cat((x, problemType), dim=1)
        data_x=torch.cat((data_x, problemType), dim=1) 
        
        #print(problemType.shape)
       # print(x.shape)
        #print(data_x.shape)
       # if x.shape[1] != data_x.shape[1]:
           # data_x = torch.cat((data_x, problemType), dim=1)
      # Pass through convolution module
       # print(x)
       # print(data_x)
        x = self.conv(x,data_x)
    
     #   print(x)
        return x

class GGNNtest5(nn.Module):
    def __init__(self, passes, numEdgeSets, inputLayerSize, outputLayerSize, collate):
        super(GGNNtest5, self).__init__()
        self.ggcs = nn.ModuleList(
            [GatedGraphConv(inputLayerSize, outputLayerSize, aggr=collate) for _ in range(numEdgeSets)])
        self.passes = passes
        self.ggc=GatedGraphConv(inputLayerSize, outputLayerSize, aggr=collate)
        self.conv1d_1 = {
            'in_channels': 1,  # �~S�~E��~@~Z�~A~S�~U��~L�~A~G设为10
            'out_channels': 32,  # �~S�~G��~@~Z�~A~S�~U��~L�~A~G设为32
            'kernel_size': 3,  # �~M�积�| �大�~O
            'stride': 1,  # 步�~E
            'padding': 1  # 填�~E~E
        }

        self.maxpool1d_1 = {
            'kernel_size': 2,  # �| �~L~V�| �大�~O
            'stride': 2  # 步�~E
        }

        self.conv1d_2 = {
            'in_channels': 32,  # �~S�~E��~@~Z�~A~S�~U��~L�~A�~Nconv1d_1�~Z~D�~S�~G��~@~Z�~A~S�~U��~L��~E~M
            'out_channels': 64,  # �~S�~G��~@~Z�~A~S�~U��~L�~A~G设为64
            'kernel_size': 3,  # �~M�积�| �大�~O
            'stride': 1,  # 步�~E
            'padding': 1  # 填�~E~E
        }

        self.maxpool1d_2 = {
            'kernel_size': 2,  # �| �~L~V�| �大�~O
            'stride': 2  # 步�~E
        }

       # print("INput")
       # print(inputLayerSize)
        # Convolution module
        self.conv = Conv(self.conv1d_1, self.conv1d_2, self.maxpool1d_1, self.maxpool1d_2,
                         fc_1_size= (inputLayerSize+1)*2,
                         fc_2_size= inputLayerSize+1)
        self.collate = collate

    def forward(self, x, edge_index, edge_attr, batch, problemType):
        data_x=x

        x = f.dropout(x, p=0.6, training=self.training)

        # Gated Graph Convolution GGNNtest4
        for _ in range(self.passes):
            placeholderX = torch.zeros_like(x)
            for val, gcn in zip(torch.unique(edge_attr[:,0]), self.ggcs):
                mask = (edge_attr[:,0] == val).squeeze()
                edge_subset = edge_index[:, mask]
                placeholderX += gcn(x, edge_subset)
            x = placeholderX / len(torch.unique(edge_attr[:,0]))
            #data_x= data_x / len(torch.unique(edge_attr))
           # x=placeholderX
       # x=self.ggc(x,edge_index)
        #print(x.shape)
        #print(data_x.shape)
        x = f.dropout(x, p=0.6, training=self.training)
        #print(x.shape)

        if self.collate == "sum":
            x = global_add_pool(x, batch, dim=0)
        elif self.collate == "mean":
            x = global_mean_pool(x, batch, dim=0)
        elif self.collate == "max":
            x = global_max_pool(x, batch)
            data_x = global_max_pool(data_x,batch)
        else:
            raise ValueError("Not a valid collate type")

        if len(problemType.shape) == 1:
            problemType = problemType.unsqueeze(1)

       # problemType = problemType[batch]

        x = torch.cat((x, problemType), dim=1)
        data_x=torch.cat((data_x, problemType), dim=1)

        #print(problemType.shape)
        #print(x.shape)
        #print(data_x.shape)
       # if x.shape[1] != data_x.shape[1]:
           # data_x = torch.cat((data_x, problemType), dim=1)
      # Pass through convolution module
       # print(x)
       # print(data_x)
        x = self.conv(x,data_x)

        print(x)
        return x

class GGNNtest6(nn.Module):
    def __init__(self, passes, numEdgeSets, inputLayerSize, outputLayerSize, collate):
        super(GGNNtest6, self).__init__()
        # 使用 GATv2Conv 来处理节点与边的特征
        self.gat = GATConv(inputLayerSize, inputLayerSize, edge_dim=inputLayerSize,heads=4)
        self.ggcs = nn.ModuleList(
            [GatedGraphConv(inputLayerSize, outputLayerSize, aggr=collate) for _ in range(numEdgeSets)])
        self.gcn = GatedGraphConv(inputLayerSize, outputLayerSize, aggr=collate)
        self.linear = nn.Linear(404, 101)
        self.passes = passes
        self.mlp = nn.Sequential(
        nn.Linear(inputLayerSize + inputLayerSize, 256),
        nn.ReLU(),
        nn.Linear(256, 128),
        nn.ReLU(),
        nn.Linear(128, inputLayerSize)  # 输出维度要与 x 的特征维度匹配
        )

        self.conv1d_1 = {
            'in_channels': 1,
            'out_channels': 32,
            'kernel_size': 3,
            'stride': 1,
            'padding': 1
        }
        self.conv1d_2 = {
            'in_channels': 32,
            'out_channels': 64,
            'kernel_size': 3,
            'stride': 1,
            'padding': 1
        }
        self.maxpool1d_1 = {
            'kernel_size': 2,
            'stride': 2
        }
        self.maxpool1d_2 = {
            'kernel_size': 2,
            'stride': 2
        }
        # Convolution module
        self.conv = Conv(self.conv1d_1, self.conv1d_2, self.maxpool1d_1, self.maxpool1d_2,
                         fc_1_size=(inputLayerSize + 1) * 2,
                         fc_2_size=inputLayerSize + 1)
        self.collate = collate

    def forward(self, x, edge_index, edge_attr, batch, problemType):
        data_x = x.float()
        
        x = x.float()  # 将x转换为float32
        edge_attr = edge_attr.float()  # 如果edge_attr也有类似问题

      #  print(x.shape)
      #  print(edge_index.shape)
      #  print(edge_attr.shape)
      #  x = f.dropout(x, p=0.6, training=self.training)
  #      print(edge_index.shape)
        num_nodes = x.size(0)
        num_edges = edge_index.size(1)
        

        # 初始化用于聚合的特征张量
        aggregated_edge_features0 = torch.zeros(num_nodes, edge_attr.size(1), device=edge_attr.device)
        aggregated_edge_features1 = torch.zeros(num_nodes, edge_attr.size(1), device=edge_attr.device)
       # print(aggregated_edge_features0.shape)   
       # print(aggregated_edge_features1.shape)

        # 聚合边特征到节点
       # aggregated_edge_features0 = scatter(edge_attr, edge_index[0], dim=0, reduce="sum")
       # aggregated_edge_features1 = scatter(edge_attr, edge_index[1], dim=0, reduce="sum")

        # 使用 scatter 操作，将边特征聚合到起始节点和终止节点
        scatter(edge_attr, edge_index[0], dim=0, out=aggregated_edge_features0, reduce="mean")
        scatter(edge_attr, edge_index[1], dim=0, out=aggregated_edge_features1, reduce="mean")
       # print(aggregated_edge_features0.shape)
       # print(aggregated_edge_features1.shape)
       # aggregated_edge_features0 = torch.cat((aggregated_edge_features0, torch.zeros(num_nodes - aggregated_edge_features0.size(0), edge_attr.size(1), device=edge_attr.device)), dim=0)
       # aggregated_edge_features1 = torch.cat((aggregated_edge_features1, torch.zeros(num_nodes - aggregated_edge_features1.size(0), edge_attr.size(1), device=edge_attr.device)), dim=0)
        aggregated_edge_features = aggregated_edge_features0 + aggregated_edge_features1
        
      #  x = x + aggregated_edge_features
      #  data_x = data_x + aggregated_edge_features
      #  x_cat = torch.cat([x,aggregated_edge_features],dim=1)
      #  x = self.mlp(x_cat)

       # print(x.shape)
        # 使用 GATv2Conv 进行卷积，传递边特征
       # x = self.gat(x, edge_index, edge_attr=edge_attr)
       # x = self.linear(x)
        for _ in range(self.passes):
            placeholderX = torch.zeros_like(x)
          #  for val, gcn in zip(torch.unique(edge_attr[:,0]), self.ggcs):
            for val, gcn in zip([0], self.ggcs):
            #    print(val)
            #    print("---------")
            #    print(edge_attr[:,0])
               # continue
            #    mask = (edge_attr[:,0] == val).squeeze()
            #    edge_subset = edge_index[:, mask]
            #    placeholderX += gcn(x, edge_subset)
                placeholderX += gcn(x,edge_index)
   #         x = placeholderX / len(torch.unique(edge_attr[:,0]))
            x = placeholderX
       # x = self.gcn(x, edge_index)

      #  x = f.dropout(x, p=0.6, training=self.training)
        if torch.all(aggregated_edge_features == 0):
            x = x
        else:
            x = torch.cat([x, aggregated_edge_features], dim=1)
            x = self.mlp(x)
     #   x = f.dropout(x, p=0.6, training=self.training)
        # 根据不同的collate类型选择全局池化方法
        if self.collate == "sum":
            x = global_add_pool(x, batch)
        elif self.collate == "mean":
            x = global_mean_pool(x, batch)
        elif self.collate == "max":
            x = global_max_pool(x, batch)
            data_x = global_max_pool(data_x, batch)
        else:
            raise ValueError("Not a valid collate type")

        # 处理 problemType
        if len(problemType.shape) == 1:
            problemType = problemType.unsqueeze(1)

        # 聚合 problemType 到特征向量
        x = torch.cat((x, problemType), dim=1)
        data_x = torch.cat((data_x, problemType), dim=1)

        # 调用卷积模块
        x = self.conv(x, data_x)
        return x

class GGNNtest7(nn.Module):
    def __init__(self, passes, numEdgeSets, inputLayerSize, outputLayerSize, collate):
        super(GGNNtest7, self).__init__()
        # 使用 GATv2Conv 来处理节点与边的特征
        self.gat = GATConv(inputLayerSize, inputLayerSize, edge_dim=inputLayerSize,heads=4)
        self.ggcs = nn.ModuleList(
            [GatedGraphConv(inputLayerSize, outputLayerSize, aggr=collate) for _ in range(numEdgeSets)])
        self.gcn = GatedGraphConv(inputLayerSize, outputLayerSize, aggr=collate)
        self.linear = nn.Linear(404, 101)
        self.passes = passes
        self.mlp = nn.Sequential(
        nn.Linear(inputLayerSize * 3, 512),
        nn.ReLU(),
        nn.Linear(512, 256),
        nn.ReLU(),
        nn.Linear(256, inputLayerSize)  # 输出维度要与 x 的特征维度匹配
        )
        self.fc1 = nn.Linear(901, 512)
        self.fc2 = nn.Linear(512,128)
        self.fcLast = nn.Linear(128, outputLayerSize)

        self.conv1d_1 = {
            'in_channels': 1,
            'out_channels': 2,
            'kernel_size': 3,
            'stride': 1,
            'padding': 1
        }
        self.conv1d_2 = {
            'in_channels': 2,
            'out_channels': 4,
            'kernel_size': 3,
            'stride': 1,
            'padding': 1
        }
        self.maxpool1d_1 = {
            'kernel_size': 2,
            'stride': 2
        }
        self.maxpool1d_2 = {
            'kernel_size': 2,
            'stride': 2
        }
        # Convolution module
        self.conv = Conv1(self.conv1d_1, self.conv1d_2, self.maxpool1d_1, self.maxpool1d_2,
                         fc_1_size=inputLayerSize * 2,
                         fc_2_size=inputLayerSize)
        self.collate = collate

    def forward(self, x, edge_index, edge_attr, batch, problemType):
        data_x = x.float()
        
        x = x.float()  # 将x转换为float32
        edge_attr = edge_attr.float()  # 如果edge_attr也有类似问题

      #  print(x.shape)
      #  print(edge_index.shape)
      #  print(edge_attr.shape)
      #  x = f.dropout(x, p=0.6, training=self.training)
  #      print(edge_index.shape)
        num_nodes = x.size(0)
        num_edges = edge_index.size(1)
        

        # 初始化用于聚合的特征张量
        aggregated_edge_features0 = torch.zeros(num_nodes, edge_attr.size(1), device=edge_attr.device)
        aggregated_edge_features1 = torch.zeros(num_nodes, edge_attr.size(1), device=edge_attr.device)
       # print(aggregated_edge_features0.shape)   
       # print(aggregated_edge_features1.shape)

        # 聚合边特征到节点
       # aggregated_edge_features0 = scatter(edge_attr, edge_index[0], dim=0, reduce="sum")
       # aggregated_edge_features1 = scatter(edge_attr, edge_index[1], dim=0, reduce="sum")

        # 使用 scatter 操作，将边特征聚合到起始节点和终止节点
        scatter(edge_attr, edge_index[0], dim=0, out=aggregated_edge_features0, reduce="mean")
        scatter(edge_attr, edge_index[1], dim=0, out=aggregated_edge_features1, reduce="mean")
       # print(aggregated_edge_features0.shape)
       # print(aggregated_edge_features1.shape)
       # aggregated_edge_features0 = torch.cat((aggregated_edge_features0, torch.zeros(num_nodes - aggregated_edge_features0.size(0), edge_attr.size(1), device=edge_attr.device)), dim=0)
       # aggregated_edge_features1 = torch.cat((aggregated_edge_features1, torch.zeros(num_nodes - aggregated_edge_features1.size(0), edge_attr.size(1), device=edge_attr.device)), dim=0)
   #     aggregated_edge_features = aggregated_edge_features0 + aggregated_edge_features1
        aggregated_edge_features = torch.cat([aggregated_edge_features0, aggregated_edge_features1], dim=1) 
      #  x = x + aggregated_edge_features
      #  data_x = data_x + aggregated_edge_features
      #  x_cat = torch.cat([x,aggregated_edge_features],dim=1)
      #  x = self.mlp(x_cat)

       # print(x.shape)
        # 使用 GATv2Conv 进行卷积，传递边特征
       # x = self.gat(x, edge_index, edge_attr=edge_attr)
       # x = self.linear(x)
   #     if torch.all(aggregated_edge_features == 0):
   #         x = x
   #     else:
   #         x = torch.cat([x, aggregated_edge_features], dim=1)
   #         x = self.mlp(x)

        for _ in range(self.passes):
            placeholderX = torch.zeros_like(x)
          #  for val, gcn in zip(torch.unique(edge_attr[:,0]), self.ggcs):
            for val, gcn in zip([0], self.ggcs):
            #    print(val)
            #    print("---------")
            #    print(edge_attr[:,0])
               # continue
            #    mask = (edge_attr[:,0] == val).squeeze()
            #    edge_subset = edge_index[:, mask]
            #    placeholderX += gcn(x, edge_subset)
                placeholderX += gcn(x,edge_index)
   #         x = placeholderX / len(torch.unique(edge_attr[:,0]))
            x = placeholderX
       # x = self.gcn(x, edge_index)

      #  x = f.dropout(x, p=0.6, training=self.training)
        if torch.all(aggregated_edge_features == 0):
            x = x
        else:
            x = torch.cat([x, aggregated_edge_features], dim=1)
            x = self.mlp(x)

        x = self.conv(x, data_x)
    #    print(x.shape)
       # x = f.dropout(x, p=0.5, training=self.training)
        # 根据不同的collate类型选择全局池化方法
        if self.collate == "sum":
            x = global_add_pool(x, batch)
        elif self.collate == "mean":
            x = global_mean_pool(x, batch)
        elif self.collate == "max":
            x = global_max_pool(x, batch)
    #        data_x = global_max_pool(data_x, batch)
        else:
            raise ValueError("Not a valid collate type")

        # 处理 problemType
        if len(problemType.shape) == 1:
            problemType = problemType.unsqueeze(1)
       # print(x.shape)
        # 聚合 problemType 到特征向量
        x = torch.cat((x, problemType), dim=1)
       # print(x.shape)
        x = self.fc1(x)
        x = f.leaky_relu(x)
        x = self.fc2(x)
        x = f.leaky_relu(x)
        x = self.fcLast(x)
    #    data_x = torch.cat((data_x, problemType), dim=1)
    #    print(x.shape)
    #    print(data_x.shape)
             # 调用卷积模块
      #  x = self.conv(x, data_x)
        return x

class GGNNtest8(nn.Module):
    def __init__(self, passes, numEdgeSets, inputLayerSize, outputLayerSize, collate):
        super(GGNNtest8, self).__init__()
        # 使用 GATv2Conv 来处理节点与边的特征
        self.gat = GATConv(inputLayerSize, inputLayerSize, edge_dim=inputLayerSize,heads=4)
   #     print(outputLayerSize)
        self.ggcs = nn.ModuleList(
            [GatedGraphConv(inputLayerSize, outputLayerSize, aggr=collate) for _ in range(numEdgeSets)])
        self.gcn = GatedGraphConv(inputLayerSize, outputLayerSize, aggr=collate)
        self.linear = nn.Linear(404, 101)
        self.passes = passes
        self.mlp = nn.Sequential(
        nn.Linear(inputLayerSize * 3, 512),
        nn.ReLU(),
        nn.Linear(512, 256),
        nn.ReLU(),
        nn.Linear(256, inputLayerSize)  # 输出维度要与 x 的特征维度匹配
        )
        self.fc1 = nn.Linear(601, 256)
        self.fc2 = nn.Linear(256,128)
        self.fcLast = nn.Linear(128, outputLayerSize)

        self.conv1d_1 = {
            'in_channels': 1,
            'out_channels': 4,
            'kernel_size': 3,
            'stride': 1,
            'padding': 1
        }
        self.conv1d_2 = {
            'in_channels': 4,
            'out_channels': 8,
            'kernel_size': 3,
            'stride': 1,
            'padding': 1
        }
        self.maxpool1d_1 = {
            'kernel_size': 2,
            'stride': 2
        }
        self.maxpool1d_2 = {
            'kernel_size': 2,
            'stride': 2
        }
        # Convolution module
        self.conv = Conv1(self.conv1d_1, self.conv1d_2, self.maxpool1d_1, self.maxpool1d_2,
                         fc_1_size=(inputLayerSize) * 2,
                         fc_2_size=inputLayerSize)
        self.collate = collate

    def forward(self, x, edge_index, edge_attr, batch, problemType):
        data_x = x.float()

        x = x.float()  # 将x转换为float32
        edge_attr = edge_attr.float()  # 如果edge_attr也有类似问题

      #  print(x.shape)
      #  print(edge_index.shape)
      #  print(edge_attr.shape)
      #  x = f.dropout(x, p=0.6, training=self.training)

        num_nodes = x.size(0)
        num_edges = edge_index.size(1)

        # 初始化用于聚合的特征张量
        aggregated_edge_features0 = torch.zeros(num_nodes, edge_attr.size(1), device=edge_attr.device)
        aggregated_edge_features1 = torch.zeros(num_nodes, edge_attr.size(1), device=edge_attr.device)
       # print(aggregated_edge_features0.shape)
       # print(aggregated_edge_features1.shape)

        # 聚合边特征到节点
       # aggregated_edge_features0 = scatter(edge_attr, edge_index[0], dim=0, reduce="sum")
       # aggregated_edge_features1 = scatter(edge_attr, edge_index[1], dim=0, reduce="sum")

        # 使用 scatter 操作，将边特征聚合到起始节点和终止节点
        scatter(edge_attr, edge_index[0], dim=0, out=aggregated_edge_features0, reduce="mean")
        scatter(edge_attr, edge_index[1], dim=0, out=aggregated_edge_features1, reduce="mean")
       # print(aggregated_edge_features0.shape)
       # print(aggregated_edge_features1.shape)
       # aggregated_edge_features0 = torch.cat((aggregated_edge_features0, torch.zeros(num_nodes - aggregated_edge_features0.size(0), edge_attr.size(1), device=edge_attr.device)), dim=0)
       # aggregated_edge_features1 = torch.cat((aggregated_edge_features1, torch.zeros(num_nodes - aggregated_edge_features1.size(0), edge_attr.size(1), device=edge_attr.device)), dim=0)
        aggregated_edge_features = (aggregated_edge_features0 + aggregated_edge_features1) / 2
       # aggregated_edge_features = torch.cat([aggregated_edge_features0, aggregated_edge_features1], dim=1)
       # x = x + aggregated_edge_features
    #    data_x = torch.cat([data_x, aggregated_edge_features], dim=1)
      #  x_cat = torch.cat([x,aggregated_edge_features],dim=1)
      #  x = self.mlp(x_cat)

     #   print(x.shape)
     #   print(x)
        # 使用 GATv2Conv 进行卷积，传递边特征
       # x = self.gat(x, edge_index, edge_attr=edge_attr)
       # x = self.linear(x)
        max_dim = 101
#        unique_rows = [tuple(row.tolist()) for row in x]
#        unique_rows_set = list(set(unique_rows))
#        unique_rows_dict = {row: index for index, row in enumerate(unique_rows_set)}
     #   num_dimensions = (len(unique_rows_dict) + max_dim -1) // max_dim
        one_hot_list = []
        num_nodes = x.shape[0]
        
        first_values = x[:, 0]
        indices = first_values % max_dim

        one_hot_tensor = torch.zeros(num_nodes, max_dim)
        one_hot_tensor[torch.arange(num_nodes), indices.long()] = 1
        one_hot_x = one_hot_tensor.to(device=edge_attr.device)
    #    print(one_hot_x.shape)
    #    print(one_hot_x)
#        for row in x:
#            row_tuple = tuple(row.tolist())  # ~F row 转~M为~E~C~D~L~[| 为~H~W表~M~O~S~H~L~L~F~E~C~D~O以
            # ~B~^~\该~L~\~G~N~G~L~H~F~E~M~@个~V~Z~D One-hot ~V| ~A
#            index = unique_rows_dict[row_tuple]
            # 为该~L~T~_~H~P对~T~Z~D One-hot ~V| ~A
#            one_hot = torch.zeros(max_dim)
#            if index < max_dim:
#                one_hot[index] = 1
#            else:
#                bit_count = index // max_dim
#                remainder = index % max_dim
#                for i in range(bit_count + 1):
#                    if i == 0:
#                        one_hot[remainder] = 1
#                    else:
#                        one_hot[i-1] = 1
         #   one_hot[dim_index][] = 1  # 设置该~M置为 1
#            one_hot_list.append(one_hot)

        # Step 2: 转~M为 tensor 并~S~G~S~^~\
  #      one_hot_x = torch.stack(one_hot_list).to(device=edge_attr.device)
      #  print(outputLayerSize)
     #   print(one_hot_x.shape)
        if torch.all(aggregated_edge_features == 0):
            one_hot_x = one_hot_x
        else:
            one_hot_x = one_hot_x + aggregated_edge_features
           # x = self.mlp(x)
        for _ in range(self.passes):
            placeholderX = torch.zeros_like(one_hot_x)
          #  for val, gcn in zip(torch.unique(edge_attr[:,0]), self.ggcs):
            for val, gcn in zip([0], self.ggcs):
         #       mask = (edge_attr[:, 0] == val).squeeze()
         #       edge_subset = edge_index[:, mask]

         #       source_nodes = edge_subset[0]
         #       target_nodes = edge_subset[1]

         #       merged_nodes = torch.cat((source_nodes, target_nodes), dim=0)
         #       unique_nodes = torch.unique(merged_nodes)
         #       unique_node_features = x[unique_nodes]
         #       new_edge_index = torch.stack((
         #           torch.index_select(unique_nodes, 0, source_nodes),
         #           torch.index_select(unique_nodes, 0, target_nodes)
         #           ),dim=0)
         #       placeholderX += gcn(unique_node_features, new_edge_index)
            #    print(val)
            #    print("---------")
            #    print(edge_attr[:,0])
               # continue
            #    mask = (edge_attr[:,0] == val).squeeze()
            #    edge_subset = edge_index[:, mask]
            #    placeholderX += gcn(x, edge_subset)
                placeholderX += gcn(one_hot_x,edge_index)
   #         x = placeholderX / len(torch.unique(edge_attr[:,0]))
            x = placeholderX
       # x = self.gcn(x, edge_index)
        x = self.conv(data_x, x)
     #   print(x.shape)
       # x = f.dropout(x, p=0.6, training=self.training)

        # 根据不同的collate类型选择全局池化方法
        if self.collate == "sum":
            x = global_add_pool(x, batch)
        elif self.collate == "mean":
            x = global_mean_pool(x, batch)
        elif self.collate == "max":
            x = global_max_pool(x, batch)
          #  data_x = global_max_pool(data_x, batch)
        else:
            raise ValueError("Not a valid collate type")

        # 处理 problemType
        if len(problemType.shape) == 1:
            problemType = problemType.unsqueeze(1)

        # 聚合 problemType 到特征向量
      #  x = torch.cat((x, problemType), dim=1)
      #  data_x = torch.cat((data_x, problemType), dim=1)

        # 调用卷积模块
      #  x = self.conv(x, data_x)
#        print(x.shape)
        x = torch.cat((x, problemType), dim=1)
       # print(x.shape)
        x = self.fc1(x)
        x = f.leaky_relu(x)
        x = self.fc2(x)
        x = f.leaky_relu(x)
        x = self.fcLast(x)
        return x

class EdgeAwareGraphSAGE0(nn.Module):
    def __init__(self, passes, numEdgeSets, inputLayerSize, outputLayerSize, collate):
        super(EdgeAwareGraphSAGE0, self).__init__()

       # ~Z~I GraphSAGE ~B~L使~T~\~@大| ~L~V~A~Z~P~H
        self.sage_layer = SAGEConv(inputLayerSize * 2, inputLayerSize, aggr='max')  # ~A~Z~P~H~V~O设置为'max'
        self.passes = passes
        self.conv1d_1 = {
            'in_channels': 1,
            'out_channels': 32,
            'kernel_size': 3,
            'stride': 1,
            'padding': 1
        }
        self.conv1d_2 = {
            'in_channels': 32,
            'out_channels':64,
            'kernel_size': 3,
            'stride': 1,
            'padding': 1
        }
        self.maxpool1d_1 = {
            'kernel_size': 2,
            'stride': 2
        }
        self.maxpool1d_2 = {
            'kernel_size': 2,
            'stride': 2
        }
        # Convolution module
        self.conv = Conv(self.conv1d_1, self.conv1d_2, self.maxpool1d_1, self.maxpool1d_2,
                         fc_1_size=(inputLayerSize + 1) * 2,
                         fc_2_size=inputLayerSize + 1)
        self.collate = collate
        # ~Z~I MLP ~B
        self.fc1 = nn.Linear(1350 + 1, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fcLast = nn.Linear(64, outputLayerSize)

    def forward(self, x, edge_index, edge_attr, batch, problemType):
        data_x = x.float()
        edge_attr = edge_attr.float()

        num_nodes = x.size(0)
        num_edges = edge_index.size(1)

        aggregated_edge_features0 = torch.zeros(num_nodes, edge_attr.size(1), device=edge_attr.device)
        aggregated_edge_features1 = torch.zeros(num_nodes, edge_attr.size(1), device=edge_attr.device)

        scatter(edge_attr, edge_index[0], dim=0, out=aggregated_edge_features0, reduce="max")
        scatter(edge_attr, edge_index[1], dim=0, out=aggregated_edge_features1, reduce="max")
        aggregated_edge_features = (aggregated_edge_features0 + aggregated_edge_features1)

        x = torch.cat([x, aggregated_edge_features], dim=1)
        x = self.sage_layer(x, edge_index)
        if self.collate == "max":
            x = global_max_pool(x, batch)
            data_x = global_max_pool(data_x, batch)
        if len(problemType.shape) == 1:
            problemType = problemType.unsqueeze(1)
        x = torch.cat((x, problemType), dim=1)
        data_x = torch.cat((data_x, problemType), dim=1)
        x = self.conv(x, data_x)

        return x

class EdgeAwareGraphSAGE(nn.Module):
    def __init__(self, passes, numEdgeSets, inputLayerSize, outputLayerSize, collate):
        super(EdgeAwareGraphSAGE, self).__init__()

       # 定义 GraphSAGE 层，使用最大池化聚合
        self.sage_layer = SAGEConv(inputLayerSize * 2, inputLayerSize, aggr='max')  # 聚合方式设置为'max'
        self.passes = passes
        self.conv1d_1 = {
            'in_channels': 1,
            'out_channels': 32,
            'kernel_size': 3,
            'stride': 1,
            'padding': 1
        }
        self.conv1d_2 = {
            'in_channels': 32,
            'out_channels':64,
            'kernel_size': 3,
            'stride': 1,
            'padding': 1
        }
        self.maxpool1d_1 = {
            'kernel_size': 2,
            'stride': 2
        }
        self.maxpool1d_2 = {
            'kernel_size': 2,
            'stride': 2
        }
        # Convolution module
        self.conv = Conv(self.conv1d_1, self.conv1d_2, self.maxpool1d_1, self.maxpool1d_2,
                         fc_1_size=(inputLayerSize + 1) * 2,
                         fc_2_size=inputLayerSize + 1)
        self.collate = collate
        # 定义 MLP 层
        self.fc1 = nn.Linear(outputLayerSize + 1, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fcLast = nn.Linear(64, outputLayerSize)

    def forward(self, x, edge_index, edge_attr, batch, problemType):
        data_x = x.float()
        edge_attr = edge_attr.float()

        num_nodes = x.size(0)
        num_edges = edge_index.size(1)
        
#        max_dim = 101
#        one_hot_list = []
#        num_nodes = x.shape[0]

#        first_values = x[:, 0]
#        indices = first_values % max_dim

#        one_hot_tensor = torch.zeros(num_nodes, max_dim)
#        one_hot_tensor[torch.arange(num_nodes), indices.long()] = 1
#        one_hot_x = one_hot_tensor.to(device=edge_attr.device)

        aggregated_edge_features0 = torch.zeros(num_nodes, edge_attr.size(1), device=edge_attr.device)
        aggregated_edge_features1 = torch.zeros(num_nodes, edge_attr.size(1), device=edge_attr.device)
        
        scatter(edge_attr, edge_index[0], dim=0, out=aggregated_edge_features0, reduce="max")
        scatter(edge_attr, edge_index[1], dim=0, out=aggregated_edge_features1, reduce="max")
        aggregated_edge_features = (aggregated_edge_features0 + aggregated_edge_features1)

  #      mask = (edge_attr[:, 0] == 14)
  #      edge_subset = edge_index[:, mask]
  #      edge_subset = edge_subset.to(device=edge_attr.device)
    #    print(edge_attr)
    #    edge_attr_subset = edge_attr[:, mask]
    #    edge_attr_subset = edge_attr_subset.to(device=edge_attr.device)
#        print(edge_index)
#        print(edge_subset)
    #    print(edge_subset.shape)
     #   print(edge_attr_subset.shape)

  #      source_nodes = edge_subset[0]
    #    print(source_nodes.shape)
  #      target_nodes = edge_subset[1]
   #     print(target_nodes.shape)
  #      merged_nodes = torch.cat((source_nodes, target_nodes), dim=0)
#        print(merged_nodes)
#        print(merged_nodes.shape)
  #      unique_nodes = torch.unique(merged_nodes)
  #      unique_nodes = unique_nodes.to(device=edge_attr.device)
#        print(unique_nodes.shape)
     #   unique_node_features = x[unique_nodes]
#        print(unique_node_features.shape)
     #   num_edge_nodes = unique_node_features.size(0)
     #   aggregated_edges_zero0 = torch.zeros(num_edge_nodes, edge_attr.size(1), device=edge_attr.device)
     #   aggregated_edges_zero1 = torch.zeros(num_edge_nodes, edge_attr.size(1), device=edge_attr.device)
     #   scatter(edge_attr_subset, edge_subset[0], dim=0, out=aggregated_edges_zero0, reduce="max")
     #   scatter(edge_attr_subset, edge_subset[1], dim=0, out=aggregated_edges_zero1, reduce="max")
     #   print(aggregated_edges_zero0.shape)
     #   print(aggregated_edges_zero1.shape)
     #   aggregated_edges_zero = (aggregated_edges_zero0 + aggregated_edges_zero1)
     #   unique_nodes_list = unique_nodes.cpu().tolist()
     #   print(unique_nodes_list)
     #   node_id_to_index = {int(node_id.item()): idx for idx, node_id in enumerate(unique_nodes)}

        # 将源节点和目标节点转换为 unique_nodes 的索引
     #   source_indices = torch.tensor([node_id_to_index[int(node.item())] for node in source_nodes], device=edge_attr.device)
     #   target_indices = torch.tensor([node_id_to_index[int(node.item())] for node in target_nodes], device=edge_attr.device)
        
        # 创建新的 edge_index
     #   new_edge_index = torch.stack((source_indices, target_indices), dim=0)
#        print(new_edge_index.shape)
        
        x = torch.cat([x, aggregated_edge_features], dim=1)
     #   unique_x = torch.cat([unique_node_features, aggregated_edges_zero], dim=1)
        # 传入 GraphSAGE 层，使用最大池化聚合
#        print(x)
  #      x = x[unique_nodes]
#        print(unique_nodes)
#        print(unique_nodes.shape)
  #      node_mapping = {node.item(): idx for idx, node in enumerate(unique_nodes)}
#        print(unique_nodes[1])
        # 更新 edge_subset 中的索引
  #      updated_edge_subset = torch.stack([
  #          torch.tensor([node_mapping[old_idx.item()] for old_idx in edge_subset[0]]),
  #          torch.tensor([node_mapping[old_idx.item()] for old_idx in edge_subset[1]])
  #      ])
  #      updated_edge_subset = updated_edge_subset.to(device=edge_attr.device)
  #      updated_edge_subset = updated_edge_subset.to(torch.long)
  #      x = self.sage_layer(x, updated_edge_subset)
        x = self.sage_layer(x, edge_index)
#        unique_x = self.sage_layer(unique_x, new_edge_index)
    #    print(x.shape)
#        x = self.conv(x, data_x)
    #    print(x.shape)

        # ~B~^~\ batch 为~]~^空~L~H~Y~[~L~E~@~\~@大| ~L~V
#        print(batch.size())
#        print(x.size())
#        if batch.max() >= x.size(0):
#            raise ValueError("Batch index exceeds node count in x.")
        if self.collate == "max":
  #          batch = batch[unique_nodes]
#            x = x[unique_nodes]
  #          data_x = data_x[unique_nodes]
            x = global_max_pool(x, batch)
            data_x = global_max_pool(data_x, batch)
#            unique_x = global_max_pool(unique_x, batch)
#            unique_node_features = global_max_pool(unique_node_features, batch)
#        print(x.shape)
#        print(data_x.shape)
        if len(problemType.shape) == 1:
            problemType = problemType.unsqueeze(1)
#        print(x.shape)
#        print(data_x.shape)
        x = torch.cat((x, problemType), dim=1)
        data_x = torch.cat((data_x, problemType), dim=1)
#        unique_x = torch.cat((unique_x, problemType), dim=1)
#        unique_node_features = torch.cat((unique_node_features, problemType), dim=1)
  #      print(x.shape)
        x = self.conv(x, data_x)
#        x = self.conv(unique_x , unique_node_features)
        # 添加 MLP 层
#        x = self.fc1(x)
#        x = f.leaky_relu(x)  # Leaky ReLU 激活
#        x = self.fc2(x)
#        x = f.leaky_relu(x)
#        x = self.fcLast(x)

        return x


class toolSelector(nn.Module):
    def __init__(self, inputLayerSize, hidden_dim, outputLayerSize):
        super(toolSelector, self).__init__()
        self.fc1 = nn.Linear(inputLayerSize, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim // 2)
        self.fc3 = nn.Linear(hidden_dim // 2, outputLayerSize)
        self.drop = nn.Dropout(p=0.2)
      #  self.dropout = nn.Dropout(dropout_prob)

    def forward(self, x):
#        x = self.drop(x)
        x = self.fc1(x)
        x = f.leaky_relu(x)
        x = self.fc2(x)
        x = f.leaky_relu(x)
#        x = torch.relu(self.fc1(x))
    #    x = self.drop(x)
#        x = self.fc1(x)
    #    x = self.fc2(x)
#        x = torch.relu(self.fc2(x))
 #       x = self.drop(x)
        x = self.fc3(x)
        return x

class StrategySelector(nn.Module):
    def __init__(self, inputLayerSize, hidden_dim, outputLayerSize):
        super(StrategySelector, self).__init__()
        # 第一层全连接
        self.fc1 = nn.Linear(inputLayerSize, hidden_dim)
        self.relu1 = nn.ReLU()
        self.dropout1 = nn.Dropout(0.3)
        # 第二层全连接
        self.fc2 = nn.Linear(hidden_dim, hidden_dim // 2)
        self.relu2 = nn.ReLU()
        self.dropout2 = nn.Dropout(0.3)
        # 输出层
        self.fc3 = nn.Linear(hidden_dim // 2, outputLayerSize)
    
    def forward(self, x):
        x = self.fc1(x)
        x = self.relu1(x)
        x = self.dropout1(x)
        x = self.fc2(x)
        x = self.relu2(x)
        x = self.dropout2(x)
        x = self.fc3(x)
        return x
