import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv

class GraphSAGEModel(nn.Module):
    def __init__(self, feat_dim, hidden_dim, output_dim, num_layers, dropout, task_level):
        super(GraphSAGEModel, self).__init__()
        self.convs = torch.nn.ModuleList()
        self.convs.append(SAGEConv(feat_dim, hidden_dim))
        for _ in range(num_layers - 2):
            self.convs.append(SAGEConv(hidden_dim, hidden_dim))
        self.convs.append(SAGEConv(hidden_dim, output_dim))
        self.dropout = dropout
        self.task_level = task_level
    def forward(self, x, edge_index):
        for conv in self.convs[:-1]:
            x = conv(x, edge_index)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.convs[-1](x, edge_index)
        if self.task_level == 'node':
            x = F.log_softmax(x, dim=1)
        return x

class GraphSAGE(nn.Module):
    def __init__(self, feat_dim, hidden_dim, output_dim, dropout, task_level):
        super(GraphSAGE, self).__init__()

        self.base_model = GraphSAGEModel(feat_dim=feat_dim, hidden_dim=hidden_dim, output_dim=output_dim, num_layers=2, dropout=dropout, task_level=task_level)
        self.processed_feature = None
        # self.base_model.reset_parameters()
    def preprocess(self, edge_index, feature, tensor_feat=True):
        self.processed_edge = edge_index
        if feature is not None:
            if tensor_feat:
                self.processed_feature = torch.FloatTensor(feature)
            else:
                self.processed_feature = feature
    def model_forward(self, idx, device, ori=None):
        return self.forward(idx, device, ori)
    def forward(self, idx, device, ori):
        processed_feature = None
        processed_feature = self.processed_feature.to(device)
        processed_edge = self.processed_edge.to(device)
        if ori is not None:
            self.base_model.query_edges = ori
            
        output = self.base_model(processed_feature, processed_edge)
        return output[idx] if idx is not None else output