import os
import sys
import inspect

currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
parentdir = os.path.dirname(currentdir)
sys.path.insert(0, parentdir) 

import torch
import torch.nn as nn
from graph.hetgat_layer import MultiHetGATLayer

class HetNet(nn.Module):
    def __init__(self, nodes, edges, edge_features, attention_nodes, output_nodes, in_dim, hid_dim, out_dim, num_heads, num_layers = 4, merge='cat', mode='leaky_relu',
    device = torch.device("cpu"), final_activation='avg'):
        super(HetNet, self).__init__()
        assert(num_layers >= 1)
        self.num_layers = num_layers
        self.device = device
        
        if merge == 'cat':
            output_heads = num_heads
        else:
            output_heads = 1
        
        hid_dim_input = {}
        for key in hid_dim:
            # hid_dim_input[key] = hid_dim[key] * num_heads
            hid_dim_input[key] = hid_dim[key] * output_heads
        layers = []
        if self.num_layers == 1:
            layers.append(MultiHetGATLayer(nodes, edges, edge_features, attention_nodes, output_nodes, in_dim, out_dim, num_heads, merge=final_activation, mode='no_activation', device = self.device).to(self.device)
            )
        else:
            layers.append(MultiHetGATLayer(nodes, edges, edge_features, attention_nodes, output_nodes, in_dim, hid_dim, num_heads, merge=merge, mode=mode, device=self.device).to(self.device))
            for _ in range(self.num_layers - 2):
                layers.append(MultiHetGATLayer(nodes, edges, edge_features, attention_nodes, output_nodes, hid_dim_input, hid_dim, num_heads, merge=merge, mode=mode, device = self.device).to(self.device))
            layers.append(MultiHetGATLayer(nodes, edges, edge_features, attention_nodes, output_nodes, hid_dim_input, out_dim, num_heads, merge=final_activation, mode='no_activation', device = self.device).to(self.device))

        self.layers = nn.ModuleList(layers)
    
    def get_param(self):
        for i in range(self.num_layers):
            self.layers[i].get_param()

    def forward(self, g, node_feat, edge_feat, mode):
        '''
        input
            g: DGL heterograph
                number of Q-value nodes = number of available actions
            node_feat: dictionary of input node features
            edge_feat: dictionary of input edge features
            mode: 'agent' for pick agent graph, 'task' for pick task graph
        '''
        hn = node_feat.copy()
        he = edge_feat.copy()
        for i in range(self.num_layers):
            hn, he = self.layers[i](g, hn, he, mode)
        return hn, he
    

class HetResNet(nn.Module):
    def __init__(self, nodes, edges, edge_features, attention_nodes, output_nodes, in_dim, hid_dim, out_dim, num_heads, num_layers = 4, merge='cat', mode='leaky_relu',
    device = torch.device("cpu"), final_activation='avg'):
        super(HetResNet, self).__init__()
        assert(num_layers >= 1)
        self.num_layers = num_layers
        self.device = device
        
        if merge == 'cat':
            output_heads = num_heads
        else:
            output_heads = 1
        
        hid_dim_input = {}
        for key in hid_dim:
            # hid_dim_input[key] = hid_dim[key] * num_heads
            hid_dim_input[key] = hid_dim[key] * output_heads + in_dim[key]
        layers = []
        if self.num_layers == 1:
            layers.append(MultiHetGATLayer(nodes, edges, edge_features, attention_nodes, output_nodes, in_dim, out_dim, num_heads, merge=final_activation, mode='no_activation', device = self.device).to(self.device)
            )
        else:
            layers.append(MultiHetGATLayer(nodes, edges, edge_features, attention_nodes, output_nodes, in_dim, hid_dim, num_heads, merge=merge, mode=mode, device=self.device).to(self.device))
            for _ in range(self.num_layers - 2):
                layers.append(MultiHetGATLayer(nodes, edges, edge_features, attention_nodes, output_nodes, hid_dim_input, hid_dim, num_heads, merge=merge, mode=mode, device = self.device).to(self.device))
            layers.append(MultiHetGATLayer(nodes, edges, edge_features, attention_nodes, output_nodes, hid_dim_input, out_dim, num_heads, merge=final_activation, mode='no_activation', device = self.device).to(self.device))

        self.layers = nn.ModuleList(layers)
    
    
    def forward(self, g, node_feat, edge_feat, mode):
        '''
        input
            g: DGL heterograph
                number of Q-value nodes = number of available actions
            node_feat: dictionary of input node features
            edge_feat: dictionary of input edge features
            mode: 'agent' for pick agent graph, 'task' for pick task graph
        '''
        nf = node_feat
        ef = edge_feat
        for i, layer in enumerate(self.layers):
            nf, ef = layer(g, nf, ef, mode)
            
            if i < len(self.layers) - 1:
                for key in nf:
                    nf[key] = torch.cat((nf[key], node_feat[key]), dim=1)
                for key in ef:
                    ef[key] = torch.cat((ef[key], edge_feat[key]), dim=1)
        return nf, ef
    
    def get_param(self):
        for i in range(self.num_layers):
            self.layers[i].get_param()

    def __repr__(self):
        return '{}(n_inp={}, n_hid={}, n_out={}, n_layers={})'.format(
            self.__class__.__name__, self.in_dim, self.hid_dim,
            self.out_dim, self.n_layers)