import torch
import torch.nn.functional as F
import torch.nn as nn

import torch_geometric
from torch_geometric.nn import GCNConv, GATConv, global_mean_pool, BatchNorm
from ogb.graphproppred.mol_encoder import BondEncoder
from models.norms import Normalization
from models.model_utils import get_activation_function
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import degree
from models.model_utils import get_activation_function, initialize_layer_w_zero

GRADIENT_CHECKPOINTING = False
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')



class ZINCGINConv(MessagePassing):
    def __init__(self, in_dim, emb_dim):
        super(ZINCGINConv, self).__init__(aggr="add")

        self.mlp = torch.nn.Sequential(torch.nn.Linear(in_dim, emb_dim), 
                                       torch.nn.BatchNorm1d(emb_dim), 
                                       torch.nn.ReLU(),
                                       torch.nn.Linear(emb_dim, emb_dim))
        self.eps = torch.nn.Parameter(torch.Tensor([0]))

        self.bond_encoder = torch.nn.Embedding(4, in_dim)

    def forward(self, message, edge_index, 
                node_feature=None, 
                edge_attr=None, # this has to be #edgewise_edge_index indices x 1
                edgewise_edge_index=None):
        edge_embedding = self.bond_encoder(edge_attr.squeeze())
        out = self.mlp(
            (1 + self.eps) * message 
            + self.propagate(edgewise_edge_index, x=message, edge_attr=edge_embedding))

        x = out + node_feature[edge_index[0]]

        return x

    def message(self, x_j, edge_attr):
        return F.relu(x_j + edge_attr)

    def update(self, aggr_out):
        return aggr_out

