import torch
from torch import nn
from torch_geometric.nn import EdgeConv

class Score_GNN(nn.Module):
    def __init__(self, target_len, mid_lay_input_len, obj_feat_len, text_feat_len, time_feat_len, n_layers):
        super().__init__()
        self.target_len = target_len       # for the score-based mdoel, the input len is identical to the output len, which is the target len
        self.mid_lay_input_len = mid_lay_input_len # for input features in the middle layers
        self.obj_feat_len = obj_feat_len   # Given the observation, we can get the object features
        self.text_feat_len = text_feat_len # for text embedding
        self.time_feat_len = time_feat_len # for time embedding
        self.n_layers = n_layers
        self.relation_net = nn.ModuleList()
        # self.relation_net.append(self.create_one_layer(2 * (target_len + obj_feat_len + text_feat_len + time_feat_len), mid_lay_input_len, mid_lay_input_len)) # first layer
        self.pos_encoder = nn.Sequential(nn.Linear(target_len, mid_lay_input_len), nn.ReLU(True)) # One layer is enough
        for _ in range(n_layers - 1): # we reduct the first last layer
            self.relation_net.append(self.create_one_layer(2 * (mid_lay_input_len + obj_feat_len + text_feat_len + time_feat_len), mid_lay_input_len, mid_lay_input_len))
        self.last_layer = self.create_one_layer(2 * (mid_lay_input_len + obj_feat_len + text_feat_len + time_feat_len), mid_lay_input_len, target_len) # last layer
            
    def create_one_layer(self, input_len, mid_len, output_len):
        mlp = nn.Sequential(
            nn.Linear(input_len, 2 * mid_len),
            nn.ReLU(True),
            nn.Linear(2 * mid_len, output_len),
        )
        return EdgeConv(mlp)
    
    def forward(self, pos, obj_emb, text_emb, time_emb, edge_index): # You should get emb before using this GNN
        # print(f'Inside model - num graphs: {data.num_graphs}, '
        #       f'device: {data.batch.device}, '
        #       f'x shape: {data.x.shape}, ')
        pos = self.pos_encoder(pos)
        for layer in self.relation_net:
            pos = torch.cat([pos, obj_emb, text_emb, time_emb], dim=1)
            pos = layer(pos, edge_index)
            pos = torch.relu(pos)
        pos = torch.cat([pos, obj_emb, text_emb, time_emb], dim=1)
        pos = self.last_layer(pos, edge_index) # no relu here
        return pos
