from torch_geometric.nn import global_mean_pool
import torch
from Policies.GCN import GCN 
from Policies.GCN_GAT import GCN_GAT
from Policies.GCN_DIR import GCN_DIR 

class Net(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, num_layers, num_nodes = 9, shared_weights=False, hidden_channels_node=32, num_layers_node=5, uniform_random=False, use_batch_norm=False, var_distance = False, tr_dist = False, tr_att = False, dir_graph = False):
        super(Net, self).__init__()
        self.var_distance = var_distance 
        self.dir_graph = dir_graph 
        if dir_graph: 
             self.gcn = GCN_DIR(in_channels, hidden_channels, num_layers, num_nodes, shared_weights=shared_weights, use_batch_norm=use_batch_norm, tr_dist=tr_dist, tr_att = tr_att) 
        else: 
            self.gcn = GCN(in_channels, hidden_channels, num_layers, shared_weights=shared_weights, use_batch_norm=use_batch_norm, var_distance = self.var_distance,tr_dist=tr_dist, tr_att = tr_att)
        self.pool = global_mean_pool
        self.state_value_lin = torch.nn.Linear(hidden_channels, 1)
        self.uniform_random = uniform_random
        layers = []
        layers.append(torch.nn.Linear(hidden_channels + 1, hidden_channels_node, bias=True))
        layers.append(torch.nn.ReLU())
        for _ in range(num_layers_node):
            layers.append(torch.nn.Linear(hidden_channels_node, hidden_channels_node, bias=True))
            layers.append(torch.nn.ReLU())
        layers.append(torch.nn.Linear(hidden_channels_node, 1)) # out_channels 
        self.act_adv_ff = torch.nn.Sequential(*layers)

    def forward(self, x, edge_index, batch, dis_req, edge_weight = None):
        
        if self.var_distance == True and self.dir_graph == False:
            gcn_out = self.gcn(x, edge_index, edge_weight)
        else: 
            gcn_out = self.gcn(x, edge_index)
        
        # if self.var_distance == True and self.dir_graph == False:
        #     gcn_out = self.gcn(x, edge_index, edge_weight)
        # else: 
        #     gcn_out = self.gcn(x, edge_index)
        # print(gcn_out.size())
        gcn_out = gcn_out.view(-1, gcn_out.size()[-1]) 
        # print(gcn_out.size(), dis_req.size())
        act_adv_inp = torch.cat((gcn_out, dis_req), 1)
        # print(act_adv_inp.size())
        act_adv = self.act_adv_ff(act_adv_inp)
        # return act_adv
        graph_pool = self.pool(gcn_out, batch)
        state_value = self.state_value_lin(graph_pool)
        action_value = state_value[batch] + act_adv
        return action_value
