import torch
from torch.nn import Linear, Parameter
import torch.nn.functional as F
import torch.nn as nn
import torch_geometric
from torch_geometric.nn import MessagePassing, global_mean_pool, GINConv
from torch_geometric.utils import add_self_loops, degree
from models.model_utils import get_activation_function, initialize_layer_w_zero
from models.norms import Normalization

#TODO(tm): send in the device while creating the model.
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')




class EdgewiseGINConv(GINConv):
    """
        # x = ReLU(W2 * message + b_2) * tanh(W_1 * message + b_1)
        v2:
            x = (W2 * message + b_2) * tanh(W_1 * message + b_1)
    """
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def forward(self, message, edge_index, node_feature, edgewise_edge_index=None):
        message = self.nn(message)
        out = self.propagate(edgewise_edge_index, x=message)
        x = out + node_feature[edge_index[0]]
        return x
    
class EdgeGINConvLayer(torch.nn.Module):
    def __init__(self, in_channels, out_channels, initialize_w_zero=False, norm_type='Identity'):
        super(EdgeGINConvLayer, self).__init__()
        mlp = torch.nn.Sequential(
            torch.nn.Linear(in_channels, out_channels),
            Normalization(out_channels, norm_type),
            torch.nn.ReLU(),
            torch.nn.Linear(out_channels, out_channels)
        )
        self.layer = EdgewiseGINConv(nn=mlp, train_eps=False)

    def forward(self, message, edge_index, node_feature, edgewise_edge_index):
        return self.layer(message, edge_index, node_feature, edgewise_edge_index)

