import torch
from torch.nn import Linear, Parameter
import torch.nn.functional as F
import torch.nn as nn
import torch_geometric
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree
from hydra.utils import instantiate

#TODO(tm): send in the device while creating the model. 
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')



class FlipFlopMessagePassingLayer(MessagePassing):
    """
        # x = ReLU(W2 * message + b_2) * tanh(W_1 * message + b_1)
        v2: 
            x = (W2 * message + b_2) * tanh(W_1 * message + b_1)
    """
    def __init__(self, in_channels=1, out_channels=1, 
                 cached=False, normalize=False, 
                 add_bias=True, initialize_w_zero=False):
        super().__init__(aggr='add')
        self.tanh_activation = nn.Tanh()
        self.sigmoid_activation = nn.Sigmoid()
        self.J1 = torch.nn.Linear(in_channels, out_channels)
        # self.J2 = torch.nn.Linear(in_channels, out_channels)
        self.J2 = torch.nn.Linear(out_channels, out_channels)
        if initialize_w_zero:
            self.J1.weight = nn.Parameter(
                torch.zeros_like(self.J1.weight), requires_grad=True)
            self.J1.bias = nn.Parameter(
                torch.zeros_like(self.J1.bias), requires_grad=True)
            self.J2.weight = nn.Parameter(
                torch.zeros_like(self.J2.weight), requires_grad=True)
            self.J2.bias = nn.Parameter(
                torch.zeros_like(self.J2.bias), requires_grad=True)
    
    def reset_parameters(self):
        self.J1.reset_parameters()
        self.J1.bias.zero_()
        self.J2.reset_parameters()
        self.J2.bias.zero_()

    # def forward(self, x, edge_index, edge_potential, edge_weight=None):
    def forward(self, message, edge_index, edge_potential=None):
        # middle1 = self.tanh_activation(self.J1(message))
        # middle2 = self.J2(message)
        # middle = torch.mul(middle1, middle2)
        # middle = torch.log(1 - self.J1(message))  - torch.log(1 - self.J1(message))
        # middle = 1 -  self.sigmoid_activation(self.J1(message))
        middle = self.J1(message)
        middle = self.J2(middle)
        out = self.propagate(edge_index, x=middle, add_self_loops=False)
        return self.sigmoid_activation(out)
    

class FlipFlopModel(torch.nn.Module):
    def __init__(self, 
                 conv_args,
                 num_hidden_layers=2, 
                 in_channels=1, 
                 hidden_channels=1, 
                 out_channels=1,
                 add_bias=False,
                 initialize_w_zero=False,
                 **kwargs):
        super().__init__()
        self.num_hidden_layers = num_hidden_layers
        self.add_bias = add_bias
        self.initialize_w_zero = initialize_w_zero
        self.conv_args = conv_args

        self.conv1 = instantiate(
            conv_args,
            in_channels=in_channels,
            out_channels=hidden_channels,
            add_bias=self.add_bias,
        )
        
        self.hidden_layers = nn.ModuleList()
        
        for i in range(self.num_hidden_layers):
            self.hidden_layers.append(
                instantiate(
                    conv_args,
                    in_channels=hidden_channels,
                    out_channels=hidden_channels,
                    add_bias=self.add_bias,
                )
            )
        self.conv2 = instantiate(
            conv_args,
            in_channels=hidden_channels,
            out_channels=out_channels,
            add_bias=self.add_bias
        )

    def forward(self, data):
        # ==== data from the main graph ==== 

        x = self.conv1(data.x, data.edge_index)
        for i in range(self.num_hidden_layers):
            x =  self.hidden_layers[i](x, data.edge_index)
        x = self.conv2(x, data.edge_index)
        return x
