import torch
import torch.nn.functional as F
import torch.nn as nn

from torch_geometric.nn import SAGEConv, global_mean_pool
from models.model_utils import get_activation_function
from models.norms import Normalization

GRADIENT_CHECKPOINTING = False
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')



class GraphSAGE(torch.nn.Module):
    def __init__(self, 
                 num_hidden_layers, 
                 in_channels, 
                 hidden_channels, 
                 out_channels, 
                 activation='relu',
                 final_activation='linear',
                 add_self_loops=True,
                 dropout=0.5,
                 use_gdc=False,
                 cached=False, 
                 task_type='node_classification',
                 norm_type=None,
                 **kwargs,):
        super().__init__()
        self.num_hidden_layers = num_hidden_layers
        self.use_gdc = use_gdc
        self.cached = cached
        self.task_type = task_type
        self.add_self_loops = add_self_loops
        self.activation = get_activation_function(activation)
        self.final_activation = get_activation_function(final_activation)
        self.dropout = dropout
        self.norm_type = norm_type


        self.conv1 = SAGEConv(hidden_channels, hidden_channels, 
                             add_self_loops=self.add_self_loops,
                             cached=self.cached,
                             normalize=not use_gdc)
        self.conv1_norm = Normalization(hidden_channels, self.norm_type)

        self.hidden_layers = nn.ModuleList()
        self.norm_layers = nn.ModuleList()
        for i in range(self.num_hidden_layers):
            self.hidden_layers.append(
                    SAGEConv(
                        hidden_channels, hidden_channels, 
                        cached=self.cached, 
                        add_self_loops=self.add_self_loops,
                        normalize=not use_gdc)
            )
            self.norm_layers.append(Normalization(hidden_channels, self.norm_type))
        self.conv2 = SAGEConv(hidden_channels, hidden_channels, 
                             add_self_loops=self.add_self_loops,
                             cached=self.cached,
                             normalize=not use_gdc)
        self.conv2_norm = Normalization(hidden_channels, self.norm_type)
        self.final_regression_layer = nn.Linear(
            hidden_channels, out_channels, bias=True)

    def residual_layer(self, hidden_layer, norm_layer, x, edge_index, edge_weight=None):
        x = hidden_layer(x, edge_index, edge_weight)
        x = norm_layer(x)
        x = self.activation(x)
        x = F.dropout(x, p=self.dropout, training=self.training) 
        return x

    def forward(self, data):
        # ==== Initialize ====
        x = data.x.type(torch.FloatTensor).to(device)
        edge_index = data.edge_index 
        edge_weight = None
        batch = data.batch

        x = self.conv1(x, edge_index, edge_weight)
        x = self.conv1_norm(x) 
        x = self.activation(x)
        x = F.dropout(x, p=self.dropout, training=self.training)

        for i in range(self.num_hidden_layers):
            if GRADIENT_CHECKPOINTING and self.training:
                def create_custom_forward(module):
                    def custom_forward(*inputs):
                        return module(*inputs)
                    return custom_forward

                x = torch.utils.checkpoint.checkpoint(
                    create_custom_forward(self.residual_layer),
                    self.hidden_layers[i],
                    self.norm_layers[i],
                    x,
                    edge_index,)
            else:
                x = self.residual_layer(
                    hidden_layer=self.hidden_layers[i],
                    norm_layer=self.norm_layers[i],
                    x=x,
                    edge_index=edge_index,
                )

        x = self.conv(x, edge_index, edge_weight)
        x = self.conv_norm(x) 
        x = F.dropout(x, p=self.dropout, training=self.training)

        if self.task_type in ['graph_regression']:
            x = global_mean_pool(x, batch)

        x = self.final_regression_layer(x)
        x = self.final_activation(x)
        return x

    def reset_parameters(self):
        self.conv1.reset_parameters()
        self.conv1_norm.reset_parameters()
        self.conv2.reset_parameters()
        self.conv2_norm.reset_parameters()
        for i in range(self.num_hidden_layers):
            self.hidden_layers[i].reset_parameters()
            self.norm_layers[i].reset_parameters()