class GraphEdgeNetwork(torch.nn.Module):
    def __init__(
            self,
            num_hidden_layers=0,
            in_channels=1,
            hidden_channels=10,
            out_channels=1,
            norm_type="Identity",
            dropout=0.5,
            activation='relu',
            final_activation='identity',
            initialize_w_zero=False,
            task_type='node_classification',
            jumping_knowledge='last',
            **kwargs,
        ):
        super().__init__()
        self.num_hidden_layers = num_hidden_layers
        self.task_type = task_type
        self.dropout = dropout
        self.out_channels = out_channels
        self.final_activation = get_activation_function(final_activation)
        self.norm_type = norm_type
        self.activation = get_activation_function(activation)
        self.jumping_knowledge = jumping_knowledge

        self.conv1 = EdgeGINConvLayer(
            in_channels=in_channels,
            out_channels=hidden_channels,
            norm_type=self.norm_type,
            initialize_w_zero=initialize_w_zero,)
        self.conv1_norm = Normalization(hidden_channels, self.norm_type)

        self.hidden_layers = nn.ModuleList()
        self.norm_layers = nn.ModuleList()

        for i in range(self.num_hidden_layers):
            self.hidden_layers.append(
                EdgeGINConvLayer(
                    in_channels=hidden_channels,
                    out_channels=hidden_channels,
                    norm_type=self.norm_type,
                    initialize_w_zero=initialize_w_zero,)
            )
            self.norm_layers.append(Normalization(hidden_channels, self.norm_type))

        self.conv2 = EdgeGINConvLayer(
            in_channels=hidden_channels,
            out_channels=hidden_channels,
            norm_type=self.norm_type,
            initialize_w_zero=initialize_w_zero,)
        self.conv2_norm = Normalization(hidden_channels, self.norm_type)

        self.aggregation_weight= nn.Linear(
            hidden_channels, hidden_channels, bias=True)
        self.aggregation_weight = initialize_layer_w_zero(
            self.aggregation_weight)

        self.final_regression_layer = nn.Linear(
            hidden_channels, out_channels, bias=True)

        self.node_weight1 = nn.Linear(in_channels, hidden_channels, bias=True)

        if jumping_knowledge == 'concat':
            dimensions = (self.num_hidden_layers + 2) * hidden_channels
            self.final_regression_layer = nn.Linear(dimensions, out_channels, bias=True)
            self.aggregation_weight= nn.Linear(dimensions, dimensions, bias=True)
            self.aggregation_weight = initialize_layer_w_zero(
                self.aggregation_weight)

    def final_expectation(self, edge_messages, edge_index, node_features, batch=None):
        aggr = torch_geometric.nn.aggr.SumAggregation()
        # No activation after aggregation weights here.
        em = self.aggregation_weight(edge_messages)
        sum_of_neighbors = aggr(em, edge_index[1])
        return sum_of_neighbors

    def residual_layer(self, hidden_layer, norm_layer, x, node_feature,
                       edge_index, edgewise_edge_index, edge_weight=None):
        x = hidden_layer(x, edge_index, 
                         node_feature, 
                         edgewise_edge_index=edgewise_edge_index)
        x = norm_layer(x)
        x = self.activation(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        return x

    def forward(self, x, data, edgewise_edge_index):
        # ==== data from the main graph ==== 
        node_feature = data.x.type(torch.FloatTensor).to(device)
        edge_index = data.edge_index
        batch = data.batch


        node_feature = self.node_weight1(node_feature)

        x = self.conv1(x, edge_index, 
                       node_feature, 
                       edgewise_edge_index=edgewise_edge_index)
        x = self.conv1_norm(x)
        x = self.activation(x)
        x = F.dropout(x, p=self.dropout, training=self.training)

        h_list = [x]

        for i in range(self.num_hidden_layers):
            x = self.residual_layer(
                hidden_layer=self.hidden_layers[i],
                norm_layer=self.norm_layers[i],
                x=x,
                node_feature=node_feature,
                edge_index=edge_index,
                edgewise_edge_index=edgewise_edge_index
            )
            h_list.append(x)

        x = self.conv2(x, edge_index, node_feature, 
                       edgewise_edge_index=edgewise_edge_index)
        x = self.conv2_norm(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        h_list.append(x)

        if self.jumping_knowledge == 'last':
            x = h_list[-1]
        elif self.jumping_knowledge == 'concat':
            x = torch.cat(h_list, dim=1)
        else: 
            ValueError("Invalid jumping knowledge type")

        x = self.final_expectation(x, edge_index, node_feature)

        if self.task_type in ['graph_regression', 'graph_classification']:
            x = global_mean_pool(x, batch)
        
        x = self.final_regression_layer(x)

        final_x = self.final_activation(x)

        return final_x


class GraphEdgeNetworkResidual(torch.nn.Module):
    def __init__(
            self,
            num_hidden_layers=0,
            in_channels=1,
            hidden_channels=10,
            out_channels=1,
            norm_type="Identity",
            activation='relu',
            final_activation='identity',
            initialize_w_zero=False,
            task_type='node_classification',
            dropout=0.5,
            jumping_knowledge='last',
            **kwargs,
        ):
        super().__init__()
        self.num_hidden_layers = num_hidden_layers
        self.task_type = task_type
        self.out_channels = out_channels
        self.final_activation = get_activation_function(final_activation)
        self.norm_type = norm_type
        self.dropout = dropout
        self.activation = get_activation_function(activation) 
        self.jumping_knowledge = jumping_knowledge

        self.conv1 = EdgeGINConvLayer(
            in_channels=in_channels,
            out_channels=hidden_channels,
            norm_type=self.norm_type,
            initialize_w_zero=initialize_w_zero,)
        self.conv1_norm = Normalization(hidden_channels, self.norm_type)

        self.hidden_layers = nn.ModuleList()
        self.norm_layers = nn.ModuleList()

        for i in range(self.num_hidden_layers):
            self.hidden_layers.append(
                EdgeGINConvLayer(
                    in_channels=hidden_channels,
                    out_channels=hidden_channels,
                    norm_type=self.norm_type,
                    initialize_w_zero=initialize_w_zero,)
            )
            self.norm_layers.append(Normalization(hidden_channels, self.norm_type))

        self.conv2 = EdgeGINConvLayer(
            in_channels=hidden_channels,
            out_channels=hidden_channels,
            norm_type=self.norm_type,
            initialize_w_zero=initialize_w_zero,)
        self.conv2_norm = Normalization(hidden_channels, self.norm_type)

        self.aggregation_weight = nn.Linear(
            hidden_channels, hidden_channels, bias=True)
        self.aggregation_weight = initialize_layer_w_zero(
            self.aggregation_weight)

        self.final_regression_layer = nn.Linear(
            hidden_channels, out_channels, bias=True)
        self.node_weight1 = nn.Linear(in_channels, hidden_channels, bias=True)

        if jumping_knowledge == 'concat':
            dimensions = (self.num_hidden_layers + 2) * hidden_channels
            self.final_regression_layer = nn.Linear(dimensions, out_channels, bias=True)
            self.aggregation_weight= nn.Linear(dimensions, dimensions, bias=True)
            self.aggregation_weight = initialize_layer_w_zero(
                self.aggregation_weight)


    def final_expectation(self, edge_messages, edge_index, node_features, batch=None):
        aggr = torch_geometric.nn.aggr.SumAggregation()
        # No activation after aggregation weights here.
        em = self.aggregation_weight(edge_messages)
        sum_of_neighbors = aggr(em, edge_index[1])
        return sum_of_neighbors

    def residual_layer(self, hidden_layer, norm_layer, x, node_feature,
                       edge_index, edgewise_edge_index, edge_weight=None):
        residual = x
        x = hidden_layer(x, edge_index, 
                         node_feature, 
                         edgewise_edge_index=edgewise_edge_index)
        x = norm_layer(x)
        x = self.activation(x)
        x = F.dropout(x, p=self.dropout, training=self.training) + residual
        return x

    def forward(self, x, data, edgewise_edge_index):
        # ==== data from the main graph ==== 
        node_feature = data.x.type(torch.FloatTensor).to(device)
        edge_index = data.edge_index
        batch = data.batch


        node_feature = self.node_weight1(node_feature)

        x = self.conv1(x, edge_index, 
                       node_feature, 
                       edgewise_edge_index=edgewise_edge_index)
        x = self.conv1_norm(x)
        x = self.activation(x)
        x = F.dropout(x, p=self.dropout, training=self.training) 

        h_list = [x]

        for i in range(self.num_hidden_layers):
            x = self.residual_layer(
                hidden_layer=self.hidden_layers[i],
                norm_layer=self.norm_layers[i],
                x=x,
                node_feature=node_feature,
                edge_index=edge_index,
                edgewise_edge_index=edgewise_edge_index
            )
            h_list.append(x)

        x = self.conv2(x, edge_index, node_feature, 
                       edgewise_edge_index=edgewise_edge_index)
        x = self.conv2_norm(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        h_list.append(x)


        if self.jumping_knowledge == 'last':
            x = h_list[-1]
        elif self.jumping_knowledge == 'concat':
            x = torch.cat(h_list, dim=1)
        else: 
            ValueError("Invalid jumping knowledge type")

        x = self.final_expectation(x, edge_index, node_feature)

        if self.task_type in ['graph_regression', 'graph_classification']:
            x = global_mean_pool(x, batch)
        
        x = self.final_regression_layer(x)

        final_x = self.final_activation(x)

        return final_x
