import torch
import torch.nn.functional as F
import torch.nn as nn

from torch_geometric.nn import GCNConv, GINConv, global_mean_pool, BatchNorm
from ogb.graphproppred.mol_encoder import BondEncoder
from models.norms import Normalization
from models.model_utils import get_activation_function
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import degree
from hydra.utils import instantiate

GRADIENT_CHECKPOINTING = False
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


class GNN(torch.nn.Module):
    def __init__(self, 
                 conv_args,
                 num_hidden_layers, 
                 in_channels, 
                 hidden_channels, 
                 out_channels, 
                 activation='linear',
                 final_activation='linear',
                 add_self_loops=True,
                 dropout=0.5,
                 use_gdc=False,
                 cached=False, 
                 task_type='node_classification',
                 norm_type=None,
                 jumping_knowledge='last',
                 **kwargs,):
        super().__init__()
        self.num_hidden_layers = num_hidden_layers
        self.use_gdc = use_gdc
        self.cached = cached
        self.task_type = task_type
        self.add_self_loops = add_self_loops
        self.activation = get_activation_function(activation)
        self.final_activation = get_activation_function(final_activation)
        self.jumping_knowledge = jumping_knowledge
        self.dropout = dropout

        self.norm_type = norm_type

        self.conv1 = instantiate(conv_args, 
                                 in_channels=in_channels, 
                                 out_channels=hidden_channels)
        self.conv1_norm = Normalization(hidden_channels, self.norm_type)

        self.hidden_layers = nn.ModuleList()
        self.norm_layers = nn.ModuleList()
        for i in range(self.num_hidden_layers):
            self.hidden_layers.append(
                instantiate(conv_args, 
                            in_channels=hidden_channels, 
                            out_channels=hidden_channels))
            self.norm_layers.append(Normalization(hidden_channels, self.norm_type))

        self.final_regression_layer = nn.Linear(
            hidden_channels, out_channels, bias=True)

        if jumping_knowledge == 'concat':
            self.final_regression_layer = nn.Linear(
                (self.num_hidden_layers + 1 ) * hidden_channels, 
                out_channels, bias=True)

    def residual_layer(self, hidden_layer, norm_layer, 
                       x, edge_index, 
                       edge_weight=None, edge_attr=None):
        x = hidden_layer(x, edge_index, edge_attr)
        x = norm_layer(x)
        x = self.activation(x)
        x = F.dropout(x, p=self.dropout, training=self.training) 
        return x

    def forward(self, data):
        # ==== Initialize ====
        x = data.x.type(torch.FloatTensor).to(device)
        edge_index = data.edge_index.type(torch.LongTensor).to(device) 
        # edge_weight = None
        if hasattr(data, 'edge_attr') and (getattr(data, 'edge_attr') is not None):
            edge_attr = data.edge_attr.type(torch.LongTensor).to(device)
        else:
            edge_attr = None
        batch = data.batch


        x = self.conv1(x, edge_index, edge_attr)
        x = self.conv1_norm(x) 
        x = self.activation(x)
        x = F.dropout(x, p=self.dropout, training=self.training)

        h_list = []
        for i in range(self.num_hidden_layers):
            x = self.residual_layer(
                hidden_layer=self.hidden_layers[i],
                norm_layer=self.norm_layers[i],
                x=x,
                edge_index=edge_index,
                edge_attr=edge_attr,
            )
            h_list.append(x)

        h_list.append(x)

        if self.jumping_knowledge == 'last':
            h = h_list[-1]
        elif self.jumping_knowledge == 'concat':
            h = torch.cat(h_list, dim=1)
        else: 
            ValueError("Invalid jumping knowledge type")

        if self.task_type in ['graph_regression', 'graph_classification']:
            x = global_mean_pool(h, batch)

        x = self.final_regression_layer(x)
        x = self.final_activation(x)
        return x



class GNNResidual(torch.nn.Module):
    def __init__(self, 
                 conv_args,
                 num_hidden_layers, 
                 in_channels, 
                 hidden_channels, 
                 out_channels, 
                 activation='linear',
                 final_activation='linear',
                 add_self_loops=True,
                 dropout=0.5,
                 use_gdc=False,
                 cached=False, 
                 task_type='node_classification',
                 norm_type=None,
                 jumping_knowledge='last',
                 **kwargs,):
        super().__init__()
        self.num_hidden_layers = num_hidden_layers
        self.use_gdc = use_gdc
        self.cached = cached
        self.task_type = task_type
        self.add_self_loops = add_self_loops
        self.activation = get_activation_function(activation)
        self.final_activation = get_activation_function(final_activation)
        self.jumping_knowledge = jumping_knowledge
        self.dropout = dropout

        self.norm_type = norm_type

        self.conv1 = instantiate(conv_args, 
                                 in_channels=in_channels, 
                                 out_channels=hidden_channels)
        self.conv1_norm = Normalization(hidden_channels, self.norm_type)

        self.hidden_layers = nn.ModuleList()
        self.norm_layers = nn.ModuleList()
        for i in range(self.num_hidden_layers):
            self.hidden_layers.append(
                instantiate(conv_args, 
                            in_channels=hidden_channels, 
                            out_channels=hidden_channels))
            self.norm_layers.append(Normalization(hidden_channels, self.norm_type))

        self.final_regression_layer = nn.Linear(
            hidden_channels, out_channels, bias=True)

        if jumping_knowledge == 'concat':
            self.final_regression_layer = nn.Linear(
                (self.num_hidden_layers + 1 ) * hidden_channels, 
                out_channels, bias=True)

    def residual_layer(self, hidden_layer, norm_layer, 
                       x, edge_index, 
                       edge_weight=None, edge_attr=None):
        residual = x
        x = hidden_layer(x, edge_index, edge_attr)
        x = norm_layer(x)
        x = self.activation(x)
        x = F.dropout(x, p=self.dropout, training=self.training)  + residual
        return x

    def forward(self, data):
        # ==== Initialize ====
        x = data.x.type(torch.FloatTensor).to(device)
        edge_index = data.edge_index.type(torch.LongTensor).to(device) 
        # edge_weight = None
        # if hasattr(data, 'edge_attr') and getattr(data, 'edge_attr') is not None:
        #     edge_attr = data.edge_attr.type(torch.LongTensor).to(device)
        # else:
        edge_attr = None
        batch = data.batch


        x = self.conv1(x, edge_index, edge_attr)
        x = self.conv1_norm(x) 
        x = self.activation(x)
        x = F.dropout(x, p=self.dropout, training=self.training)

        h_list = []
        for i in range(self.num_hidden_layers):
            x = self.residual_layer(
                hidden_layer=self.hidden_layers[i],
                norm_layer=self.norm_layers[i],
                x=x,
                edge_index=edge_index,
                edge_attr=edge_attr,
            )
            h_list.append(x)

        x = F.dropout(x, p=self.dropout, training=self.training)
        h_list.append(x)

        if self.jumping_knowledge == 'last':
            h = h_list[-1]
        elif self.jumping_knowledge == 'concat':
            h = torch.cat(h_list, dim=1)
        else: 
            ValueError("Invalid jumping knowledge type")

        if self.task_type in ['graph_regression', 'graph_classification']:
            x = global_mean_pool(h, batch)

        x = self.final_regression_layer(x)
        x = self.final_activation(x)
        return x


