# -*- coding: utf-8 -*-
"""
Created on Tue Jun 11 12:53:49 2022

@author: Anonymous, Anonymous

Heterogeneous Graph Attention Layer for General Metagraph Applications with Output Nodes
"""
# import os
# import sys
# import inspect

# currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
# for i in range(1):
#     currentdir = os.path.dirname(currentdir)
# parentdir = os.path.dirname(currentdir)
# sys.path.insert(0, parentdir)

from enum import Enum
import sys
from dgl.ops import edge_softmax
import dgl.function as fn
import torch
import torch.nn as nn

from models.graph.node_layers.linear_nn import LinearNN

class MixedActivation(nn.Module):
    def __init__(self):
        super(MixedActivation, self).__init__()
        self.activate = nn.ReLU()
    
    def forward(self, x):
        x = self.activate(x)
        log = torch.log(x)
        exp = torch.exp(x)
        actual = x
        cat = torch.cat([log, exp, actual], dim=1)
        return cat

class Activation(Enum):
    RELU = 'relu'
    LEAKY_RELU = 'leaky_relu'
    SOFTMAX = 'softmax'
    SIGMOID = 'sigmoid'
    MAX = 'max'
    NO_ACTIVATION = 'no_activation'
    MIXED = 'mixed'

    @classmethod
    def from_string(cls, string):
        key = string.upper()
        try:
            return cls[key]
        except KeyError:
            raise ValueError("Invalid String for Activation:", key)


def get_activation(mode_str: str, l_alpha=0.2):
    mode = Activation.from_string(mode_str)
    if mode == Activation.RELU:
        return nn.ReLU()
    elif mode == Activation.LEAKY_RELU:
        return nn.LeakyReLU(l_alpha)
    elif mode == Activation.SOFTMAX:
        return nn.Softmax()
    elif mode == Activation.SIGMOID:
        return nn.Sigmoid()
    elif mode == Activation.MIXED:
        return MixedActivation()
    else:
        return None


