import torch
from torch.nn import Linear, Parameter
import torch.nn.functional as F
import torch.nn as nn
import torch_geometric
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree
from models.model_utils import get_activation_function

#TODO(tm): send in the device while creating the model.
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def get_edgewise_graph(edge_index):
    # edge_index = graph.edge_index
    idx1 = []
    idx2 = []
    for i in range(len(edge_index[0, :])):
        target_node = edge_index[0, i] # i.e, we want to find the edges where this node is the taget.
        indices = (edge_index[1, :] == target_node).nonzero(as_tuple=True)
        indices = indices[0].cpu().numpy().tolist()
        idx1 += [i] * len(indices)
        idx2 += indices
    
    final_edge_index = torch.zeros(2, len(idx1))
    final_edge_index[0, :] = torch.tensor(idx1)
    final_edge_index[1, :] = torch.tensor(idx2)
    final_edge_index = final_edge_index.type(torch.int64).to(device)
    return final_edge_index


def initialize_layer_w_zero(layer):
    layer.weight = nn.Parameter(
        torch.zeros_like(layer.weight),
        requires_grad=True,
    )
    if layer.bias is not None:
        layer.bias = nn.Parameter(
            torch.zeros_like(layer.bias),
            requires_grad=True,
        )
    return layer


class EdgeWiseLayer(MessagePassing):
    """
        # x = ReLU(W2 * message + b_2) * tanh(W_1 * message + b_1)
        v2:
            x = (W2 * message + b_2) * tanh(W_1 * message + b_1)
    """
    def __init__(self, 
                 in_channels=1,
                 node_in_channels=1,
                 out_channels=1,
                 add_bias=True,
                 activation='relu', 
                 node_activation='relu' ,
                 final_activation='linear',
                 initialize_w_zero=False,
        ):
        # TODO: figure out this aggregation, does it need any normalization anywhere?
        # Technically we have gotten rid of any reason to have this aggregation in our experiments.
        super().__init__(aggr='add')
        self.activation = get_activation_function(activation)
        self.node_activation = get_activation_function(node_activation)
        self.final_activation = get_activation_function(final_activation)
        self.add_bias = add_bias
        
        self.edge_weight = torch.nn.Conv1d(in_channels, out_channels, kernel_size=1)

        if initialize_w_zero:
            self.edge_weight = initialize_layer_w_zero(self.edge_weight)
            # self.node_weight = initialize_layer_w_zero(self.node_weight)
        
    def reset_parameters(self):
        # self.node_weight.reset_parameters()
        self.edge_weight.reset_parameters()
        self.edge_weight.bias.zero_()
        # self.node_weight.bias.zero_()

    def get_norm(self, edge_index, features):
        row, col = edge_index 
        deg = degree(col, features.size(0), dtype=features.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
        norm = torch.ones_like(norm)
        return norm


    def forward(self, message, edge_index, node_feature, edgewise_edge_index=None):
        edge_feature = self.edge_weight(message)
        edge_feature = self.activation(edge_feature)


        # norm = self.get_norm(edge_index, node_feature)
        norm = self.get_norm(edgewise_edge_index, message)

        # Step 4-5: Start propagating messages.
        edge_shape = edge_feature.shape
        edge_feature = edge_feature.reshape(edge_shape[0], -1)
        out = self.propagate(edgewise_edge_index, x=edge_feature, norm=norm)

        out = edge_feature.reshape(edge_shape)

        x = self.activation(out)
        return x


class GraphEdgeConvNetwork(torch.nn.Module):
    def __init__(
            self,
            layer_class,
            num_hidden_layers=0,
            in_channels=1,
            hidden_channels=1,
            out_channels=1,
            add_bias=False,
            activation='relu',
            node_activation='relu',
            final_activation='identity',
            initialize_w_zero=False,
        ):
        super().__init__()
        self.num_hidden_layers = num_hidden_layers
        self.add_bias = add_bias
        self.layer_class = layer_class
        self.final_activation = get_activation_function(final_activation)
        self.activation = get_activation_function(activation)
        self.conv1 = self.layer_class(
            in_channels=in_channels,
            node_in_channels=in_channels,
            out_channels=hidden_channels,
            add_bias=self.add_bias,
            activation=activation,
            node_activation=node_activation,
            final_activation=final_activation,
            initialize_w_zero=initialize_w_zero,
        )

        hidden_layers = nn.ModuleList()

        for i in range(self.num_hidden_layers):
            hidden_layers.append(
                self.layer_class(
                    in_channels=hidden_channels,
                    node_in_channels=in_channels,
                    out_channels=hidden_channels,
                    add_bias=self.add_bias,
                    activation=activation,
                    node_activation=node_activation,
                    final_activation=final_activation,
                    initialize_w_zero=initialize_w_zero,
                )
            )
        self.hidden_layers = hidden_layers
        self.conv2 = self.layer_class(
            in_channels=hidden_channels,
            node_in_channels=in_channels,
            # out_channels=out_channels,
            out_channels=hidden_channels,
            add_bias=self.add_bias,
            activation=activation,
            node_activation=node_activation,
            final_activation=final_activation,
            initialize_w_zero=initialize_w_zero,
        )

        # self.aggregation_weight = nn.Linear(
        #     hidden_channels, out_channels, bias=True)
        self.aggregation_weight = nn.Conv1d(
            hidden_channels, out_channels, kernel_size=2
        )
        # self.node_weight = nn.Linear(
        #     in_channels, hidden_channels, bias=True)
        # self.aggregation_weight = initialize_layer_w_zero(self.aggregation_weight)

    def final_expectation(self, edge_messages, edge_index, node_features):
        aggr = torch_geometric.nn.aggr.SumAggregation()
        em = self.aggregation_weight(edge_messages).relu()
        em = em.reshape(em.shape[0], -1)
        sum_of_neighbors = aggr(em, edge_index[1])
        final = self.final_activation(sum_of_neighbors)
        return final

    def forward(self, x, edge_index, node_feature, edgewise_edge_index=None, **kwargs):
        # node_feature = self.node_weight(node_feature).relu()
        x = self.conv1(x, edge_index, node_feature, edgewise_edge_index)
        for i in range(self.num_hidden_layers):
            x =  self.hidden_layers[i](
                x, edge_index, node_feature, edgewise_edge_index)
        x = self.conv2(x, edge_index, node_feature, edgewise_edge_index)
        final_x = self.final_expectation(x, edge_index, node_feature)
        return final_x

