import torch
from torch.nn import Linear, Parameter
import torch.nn.functional as F
import torch.nn as nn
import torch_geometric
from torch_geometric.nn import MessagePassing, GCNConv
from torch_geometric.utils import add_self_loops, degree
from torch_geometric.data import Data
import numpy as np
import os
import pickle
from  generate_random_inductive import * 

def save_as_pickle(data, filename, folder_path):
    name = os.path.join(folder_path, filename)
    with open(name, 'wb') as f:
        pickle.dump(data, f)

def load_from_pickle(filename):
    with open(filename, 'rb') as f:
        data = pickle.load(f)
    return data

#TODO(tm): send in the device while creating the model. 
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class UpDownLeaveNodeDirectionality(MessagePassing):
    def __init__(self, 
                 in_channels=1, 
                 out_channels=1, 
                 Jst=1, 
                 add_bias=True, 
                 undirected_edgewise_edge_index=False,
                 initialize_w_zero=False):
        super().__init__(aggr='add')
        self.tanh_activation = nn.Tanh()
        self.Jst = Jst
        self.J_source = torch.nn.Linear(in_channels, out_channels)
        self.undirected_edgewise_edge_index = undirected_edgewise_edge_index
        self.aggr = torch_geometric.nn.aggr.MeanAggregation()
    
    def reset_parameters(self):
        self.J_source.reset_parameters()
        self.J_source.bias.zero_()
        self.J_target.reset_parameters()
        self.J_target.bias.zero_()
    

    def get_norm(self, edge_index, feature):
        row, col = edge_index 
        deg = degree(col, feature.size(0), dtype=feature.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
        return norm

    def forward(self, x_up, edge_index):
        m_up = self.J_source(x_up) 

        norm = self.get_norm(edge_index, x_up)

        out_up = self.propagate(edge_index, x=m_up, norm=norm) 

        x_up = self.tanh_activation(out_up) 
        return x_up


class UndirectedBPModel(torch.nn.Module):
    def __init__(self, 
                 num_hidden_layers=0, 
                 in_channels=1, 
                 hidden_channels=1, 
                 out_channels=1,
                 add_bias=False,
                 undirected_edgewise_edge_index=False,
                 initialize_w_zero=False,
                 Jst=1,
                 message_aggregate='learned', 
                 **kwargs):
        super().__init__()
        self.num_hidden_layers = num_hidden_layers
        self.Jst = Jst
        self.add_bias = add_bias
        self.initialize_w_zero = initialize_w_zero
        self.undirected_edgewise_edge_index = undirected_edgewise_edge_index
        assert message_aggregate in ['fixed', 'learned'], "invalid message aggregate"
        self.message_aggregate = message_aggregate
        self.layer_class = UpDownLeaveNodeDirectionality
        self.add_bias = False
        self.conv_block1 = self.layer_class(
            in_channels=2 * in_channels, 
            out_channels=hidden_channels,
            undirected_edgewise_edge_index=undirected_edgewise_edge_index,
            add_bias=False,
            initialize_w_zero=initialize_w_zero,)
        
        hidden_layers = nn.ModuleList()
        
        for _ in range(self.num_hidden_layers):
            hidden_layers.append(
                self.layer_class(
                    hidden_channels, 
                    hidden_channels,
                    undirected_edgewise_edge_index=undirected_edgewise_edge_index,
                    add_bias=self.add_bias,
                    initialize_w_zero=initialize_w_zero,)
            )
        self.hidden_layers = hidden_layers
        self.conv_block2 = self.layer_class(
            in_channels=hidden_channels, 
            out_channels=hidden_channels,
            undirected_edgewise_edge_index=undirected_edgewise_edge_index,
            add_bias=self.add_bias,
            initialize_w_zero=initialize_w_zero,)

        self.aggregation_weight = nn.Linear(
            hidden_channels, out_channels, bias=False)
        # self.aggregation_weight.weight = nn.Parameter(
        #     torch.zeros_like(self.aggregation_weight.weight),
        #     requires_grad=True
        # )

    def final_expectation(self, x, edge_index, node_potentials=None):
        aggr = torch_geometric.nn.aggr.SumAggregation()
        em = self.aggregation_weight(x)
        sum_of_neighbors = aggr(em, edge_index[1])
        if node_potentials is not None:
            final = torch.tanh(node_potentials + sum_of_neighbors)
        else:
            final = torch.tanh(sum_of_neighbors)
        return final

    def forward(self, x, data, edgewise_edge_index):
        """Here x is a message """
        # edge_index = data.edge_index
        # x = (data.x[data.edge_index[0]] + data.x[data.edge_index[1]]) / 2
        edge_index = data.edge_index
        x = torch.cat([data.x[edge_index[0]], data.x[edge_index[1]]], dim=1)
        x_up = self.conv_block1(
            x_up=x, 
            edge_index=edgewise_edge_index)

        for i in range(self.num_hidden_layers):
            x_in = x_up
            x_up = self.hidden_layers[i](
                x_up=x_up, 
                edge_index=edgewise_edge_index) #+ x_in

        final_x = self.final_expectation(x_up, edge_index, node_potentials=data.x)
        return final_x  

def ising_node_potential(i, singleton_mean, singleton_var, num_dim=1):
    J = np.random.normal(singleton_mean, singleton_var, num_dim)
    return torch.FloatTensor(J)

def get_node_potential(num_nodes, singleton_mean=0, singleton_var=1, num_dim=1):
    x = torch.zeros([num_nodes, num_dim])
    for i in range(num_nodes):
        x[i, :] = ising_node_potential(
            i, 
            singleton_mean=singleton_mean, 
            singleton_var=singleton_var, 
            num_dim=num_dim)
    return x

def generate_star(num_nodes):
    """
    Generate a d-regular tree with a given depth and return the edges.

    :param d: Degree of each node (number of children).
    :param depth: Depth of the tree.
    :return: List of edges representing the d-regular tree.
    """
    center_node = torch.zeros(num_nodes-1)
    leaves = torch.arange(num_nodes-1) + 1
    edge_index = torch.stack([center_node, leaves], dim=0)
    return edge_index

def get_graph(num_nodes, 
              singleton_mean=0, 
              singleton_var=1, 
              num_dim=1):
    graph_edge_index = generate_star(num_nodes)
    graph_edge_index = torch.tensor(graph_edge_index)
    x = get_node_potential(
        num_nodes, 
        singleton_mean=singleton_mean, 
        singleton_var=singleton_var, 
        num_dim=num_dim)
    graph = Data(
        x=x,
        edge_index=graph_edge_index.to(torch.long),
    )
    return graph

def get_edgewise_edge_index(edge_index):
    def remove_intersection(keep, remove):
        keep_list = keep.cpu().numpy().tolist()
        remove_list = remove.cpu().numpy().tolist()
        out = set(keep_list) - set(remove_list)
        return list(out)
    # edge_index = graph.edge_index
    idx1 = []
    idx2 = []
    for i in range(len(edge_index[0, :])):
        target_node = edge_index[0, i] # i.e, we want to find the edges where this node is the taget.
        target_of_target = edge_index[1, i]
        keep_indices = (edge_index[1, :] == target_node).nonzero(as_tuple=True)[0]
        remove_indices = (edge_index[0, :] == target_of_target).nonzero(as_tuple=True)[0]
        indices = remove_intersection(keep_indices, remove_indices)
        idx1 += [i] * len(indices)
        idx2 += indices
    
    final_edge_index = torch.zeros(2, len(idx1))
    final_edge_index[1, :] = torch.tensor(idx1)
    final_edge_index[0, :] = torch.tensor(idx2)
    final_edge_index = final_edge_index.type(torch.int64)
    return final_edge_index

def get_edgewise_graph(graph, to_undirected=True):
    if to_undirected:
        undirected_edge_index = torch_geometric.utils.to_undirected(graph.edge_index)
    else:
        undirected_edge_index = graph.edge_index

    edgewise_edge_index = get_edgewise_edge_index(undirected_edge_index)
    edgewise_graph = Data(
        x=undirected_edge_index[1],
        edge_index=edgewise_edge_index.to(torch.long)
    )
    return edgewise_graph 

def get_undirected_edgewise_graph(graph, to_undirected=True):
    if to_undirected:
        undirected_edge_index_1 = torch_geometric.utils.to_undirected(graph.edge_index)
    else:
        undirected_edge_index_1 = graph.edge_index
    
    undirected_edge_index_2 = torch.zeros_like(undirected_edge_index_1)
    undirected_edge_index_2[0, :] = undirected_edge_index_1[1, :]
    undirected_edge_index_2[1, :] = undirected_edge_index_1[0, :]

    edgewise_edge_index_dir_1 = get_edgewise_edge_index(undirected_edge_index_1)
    edgewise_edge_index_dir_2 = get_edgewise_edge_index(undirected_edge_index_2)
    edgewise_edge_index = torch.cat(
        [edgewise_edge_index_dir_1, edgewise_edge_index_dir_2], dim=1)
    edgewise_graph = Data(
        x=undirected_edge_index_1[1],
        edge_index=edgewise_edge_index.to(torch.long)
    )
    return edgewise_graph 

for num_nodes in [16, 32, 64]:
    for depth in [1, 3, 5]:
        print('=====================')
        print(num_nodes, depth)

        dest_prefix = '/home/user/data/graph_datasets/star_graph_random_scratch'
        folder = f'ns_10000_num_nodes_{num_nodes}_depth_{depth}_dim_10'
        folder_path = os.path.join(dest_prefix, folder)
        os.makedirs(os.path.join(dest_prefix, folder), exist_ok=True)


        model = UndirectedBPModel(
                num_hidden_layers=depth, 
                in_channels=10, 
                hidden_channels=10, 
                out_channels=10,
            )
        model = model.to(device)

        for sample in range(10000):
        # for sample in range(1):
            print(sample)
            data = get_graph(num_nodes, singleton_mean=0, singleton_var=1, num_dim=10)
            undirected_edge = torch_geometric.utils.to_undirected(data.edge_index)

            edge_data = get_edgewise_graph(data, to_undirected=True)

            undir_edge_data = get_undirected_edgewise_graph(data, to_undirected=True)

            edge_data.undirected_edge_index = undir_edge_data.edge_index
            data.undirected_edge_index = data.edge_index # change this name to directed!!!
            data.edge_index = undirected_edge

            data = data.to(device)
            edge_data = edge_data.to(device)
            out = model(data.x, data, edge_data.edge_index)

            data = data.to('cpu')

            new_graph = Data(
                x=data.x,
                edge_index=data.edge_index,
                y=out.detach().cpu(),
                undirected_edge_index=data.undirected_edge_index,
            )
            name = f'example_{sample:04}.pt'
            graph_name = f'graph_example_{sample:04}.pt'
            save_as_pickle(new_graph, name, folder_path)
            save_as_pickle(edge_data.to('cpu'), graph_name, folder_path)