class GraphSAGEResidual(torch.nn.Module):
    def __init__(self, 
                 num_hidden_layers, 
                 in_channels, 
                 hidden_channels, 
                 out_channels, 
                 activation='relu',
                 final_activation='linear',
                 add_self_loops=True,
                 dropout=0.5,
                 use_gdc=False,
                 cached=False, 
                 task_type='node_classification',
                 norm_type=None,
                 **kwargs,):
        super().__init__()
        self.num_hidden_layers = num_hidden_layers
        self.use_gdc = use_gdc
        self.cached = cached
        self.task_type = task_type
        self.add_self_loops = add_self_loops
        self.activation = get_activation_function(activation)
        self.final_activation = get_activation_function(final_activation)
        self.dropout = dropout
        self.norm_type = norm_type

        self.conv1 = SAGEConv(in_channels, hidden_channels, 
                             add_self_loops=self.add_self_loops,
                             cached=self.cached,
                             normalize=not use_gdc)
        self.conv1_norm = Normalization(hidden_channels, self.norm_type)

        self.hidden_layers = nn.ModuleList()
        self.norm_layers = nn.ModuleList()
        for i in range(self.num_hidden_layers):
            self.hidden_layers.append(
                    SAGEConv(
                        hidden_channels, hidden_channels, 
                        cached=self.cached, 
                        add_self_loops=self.add_self_loops,
                        normalize=not use_gdc)
            )
            self.norm_layers.append(Normalization(hidden_channels, self.norm_type))
        self.conv2 = SAGEConv(hidden_channels, hidden_channels, 
                             add_self_loops=self.add_self_loops,
                             cached=self.cached,
                             normalize=not use_gdc)
        self.conv2_norm = Normalization(hidden_channels, self.norm_type)
        self.final_regression_layer = nn.Linear(
            hidden_channels, out_channels, bias=True)

    def residual_layer(self, hidden_layer, norm_layer, x, edge_index, edge_weight=None):
        x = hidden_layer(x, edge_index, edge_weight)
        x = norm_layer(x)
        x = self.activation(x)
        x = F.dropout(x, p=self.dropout, training=self.training) 
        return x

    def forward(self, data):
        # ==== Initialize ====
        x = data.x.type(torch.FloatTensor).to(device)
        edge_index = data.edge_index 
        edge_weight = None
        batch = data.batch

        x = self.conv1(x, edge_index, edge_weight)
        x = self.conv1_norm(x) 
        x = self.activation(x)
        x = F.dropout(x, p=self.dropout, training=self.training)

        for i in range(self.num_hidden_layers):
            if GRADIENT_CHECKPOINTING and self.training:
                def create_custom_forward(module):
                    def custom_forward(*inputs):
                        return module(*inputs)
                    return custom_forward

                x = torch.utils.checkpoint.checkpoint(
                    create_custom_forward(self.residual_layer),
                    self.hidden_layers[i],
                    self.norm_layers[i],
                    x,
                    edge_index,)
            else:
                x = self.residual_layer(
                    hidden_layer=self.hidden_layers[i],
                    norm_layer=self.norm_layers[i],
                    x=x,
                    edge_index=edge_index,
                )

        x = self.conv2(x, edge_index, edge_weight)
        x = self.conv2_norm(x) 
        x = F.dropout(x, p=self.dropout, training=self.training)

        if self.task_type in ['graph_regression']:
            x = global_mean_pool(x, batch)

        x = self.final_regression_layer(x)
        x = self.final_activation(x)
        return x

    def reset_parameters(self):
        self.conv1.reset_parameters()
        self.conv1_norm.reset_parameters()
        self.conv2.reset_parameters()
        self.conv2_norm.reset_parameters()
        for i in range(self.num_hidden_layers):
            self.hidden_layers[i].reset_parameters()
            self.norm_layers[i].reset_parameters()
