import torch
from torch.nn import Linear, Parameter
import torch.nn.functional as F
import torch.nn as nn
import torch_geometric
from torch_geometric.nn import MessagePassing, GCNConv
from torch_geometric.utils import add_self_loops, degree

#TODO(tm): send in the device while creating the model. 
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


class UpDownLeaveOneOut(MessagePassing):
    def __init__(self, in_channels=1, out_channels=1, 
                 cached=False, normalize=False, Jst=1, 
                 add_bias=True, initialize_w_zero=False):
        super().__init__(aggr='add')
        self.tanh_activation = nn.Tanh()
        self.Jst = Jst
        self.J_source = torch.nn.Linear(in_channels, out_channels)
        self.J_target = torch.nn.Linear(in_channels, out_channels)
        self.aggr = torch_geometric.nn.aggr.SumAggregation()
        if initialize_w_zero:
            self.J_source.weight = nn.Parameter(
                torch.zeros_like(self.J_source.weight), requires_grad=True)
            self.J_source.bias = nn.Parameter(
                torch.zeros_like(self.J_source.bias), requires_grad=True)
            self.J_target.weight = nn.Parameter(
                torch.zeros_like(self.J_target.weight), requires_grad=True)
            self.J_target.bias = nn.Parameter(
                torch.zeros_like(self.J_target.bias), requires_grad=True)
    
    def reset_parameters(self):
        self.J_source.reset_parameters()
        self.J_source.bias.zero_()
        self.J_target.reset_parameters()
        self.J_target.bias.zero_()

    def get_norm(self, edge_index, feature, node_feature=None):
        row, col = edge_index 
        deg = degree(col, feature.size(0), dtype=feature.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
        return norm

    def forward(self, x_up, x_down, edge_index, edgewise_edge_index=None):
        m_s = self.J_source(x_up)
        m_t = self.J_target(x_down)

        edgewise_edge_index_flip = torch.zeros_like(edgewise_edge_index)
        edgewise_edge_index_flip[0, :] = edgewise_edge_index[1, :]
        edgewise_edge_index_flip[1, :] = edgewise_edge_index[0, :]

        messages_s = m_s[edge_index[0]]
        messages_t = m_t[edge_index[1]]
        
        # norm = torch.ones_like(edgewise_edge_index[0])
        norm = self.get_norm(edgewise_edge_index, messages_t)

        # Step 4-5: Start propagating messages.
        out_s = self.propagate(edgewise_edge_index, x=messages_s, norm=norm) 
        out_t = self.propagate(edgewise_edge_index_flip, x=messages_t, norm=norm) 
        
        x_s = self.aggr(out_s, edge_index[0])
        x_t = self.aggr(out_t, edge_index[1])
        x_s = self.tanh_activation(x_s) 
        x_t = self.tanh_activation(x_t) 
        return x_s, x_t


class UpDownLeaveOneOutBP(MessagePassing):
    def __init__(self, 
                 in_channels=1, 
                 out_channels=1, 
                 Jst=1, 
                 add_bias=True, 
                 undirected_edgewise_edge_index=False,
                 initialize_w_zero=False):
        super().__init__(aggr='add')
        self.tanh_activation = nn.Tanh()
        self.Jst = Jst
        self.J_source = torch.nn.Linear(in_channels, out_channels)
        self.J_target = torch.nn.Linear(in_channels, out_channels)
        self.aggr = torch_geometric.nn.aggr.SumAggregation()
        self.undirected_edgewise_edge_index = undirected_edgewise_edge_index
        if initialize_w_zero:
            self.J_source.weight = nn.Parameter(
                torch.zeros_like(self.J_source.weight), requires_grad=True)
            self.J_source.bias = nn.Parameter(
                torch.zeros_like(self.J_source.bias), requires_grad=True)
            self.J_target.weight = nn.Parameter(
                torch.zeros_like(self.J_target.weight), requires_grad=True)
            self.J_target.bias = nn.Parameter(
                torch.zeros_like(self.J_target.bias), requires_grad=True)
    
    def reset_parameters(self):
        self.J_source.reset_parameters()
        self.J_source.bias.zero_()
        self.J_target.reset_parameters()
        self.J_target.bias.zero_()

    def get_norm(self, edge_index, feature, node_feature=None):
        row, col = edge_index 
        deg = degree(col, feature.size(0), dtype=feature.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
        return norm

    def forward(self, x_up, x_down, edge_index, edgewise_edge_index=None):
        m_s = self.J_source(x_up)
        m_t = self.J_target(x_down)

        if self.undirected_edgewise_edge_index:
            edgewise_edge_index_flip = torch_geometric.utils.to_undirected(edgewise_edge_index)
            edgewise_edge_index = torch_geometric.utils.to_undirected(edgewise_edge_index)
        else:
            edgewise_edge_index_flip = torch.zeros_like(edgewise_edge_index)
            edgewise_edge_index_flip[0, :] = edgewise_edge_index[1, :]
            edgewise_edge_index_flip[1, :] = edgewise_edge_index[0, :]

        messages = m_s[edge_index[0]] + m_t[edge_index[1]]
        
        # norm = torch.ones_like(edgewise_edge_index[0])
        norm = self.get_norm(edgewise_edge_index, messages)

        # Step 4-5: Start propagating messages.
        out_s = self.propagate(edgewise_edge_index, x=messages, norm=norm) 
        out_t = self.propagate(edgewise_edge_index_flip, x=messages, norm=norm) 

        x_s = self.aggr(out_s, edge_index[0])
        x_t = self.aggr(out_t, edge_index[1])
        x_s = self.tanh_activation(x_s) 
        x_t = self.tanh_activation(x_t) 
        return x_s, x_t



class NodeBeliefPropModel(torch.nn.Module):
    def __init__(self, 
                 layer_class,
                 num_hidden_layers=0, 
                 in_channels=1, 
                 hidden_channels=1, 
                 out_channels=1,
                 add_bias=True,
                 initialize_w_zero=False,
                 Jst=1,
                 message_aggregate='learned', 
                 **kwargs):
        super().__init__()
        self.num_hidden_layers = num_hidden_layers
        self.Jst = Jst
        self.add_bias = add_bias
        self.initialize_w_zero = initialize_w_zero
        print(self.Jst)
        assert message_aggregate in ['fixed', 'learned'], "invalid message aggregate"
        self.message_aggregate = message_aggregate
        # self.layer_class = UpDownBlock
        # self.layer_class = UpDownLeaveOneOut
        self.layer_class = UpDownLeaveOneOutBP
        self.conv_block1 = self.layer_class(
            in_channels=in_channels, 
            out_channels=hidden_channels,
            add_bias=self.add_bias,
            initialize_w_zero=initialize_w_zero,)
        
        hidden_layers = nn.ModuleList()
        
        for _ in range(self.num_hidden_layers):
            hidden_layers.append(
                self.layer_class(
                    hidden_channels, 
                    hidden_channels,
                    add_bias=self.add_bias,
                    initialize_w_zero=initialize_w_zero,)
            )
        self.hidden_layers = hidden_layers
        self.conv_block2 = self.layer_class(
            in_channels=hidden_channels, 
            out_channels=out_channels,
            add_bias=self.add_bias,
            initialize_w_zero=initialize_w_zero,)

        if self.message_aggregate == 'learned':
            self.aggregation_weight = nn.Linear(
                out_channels, out_channels, bias=False)
            self.aggregation_weight.weight = nn.Parameter(
                torch.zeros_like(self.aggregation_weight.weight),
                requires_grad=True
            )

    def final_expectation(self, x_up, x_down, edge_index, node_potentials):
        edge_messages = x_up[edge_index[0]] + x_down[edge_index[1]]
        aggr = torch_geometric.nn.aggr.SumAggregation()
        if self.message_aggregate == 'fixed':
            em = torch.atanh(
                torch.tanh(torch.tensor([self.Jst]).to(device) * edge_messages))
        elif self.message_aggregate == 'learned':
            em = torch.tanh(self.aggregation_weight(edge_messages))
        sum_of_neighbors = aggr(em, edge_index[1])
        final = torch.tanh(node_potentials + sum_of_neighbors)
        return final
        
    def forward(self, data, x=None, edgewise_edge_index=None):
        # ==== data from the main graph ==== 
        x = data.x.type(torch.FloatTensor).to(device)
        edge_index = data.edge_index
        x_up, x_down = self.conv_block1(
            x_up=x, 
            x_down=x, 
            edge_index=edge_index, 
            edgewise_edge_index=edgewise_edge_index)
        for i in range(self.num_hidden_layers):
            x_up_in = x_up
            x_down_in = x_down
            x_up, x_down = self.hidden_layers[i](
                x_up=x_up, 
                x_down=x_down, 
                edge_index=edge_index,
                edgewise_edge_index=edgewise_edge_index)
            x_up = x_up + x_up_in
            x_down = x_down + x_down_in 
        x_up, x_down = self.conv_block2(
            x_up=x_up, 
            x_down=x_down, 
            edge_index=edge_index,
            edgewise_edge_index=edgewise_edge_index)
        # final_x = (x_up + x_down) / 2
        final_x = self.final_expectation(x_up, x_down,
                                         edge_index, node_potentials=x)
        return final_x