class GIN(torch.nn.Module):
    def __init__(self, 
                 num_hidden_layers, 
                 in_channels, 
                 hidden_channels, 
                 out_channels, 
                 activation='relu',
                 final_activation='linear',
                 add_self_loops=True,
                 dropout=0.5,
                 use_gdc=False,
                 cached=False, 
                 task_type='node_classification',
                 norm_type=None,
                 jumping_knowledge='last',
                 **kwargs,):
        super().__init__()
        self.num_hidden_layers = num_hidden_layers
        self.use_gdc = use_gdc
        self.cached = cached
        self.task_type = task_type
        self.add_self_loops = add_self_loops
        self.activation = get_activation_function(activation)
        self.final_activation = get_activation_function(final_activation)
        self.jumping_knowledge = jumping_knowledge
        self.dropout = dropout

        self.norm_type = norm_type

        self.hidden_layers = nn.ModuleList()
        self.norm_layers = nn.ModuleList()
        for i in range(self.num_hidden_layers):
            self.hidden_layers.append(
                ZINCGINConv(hidden_channels, hidden_channels))
            self.norm_layers.append(Normalization(hidden_channels, self.norm_type))

        self.conv2 = ZINCGINConv(hidden_channels, hidden_channels)
        self.conv2_norm = Normalization(hidden_channels, self.norm_type)

        self.aggregation_weight= nn.Linear(
            hidden_channels, hidden_channels, bias=True)
        self.aggregation_weight = initialize_layer_w_zero(
            self.aggregation_weight)
        self.final_regression_layer = nn.Linear(
            hidden_channels, out_channels, bias=True)

        if self.task_type in ['graph_regression']:
            if jumping_knowledge == 'concat':
                dimensions = (self.num_hidden_layers + 1) * hidden_channels
                self.final_regression_layer = nn.Linear(dimensions, out_channels, bias=True)
                self.aggregation_weight= nn.Linear(dimensions, dimensions, bias=True)
                self.aggregation_weight = initialize_layer_w_zero(
                    self.aggregation_weight)


    def residual_layer(self, hidden_layer, norm_layer, 
                       x, node_feature, edge_index, edgewise_edge_index,
                       edge_weight=None, edge_attr=None):
        # here x: messages
        x = hidden_layer(x, edge_index, node_feature, edge_attr, edgewise_edge_index)
        x = norm_layer(x)
        x = self.activation(x)
        x = F.dropout(x, p=self.dropout, training=self.training) 
        return x
        
    def final_expectation(self, edge_messages, edge_index, node_features, batch=None):
        aggr = torch_geometric.nn.aggr.SumAggregation()
        # No activation after aggregation weights here.
        em = self.aggregation_weight(edge_messages)
        sum_of_neighbors = aggr(em, edge_index[1])
        return sum_of_neighbors

    def forward(self, x, data, edgewise_edge_index):
        # ==== Initialize ====
        node_feature = data.x.type(torch.FloatTensor).to(device)
        edge_index = data.edge_index.type(torch.LongTensor).to(device) 
        edge_weight = None
        edge_attr = data.edge_attr.type(torch.LongTensor).to(device)[edgewise_edge_index[0]]
        batch = data.batch

        h_list = []
        for i in range(self.num_hidden_layers):
            x = self.residual_layer(
                hidden_layer=self.hidden_layers[i],
                norm_layer=self.norm_layers[i],
                x=x, # x here is message
                node_feature=node_feature,
                edge_index=edge_index,
                edgewise_edge_index=edgewise_edge_index,
                edge_attr=edge_attr,
            )
            h_list.append(x)

        x = self.conv2(x, edge_index, node_feature, edge_attr, edgewise_edge_index)
        x = self.conv2_norm(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        h_list.append(x)

        if self.task_type in ['graph_regression', 'graph_classification']:
            if self.jumping_knowledge == 'last':
                h = h_list[-1]
            elif self.jumping_knowledge == 'concat':
                h = torch.cat(h_list, dim=1)
            else: 
                ValueError("Invalid jumping knowledge type")
            x = global_mean_pool(x, batch)

        x = self.final_expectation(x, edge_index, node_feature)

        x = self.final_regression_layer(x)

        final_x = self.final_activation(x)
        return final_x



class GINResidual(torch.nn.Module):
    def __init__(self, 
                 num_hidden_layers, 
                 in_channels, 
                 hidden_channels, 
                 out_channels, 
                 activation='relu',
                 final_activation='linear',
                 add_self_loops=True,
                 dropout=0.5,
                 use_gdc=False,
                 cached=False, 
                 task_type='node_classification',
                 norm_type=None,
                 jumping_knowledge='last',
                 **kwargs,):
        super().__init__()
        self.num_hidden_layers = num_hidden_layers
        self.use_gdc = use_gdc
        self.cached = cached
        self.task_type = task_type
        self.add_self_loops = add_self_loops
        self.activation = get_activation_function(activation)
        self.final_activation = get_activation_function(final_activation)
        self.jumping_knowledge = jumping_knowledge
        self.dropout = dropout

        self.norm_type = norm_type

        self.hidden_layers = nn.ModuleList()
        self.norm_layers = nn.ModuleList()
        for i in range(self.num_hidden_layers):
            self.hidden_layers.append(
                ZINCGINConv(hidden_channels, hidden_channels))
            self.norm_layers.append(Normalization(hidden_channels, self.norm_type))

        self.conv2 = ZINCGINConv(hidden_channels, hidden_channels)
        self.conv2_norm = Normalization(hidden_channels, self.norm_type)

        self.aggregation_weight= nn.Linear(
            hidden_channels, hidden_channels, bias=True)
        self.aggregation_weight = initialize_layer_w_zero(
            self.aggregation_weight)
        self.final_regression_layer = nn.Linear(
            hidden_channels, out_channels, bias=True)

        if self.task_type in ['graph_regression']:
            if jumping_knowledge == 'concat':
                dimensions = (self.num_hidden_layers + 1) * hidden_channels
                self.final_regression_layer = nn.Linear(
                    dimensions, out_channels, bias=True)
                self.aggregation_weight= nn.Linear(
                    dimensions, dimensions, bias=True)
                self.aggregation_weight = initialize_layer_w_zero(
                    self.aggregation_weight)


    def residual_layer(self, hidden_layer, norm_layer, 
                       x, node_feature, edge_index, edgewise_edge_index,
                       edge_weight=None, edge_attr=None):
        # here x: messages
        residual = x
        x = hidden_layer(x, edge_index, node_feature, edge_attr, edgewise_edge_index)
        x = norm_layer(x)
        x = self.activation(x)
        x = F.dropout(x, p=self.dropout, training=self.training) + residual
        return x
        
    def final_expectation(self, edge_messages, edge_index, node_features, batch=None):
        aggr = torch_geometric.nn.aggr.SumAggregation()
        # No activation after aggregation weights here.
        em = self.aggregation_weight(edge_messages)
        sum_of_neighbors = aggr(em, edge_index[1])
        return sum_of_neighbors

    def forward(self, x, data, edgewise_edge_index):
        # ==== Initialize ====
        node_feature = data.x.type(torch.FloatTensor).to(device)
        edge_index = data.edge_index.type(torch.LongTensor).to(device) 
        edge_weight = None
        edge_attr = data.edge_attr.type(torch.LongTensor).to(device)[edgewise_edge_index[0]]
        batch = data.batch

        h_list = []
        for i in range(self.num_hidden_layers):
            x = self.residual_layer(
                hidden_layer=self.hidden_layers[i],
                norm_layer=self.norm_layers[i],
                x=x, # x here is message
                node_feature=node_feature,
                edge_index=edge_index,
                edgewise_edge_index=edgewise_edge_index,
                edge_attr=edge_attr,
            )
            h_list.append(x)

        x = self.conv2(x, edge_index, node_feature, edge_attr, edgewise_edge_index)
        x = self.conv2_norm(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        h_list.append(x)


        if self.task_type in ['graph_regression', 'graph_classification']:
            if self.jumping_knowledge == 'last':
                h = h_list[-1]
            elif self.jumping_knowledge == 'concat':
                h = torch.cat(h_list, dim=1)
            else: 
                ValueError("Invalid jumping knowledge type")
            x = global_mean_pool(x, batch)

        x = self.final_expectation(x, edge_index, node_feature)

        if self.task_type in ['graph_regression', 'graph_classification']:
            x = global_mean_pool(x, batch)

        x = self.final_regression_layer(x)

        final_x = self.final_activation(x)
        return final_x