class HetGATLayer(nn.Module):

    """
    # nodes: List of Nodes of the Metagraph
    # edges: List of Edges within the Metagraph (input, edge, output)
    # edge_features: Dictionary of Edges that carry a value and the source for the Edge feature
    # attention_nodes: Set of Edges that use attention
    # in_dim: dict of input feature dimension for each node
    # out_dim: dict of output feature dimension for each node
    #       in_dim and out_dim include weighted edge types
    # cetypes: reutrn of G.canonical_etypes
    """

    def __init__(self, nodes, edges, edge_features, attention_nodes,
                 output_nodes, in_dim, out_dim, num_heads,
                 l_alpha=0.2, mode='leaky_relu',
                 device=torch.device("cpu"), verbose=False):
        super(HetGATLayer, self).__init__()
        self._num_heads = num_heads
        self._in_dim = in_dim
        self._out_dim = out_dim

        self.nodes = nodes  # List
        self.edges = edges  # List of 3 tuples
        # Dict of edge_name to source, edges with weight
        self.edge_features = edge_features
        self.attention_nodes = attention_nodes  # Set of Nodes with Attention
        self.output_nodes = output_nodes  # Set of Output Nodes
        node_dict = {}

        self.device = device
        self.verbose = verbose
        # Nodes
        for node in self.nodes:
            node_dict[node] = LinearNN(
                in_dim[node], out_dim[node] * num_heads)
        self.fc_node = nn.ModuleDict(node_dict).to(device)

        # Edges
        edge_dict = {}  # Edge Dictionary
        output_dict = {}  # Dictionary of Dictionary for output dictionaries
        for input_node, edge, output_node in self.edges:
            if output_node in self.output_nodes:  # Edges point to Output Nodes set to respective Output Node Edges
                if output_node not in output_dict:
                    # initialize output subdictionary
                    output_dict[output_node] = {}
                output_dict[output_node][edge] = LinearNN(
                    in_dim[input_node], out_dim[output_node] * num_heads).to(device)
            else:  # Edges point to normal nodes
                edge_dict[edge] = LinearNN(
                    in_dim[input_node], out_dim[output_node] * num_heads).to(device)
                
                # edge_dict[edge] = nn.Linear(in_dim[input_node], out_dim[output_node] * num_heads).to(device)
        self.fc_base = nn.ModuleDict(edge_dict).to(device)

        # set the Output Dictionary
        outputs = {}
        for output_node in self.output_nodes:
            if output_node in output_dict:
                outputs[output_node] = nn.ModuleDict(
                    output_dict[output_node]).to(device)
        # self.fc_outputs = outputs
        self.fc_outputs = nn.ModuleDict(outputs).to(device)

        value_edge_dict = {}
        for edge, source in self.edge_features.items():  # Edge Weights
            value_edge_dict[edge] = LinearNN(
                in_dim[edge], out_dim[edge] * num_heads).to(device)
        self.fc_edge = nn.ModuleDict(value_edge_dict).to(device)

        self.leaky_relu = nn.LeakyReLU(negative_slope=l_alpha)
        self.activation = get_activation(mode, l_alpha)

        self.src_param = {}
        self.dst_param = {}
        self.edge_param = {}

        # Attention parameters
        for input_node, edge, output_node in self.edges:
            # This is used for the heterogenous edges that require use of attention
            # if edge not in self.attention_nodes or output_node in self.output_nodes: # Ignore No Attention: Includes Output Nodes
            #     continue
            # self.src_param[edge] = nn.Parameter(torch.FloatTensor(size=(1, num_heads, out_dim[output_node]))).to(device)
            self.src_param[edge] = LinearNN(
                out_dim[output_node], out_dim[output_node]*num_heads).to(device)
            # self.dst_param[edge] = nn.Parameter(torch.FloatTensor(size=(1, num_heads, out_dim[output_node]))).to(device)
            self.dst_param[edge] = LinearNN(
                out_dim[output_node], out_dim[output_node]*num_heads).to(device)

            if edge in self.edge_features:  # Edge Weights
                # self.edge_param[edge] = nn.Parameter(torch.FloatTensor(size=(1, num_heads, out_dim[output_node]))).to(device)
                self.edge_param[edge] = LinearNN(
                    out_dim[edge], out_dim[output_node]*num_heads).to(device)

        self.reset_parameters()

    def reset_parameters(self):
        gain = nn.init.calculate_gain('sigmoid')
        # Internal Nodes
        for node in self.nodes:
            self.fc_node[node].reset(gain)
        # Output Nodes
        for node in self.output_nodes:
            self.fc_node[node].reset(gain)
            

        # Edges
        for input_node, edge, output_node in self.edges:
            self.src_param[edge].reset(gain)
            self.dst_param[edge].reset(gain)

            if output_node in self.output_nodes:
                self.fc_outputs[output_node][edge].reset(gain)
            else:
                self.fc_base[edge].reset(gain)
        
            # attention
            if edge in self.edge_features:
                self.edge_param[edge].reset(gain)


    def forward(self, g, node_feat, edge_feat, mode):
        '''
        Feature transform for each node type
        '''
        Wh_node = {}
        for node in self.nodes:
            # ignore nodes that are in output but not the set mode
            if node in self.output_nodes and node not in [mode]:
                continue
            # set the initial value for node as "Wh_node"
            Wh_node[node] = self.fc_node[node](
                node_feat[node]).view(-1, self._num_heads, self._out_dim[node])
            wh_name = "Wh_" + node
            g.nodes[node].data[wh_name] = Wh_node[node]

        We_edge = {}
        for edge, i in self.edge_features.items():
            if edge_feat[edge].size()[0] == 0:
                We_edge[edge] = torch.Tensor([]).to(self.device)
            else:
                We_edge[edge] = self.fc_edge[edge](
                    edge_feat[edge]).view(-1, self._num_heads, self._out_dim[edge])
            We_name = "We_" + edge
            g[edge].edata[We_name] = We_edge[edge]

        '''
        Feature transform for each edge/relation type
        '''
        Wh_edge = {}
        for input_node, edge, output_node in self.edges:
            # ignore, if edge destination is output but not the current mode
            if output_node in self.output_nodes and output_node not in [mode]:
                continue
            wh_name = "Wh_" + edge
            if output_node in self.output_nodes:
                # if destination is an output node, set from fc_outputs nodes
                Wh_edge[edge] = self.fc_outputs[output_node][edge](
                    node_feat[input_node]).view(-1, self._num_heads, self._out_dim[output_node])
                g.nodes[input_node].data[wh_name] = Wh_edge[edge]
            else:
                Wh_edge[edge] = self.fc_base[edge](
                    node_feat[input_node]).view(-1, self._num_heads, self._out_dim[output_node])
                g.nodes[input_node].data[wh_name] = Wh_edge[edge]
        '''
        Message passing on subgraphs - base graph
            With attention
        '''
        Attn_src = {}
        Attn_dst = {}
        Attn_edg = {}
        e_dict = {}
        for input_node, edge, output_node in self.edges:
            if edge in self.edge_features:
                '''
                Message passing on subgraphs - base graph
                    With attention
                '''
                # Attention from Input Node
                attn_src_name = 'Attn_src_' + edge
                # Attn_src[edge] = (Wh_edge[edge] * self.src_param[edge]).sum(dim=-1).unsqueeze(-1)
                
                Attn_src[edge] = self.src_param[edge](
                    Wh_edge[edge]).sum(dim=-1).unsqueeze(-1)

                # Attention from Output Node
                attn_dst_name = 'Attn_dst_' + edge
                # Attn_dst[edge] = (Wh_node[output_node] * self.dst_param[edge]).sum(dim=-1).unsqueeze(-1)
                Attn_dst[edge] = self.dst_param[edge](
                    Wh_node[output_node]).sum(dim=-1).unsqueeze(-1)

                g[edge].srcdata.update({attn_src_name: Attn_src[edge]})
                g[edge].dstdata.update({attn_dst_name: Attn_dst[edge]})
                e_name = 'e_' + edge
                a_edge = 'a_' + edge
                if edge not in self.attention_nodes:  # Attention without Message Passing
                    input_node = self.edge_features[edge]
                    # Combine Input and Node Attention
                    g[edge].apply_edges(fn.u_add_v(
                        attn_src_name, attn_dst_name, e_name))
                    # LeakyReLU
                    e_dict[edge] = self.leaky_relu(g[edge].edata.pop(e_name))
                    # compute softmax to get alpha
                    g[edge].edata[a_edge] = edge_softmax(g[edge], e_dict[edge])

                    # Wh_o' += attn_input * W_in * h_in
                    Wh_node_src = 'Wh_' + input_node
                    m_edge = 'm_' + edge
                    ft_edge = 'ft_' + edge
                    # Get attention for each input node and add them together
                    g[edge].update_all(fn.u_mul_e(
                        Wh_node_src, a_edge, m_edge), fn.sum(m_edge, ft_edge))
                    # if torch.isnan(g[edge]).any()
                else:  # Attention with Message Passing
                    attn_edg_name = 'Attn_edg_' + edge
                    # Attenton from the Message/Edge Values
                    if edge_feat[edge].size()[0] == 0:
                        Attn_edg[edge] = torch.Tensor([]).to(self.device)
                    else:
                        # Attn_edg[edge] = (We_edge[edge]*self.edge_param[edge]).sum(dim=-1).unsqueeze(-1)
                        Attn_edg[edge] = self.edge_param[edge](
                            We_edge[edge]).sum(dim=-1).unsqueeze(-1)

                    g[edge].edata.update({attn_edg_name: Attn_edg[edge]})
                    # ATTENTION
                    # compute edge attention from input, edge and output vectors
                    sum_name = 'sum_' + edge
                    g[edge].apply_edges(fn.u_add_e(
                        attn_src_name, attn_edg_name, sum_name))
                    g[edge].apply_edges(fn.e_add_v(
                        sum_name, attn_dst_name, e_name))
                    # LeakyReLU
                    e_dict[edge] = self.leaky_relu(g[edge].edata.pop(e_name))
                    # compute softmax to get attention alpha
                    g[edge].edata[a_edge] = edge_softmax(g[edge], e_dict[edge])
                    # PROPOGATION
                    # message passing with transformed edge weights

                    # Get attention for each edge with attention and add them together
                    Wh_edge_name = 'Wh_' + edge
                    m_edge_1 = 'm_' + edge + '_1'
                    ft_edge_1 = 'ft_' + edge + '_1'
                    g[edge].update_all(fn.u_mul_e(Wh_edge_name, a_edge,
                                                  m_edge_1), fn.sum(m_edge_1, ft_edge_1))

                    # Get attention for each input edge and add them together
                    a_edge_2 = a_edge + '_2'
                    m_edge_2 = 'm_' + edge + '_2'
                    ft_edge_2 = 'ft_' + edge + '_2'
                    g[edge].edata[a_edge_2] = g[edge].edata[a_edge] * We_edge[edge]
                    # combine the input and edge vectors
                    g[edge].update_all(fn.copy_e(
                        a_edge_2, m_edge_2), fn.sum(m_edge_2, ft_edge_2))
            else:
                '''
                Message passing on subgraphs - pick mode graph
                    No attention
                This should only apply to output nodes, which combine the 
                '''
                # Ignore Non-Mode Outputs
                if output_node in self.output_nodes and output_node not in [mode]:
                    continue
                Wh_edge_name = 'Wh_' + edge
                z_edge = 'z_' + edge
                ft_edge = 'ft_' + edge
                g[edge].update_all(fn.copy_u(
                    Wh_edge_name, z_edge), fn.sum(z_edge, ft_edge))

        '''
        Combine features from subgraphs
            Sum up to hi'   
            Add all input edges
        '''
        Wh_new = {}
        for node in self.nodes:
            if node not in [mode] and node in self.output_nodes:  # ignore output nodes that are not node
                continue
            Wh_node = 'Wh_' + node
            Wh_new[node] = g.nodes[node].data[Wh_node].clone()

        for input_node, edge, output_node in self.edges:
            # Ignore Non-Mode Outputs
            if output_node not in [mode] and output_node in self.output_nodes:
                continue
            elif edge in self.attention_nodes:  # if part of edge features there are 2 values to add
                ft_edge_1 = 'ft_' + edge + '_1'
                ft_edge_2 = 'ft_' + edge + '_2'
                Wh_new[output_node] += g.nodes[output_node].data[ft_edge_1]
                if edge_feat[edge].size()[0] > 0:
                    Wh_new[output_node] += g.nodes[output_node].data[ft_edge_2]
            else:
                ft_edge = 'ft_' + edge
                Wh_new[output_node] += g.nodes[output_node].data[ft_edge]

        for node in self.nodes:
            if node not in [mode] and node in self.output_nodes:  # ignore output nodes that are not node
                continue
            g.nodes[node].data['h'] = Wh_new[node]

        '''
        Deal with relu activation and prepare results for return
        '''
        results_node = {}
        results_edge = {}

        for ntype in node_feat:
            if ntype in self.output_nodes and 'h' not in g.nodes[ntype].data:
                continue
            if self.activation is not None:
                results_node[ntype] = self.activation(g.nodes[ntype].data['h'])
                # if results_node[ntype].shape[0] > 0:
                #     batch_norm = nn.BatchNorm1d(results_node[ntype].shape[-1]).to(self.device)
                #     results_node[ntype] = batch_norm(results_node[ntype])
            else:
                results_node[ntype] = g.nodes[ntype].data['h']
        for input_node, edge, output_node in self.edges:
            # if output_node in self.output_nodes:  # ignore output nodes
            #     continue
            if edge in self.edge_features:
                if self.activation is not None:
                    results_edge[edge] = self.activation(We_edge[edge])
                    # if results_edge[edge].shape[0] > 0:
                    #     batch_norm = nn.BatchNorm1d(results_edge[edge].shape[-1]).to(self.device)
                    #     results_edge[edge] = batch_norm(results_edge[edge])
                else:
                    results_edge[edge] = We_edge[edge]
        return results_node, results_edge


