import torch
from torch.nn import Linear, Parameter
import torch.nn.functional as F
import torch.nn as nn
import torch_geometric
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree

#TODO(tm): send in the device while creating the model. 
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class IsingMessagePassingLayer(MessagePassing):
    def __init__(self, in_channels=1, out_channels=1, 
                 cached=False, normalize=False, Jst=1, add_bias=False):
        super().__init__(aggr='add')
        self.tanh_activation = nn.Tanh()
        self.Jst = Jst
        self.add_bias = add_bias
        if self.add_bias:
            self.J = torch.nn.Linear(in_channels, out_channels, bias=True)
            self.J.bias = torch.nn.Parameter(
                self.Jst * torch.zeros(out_channels), 
                requires_grad=True)
        else:
            self.J = torch.nn.Linear(in_channels, out_channels, bias=False)
        
        J_init_weight = torch.tanh(self.Jst * torch.zeros_like(self.J.weight))
        J_init_weight = J_init_weight + 0 * torch.randn_like(J_init_weight)
        self.J.weight = torch.nn.Parameter(
            J_init_weight / max(in_channels, out_channels),
            requires_grad=True)
    
    def reset_parameters(self):
        self.J.reset_parameters()
        if self.add_bias:
            self.J.bias.data.zero_()
    
    # def forward(self, x, edge_index, edge_potential, edge_weight=None):
    def forward(self, message, edge_index, edge_potential, edge_weight=None):
        # NOTE: we can't really add any edge index to the mix. this will mess things up for us.
        # edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
        middle = torch.atanh(self.J(message))
        
        
        # Step 3: Compute normalization.
        row, col = edge_index
        deg = degree(col, edge_potential.size(0), dtype=edge_potential.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
        norm = torch.ones_like(norm)

        # Step 4-5: Start propagating messages.
        out = self.propagate(edge_index, x=middle, norm=norm)
        
        x = out + edge_potential[edge_index[0]] 
        x = self.tanh_activation(x)
        return x
    
    def aggregate(self, inputs):
        return inputs
    
    def message(self, x, edge_index):
        def get_nbd_message(source_node, target_node):
            
            # get source's subgraph
            source_subgraph = torch_geometric.utils.k_hop_subgraph(
                source_node.item(), num_hops=1, edge_index=edge_index, 
                flow='source_to_target', directed=True)
            target_subgraph = torch_geometric.utils.k_hop_subgraph(
                target_node.item(), num_hops=1, edge_index=edge_index, 
                flow='target_to_source', directed=True)
            # get common edge
            bool_tensor_source = source_subgraph[-1]
            bool_tensor_target = target_subgraph[-1]
            out = torch.logical_xor(
                bool_tensor_source, 
                torch.logical_and(bool_tensor_source, bool_tensor_target))
            return out
        
        # belief = x.clone().detach()
        belief = torch.zeros_like(x)
        
        for i in range(edge_index.shape[1]):
            edge = edge_index[:, i]
            source_node = edge[0]
            target_node = edge[1]
            out = get_nbd_message(source_node, target_node)
            if torch.sum(out) > 0: 
                belief[i, :] = torch.sum(x[out], dim=0)
        return belief


class LinearMultiplyIsingMessagePassingLayer(MessagePassing):
    """
        # x = ReLU(W2 * message + b_2) * tanh(W_1 * message + b_1)
        v2: 
            x = (W2 * message + b_2) * tanh(W_1 * message + b_1)
    """
    def __init__(self, in_channels=1, out_channels=1, 
                 cached=False, normalize=False, Jst=1, add_bias=True):
        super().__init__(aggr='add')
        self.tanh_activation = nn.Tanh()
        self.Jst = Jst
        self.J1 = torch.nn.Linear(in_channels, out_channels)
        self.J2 = torch.nn.Linear(in_channels, out_channels)
    
    def reset_parameters(self):
        self.J1.reset_parameters()
        self.J1.bias.zero_()
        self.J2.reset_parameters()
        self.J2.bias.zero_()
    
    # def forward(self, x, edge_index, edge_potential, edge_weight=None):
    def forward(self, message, edge_index, edge_potential, edge_weight=None):
        # NOTE: we can't really add any edge index to the mix. this will mess things up for us.
        # edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
        middle1 = self.J1(message).tanh()
        middle2 = self.J2(message)
        middle = torch.mul(middle1, middle2)
        
        
        # Step 3: Compute normalization.
        row, col = edge_index
        deg = degree(col, edge_potential.size(0), dtype=edge_potential.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
        norm = torch.ones_like(norm)

        # Step 4-5: Start propagating messages.
        out = self.propagate(edge_index, x=middle, norm=norm)
        
        x = out + edge_potential[edge_index[0]] 
        x = self.tanh_activation(x)
        return x
    
    def aggregate(self, inputs):
        return inputs
    
    def message(self, x, edge_index):
        def get_nbd_message(source_node, target_node):
            
            # get source's subgraph
            source_subgraph = torch_geometric.utils.k_hop_subgraph(
                source_node.item(), num_hops=1, edge_index=edge_index, 
                flow='source_to_target', directed=True)
            target_subgraph = torch_geometric.utils.k_hop_subgraph(
                target_node.item(), num_hops=1, edge_index=edge_index, 
                flow='target_to_source', directed=True)
            # get common edge
            bool_tensor_source = source_subgraph[-1]
            bool_tensor_target = target_subgraph[-1]
            out = torch.logical_xor(
                bool_tensor_source, 
                torch.logical_and(bool_tensor_source, bool_tensor_target))
            return out
        
        # belief = x.clone().detach()
        belief = torch.zeros_like(x)
        
        for i in range(edge_index.shape[1]):
            edge = edge_index[:, i]
            source_node = edge[0]
            target_node = edge[1]
            out = get_nbd_message(source_node, target_node)
            if torch.sum(out) > 0: 
                belief[i, :] = torch.sum(x[out], dim=0)
        return belief

class BeliefPropLayers(torch.nn.Module):
    def __init__(self, 
                 layer_class,
                 num_hidden_layers=0, 
                 in_channels=1, 
                 hidden_channels=1, 
                 out_channels=1,
                 add_bias=True,
                 Jst=1,
                 message_aggregate='fixed'):
        super().__init__()
        self.num_hidden_layers = num_hidden_layers
        self.Jst = Jst
        self.add_bias = add_bias
        print(self.Jst)
        assert message_aggregate in ['fixed', 'learned'], "invalid message aggregate"
        self.message_aggregate = message_aggregate
        self.layer_class = layer_class
        self.conv1 = self.layer_class(
            in_channels=in_channels, 
            out_channels=hidden_channels,
            add_bias=self.add_bias,
            cached=True, 
            normalize=False)
        
        hidden_layers = nn.ModuleList()
        
        for i in range(self.num_hidden_layers):
            hidden_layers.append(
                self.layer_class(
                    hidden_channels, hidden_channels,
                    add_bias=self.add_bias,
                    cached=True, normalize=False, Jst=self.Jst)
            )
        self.hidden_layers = hidden_layers
        self.conv2 = self.layer_class(
            hidden_channels, out_channels, cached=True, 
            add_bias=self.add_bias,
            normalize=False, Jst=self.Jst)
        if self.message_aggregate == 'learned':
            self.aggregation_weight = nn.Linear(
                out_channels, out_channels, bias=False)
            self.aggregation_weight.weight = nn.Parameter(
                torch.zeros_like(self.aggregation_weight.weight),
                requires_grad=True
            )

    def final_expectation(self, edge_messages, edge_index, node_potentials):
        aggr = torch_geometric.nn.aggr.SumAggregation()
        # em = torch.atanh(torch.tanh(torch.tensor([Jst]).to(device) * edge_messages))
        #NOTE: maybe this needs a learnable parameter as well.
        if self.message_aggregate == 'fixed':
            em = torch.atanh(
                torch.tanh(torch.tensor([self.Jst]).to(device) * edge_messages))
        elif self.message_aggregate == 'learned':
            em = torch.tanh(self.aggregation_weight(edge_messages))
        sum_of_neighbors = aggr(em, edge_index[1])
        final = torch.tanh(node_potentials + sum_of_neighbors)
        return final
        
    def forward(self, x, edge_index, node_feature, 
                edgewise_edge_index=None, **kwargs):
        x = self.conv1(x, edge_index, node_feature)
        for i in range(self.num_hidden_layers):
            x =  self.hidden_layers[i](x, edge_index, node_feature)
        x = self.conv2(x, edge_index, node_feature)
        final_x = self.final_expectation(x, edge_index, node_feature)
        return final_x