class MultiHetGATLayer(nn.Module):
    """
    # merge = 'cat' or 'avg'
    """
    def __init__(self, nodes, edges, edge_features, attention_nodes, output_nodes, in_dim, out_dim, num_heads, merge='cat', mode='leaky_relu',
                 device=torch.device("cpu"), verbose=False):
        # super(MultiHetGATLayer, self).__init__()
        super().__init__()
        self._num_heads = num_heads
        self._merge = merge

        self.gat_conv = HetGATLayer(nodes, edges, 
                                    edge_features, attention_nodes, 
                                    output_nodes, in_dim, out_dim, num_heads, 
                                    mode=mode, device=device, verbose=verbose).to(device)
        if self._merge == 'linear':
            d = {}
            for ntype in out_dim:
                d[ntype] = LinearNN(
                    out_dim[ntype] * num_heads, out_dim[ntype]).to(device)
            
            self.predictor = nn.ModuleDict(d)
        self.device = device

    def get_param(self):
        self.gat_conv.get_param()

    def forward(self, g, node_feat, edge_feat, mode):
        tmp_node, tmp_edge = self.gat_conv(g, node_feat, edge_feat, mode)
        results_node = {}
        results_edge = {}

        if self._merge == 'cat':
            # concat on the output feature dimension (dim=1)
            for ntype in tmp_node:
                results_node[ntype] = torch.flatten(tmp_node[ntype], 1)

            for etype in tmp_edge:
                if edge_feat[etype].size()[0] == 0:
                    results_edge[etype] = torch.Tensor([]).to(self.device)
                else:
                    results_edge[etype] = torch.flatten(tmp_edge[etype], 1)
        elif self._merge == 'sum':
            # merge using sum
            for ntype in tmp_node:
                results_node[ntype] = torch.sum(tmp_node[ntype], 1)

            for etype in tmp_edge:
                if edge_feat[etype].size()[0] == 0:
                    results_edge[etype] = torch.Tensor([]).to(self.device)
                else:
                    results_edge[etype] = torch.sum(tmp_edge[etype], 1)
        elif self._merge == 'linear':
            for ntype in tmp_node:
                results_node[ntype] = self.predictor[ntype](nn.functional.leaky_relu(tmp_node[ntype].flatten(1)))
            for etype in tmp_edge:
                if tmp_edge[etype].shape[0] == 0:
                    results_edge[etype] = torch.Tensor([]).to(self.device)
                else:
                    results_edge[etype] = self.predictor[etype](nn.functional.leaky_relu(tmp_edge[etype].flatten(1)))
        elif self._merge == 'max':
            for ntype in tmp_node:
                results_node[ntype] = torch.max(tmp_node[ntype], 1)[0]
            for etype in tmp_edge:
                if edge_feat[etype].size()[0] == 0:
                    results_edge[etype] = torch.Tensor([]).to(self.device)
                else:
                    results_edge[etype] = torch.max(tmp_edge[etype], 1)[0]
        else:
            # merge using average
            for ntype in tmp_node:
                results_node[ntype] = torch.mean(tmp_node[ntype], 1)

            for etype in tmp_edge:
                if edge_feat[etype].size()[0] == 0:
                    results_edge[etype] = torch.Tensor([])
                else:
                    results_edge[etype] = torch.mean(tmp_edge[etype], 1)
        return results_node, results_edge
