import os, sys
import torch
import torch.nn.functional as F
import torch.nn as nn
import wandb
import numpy as np
from torch_geometric.data import Data

import torch_geometric
import torch_geometric.transforms as T
import torch_geometric.utils as geo_utils
from einops import repeat
from scipy.sparse import csr_matrix
from scipy.sparse.csgraph import minimum_spanning_tree
 
from torch_geometric.datasets import WikipediaNetwork
import torch_geometric.utils as geo_utils
import pickle


def remove_intersection(keep, remove):
    keep_list = keep.cpu().numpy().tolist()
    remove_list = remove.cpu().numpy().tolist()
    out = set(keep_list) - set(remove_list)
    return list(out)


def get_edgewise_graph(edge_index):
    # edge_index = graph.edge_index
    idx1 = []
    idx2 = []
    for i in range(len(edge_index[0, :])):
        target_node = edge_index[0, i] # i.e, we want to find the edges where this node is the taget.
        target_of_target = edge_index[1, i]
        keep_indices = (edge_index[1, :] == target_node).nonzero(as_tuple=True)[0]
        remove_indices = (edge_index[0, :] == target_of_target).nonzero(as_tuple=True)[0]
        indices = remove_intersection(keep_indices, remove_indices)
        idx1 += [i] * len(indices)
        idx2 += indices
    
    final_edge_index = torch.zeros(2, len(idx1))
    final_edge_index[1, :] = torch.tensor(idx1)
    final_edge_index[0, :] = torch.tensor(idx2)
    final_edge_index = final_edge_index.type(torch.int64)
    return final_edge_index


root_folder = '/home/user/data'
sub_name = 'squirrel'

final_folder_name = f"{root_folder}/graph_datasets/1_hop_nbd/{sub_name}"
os.makedirs(final_folder_name, exist_ok=True)


dataset = WikipediaNetwork(root=root_folder, name=sub_name)

data = dataset[0]
transform = T.ToUndirected()
splits_file = np.load(
    f'/home/user/data/{sub_name}/'
    f'geom_gcn/raw/'
    f'{sub_name}_split_0.6_0.2_0.npz')
train_mask = splits_file['train_mask']
val_mask = splits_file['val_mask']
test_mask = splits_file['test_mask']
data.train_mask = torch.tensor(train_mask, dtype=torch.bool)
data.val_mask = torch.tensor(val_mask, dtype=torch.bool)
data.test_mask = torch.tensor(test_mask, dtype=torch.bool)


def get_1_hop_edge_subset(node_index, edge_index):
    source_indices = edge_index[:, (edge_index[0] == node_index)]
    target_indices = edge_index[:, (edge_index[1] == node_index)]
    return torch.cat((source_indices, target_indices), 1)

def get_1_hop_ring_nbd(node_index, edge_index):
    subset_edge_index = get_1_hop_edge_subset(node_index, edge_index)
    unique_nodes_within = subset_edge_index[...].unique()
    relabeled_edges = torch.zeros_like(subset_edge_index)
    index_inversion_list = {
        unique_nodes_within[i].item():i for i in range(len(unique_nodes_within))}

    for i in range(subset_edge_index.shape[-1]):
        nodes = subset_edge_index[:, i]
        relabeled_edges[0, i] = index_inversion_list[nodes[0].item()]
        relabeled_edges[1, i] = index_inversion_list[nodes[1].item()]
    return unique_nodes_within, subset_edge_index, relabeled_edges

with open(os.path.join(final_folder_name, 'data.pkl'), 'wb') as pickle_file:
    pickle.dump(data, pickle_file)


undirected_edge_index = geo_utils.to_undirected(data.edge_index)
# undirected_edge_index = data.edge_index

aggr = torch_geometric.nn.aggr.SumAggregation()
source_degrees = torch_geometric.utils.degree(undirected_edge_index[0])
target_degrees = torch_geometric.utils.degree(undirected_edge_index[1])

for node_index in range(data.x.shape[0]):
    print(node_index)
    # print(f"source degree: {source_degrees[node_index]}")
    # print(f"target degree: {target_degrees[node_index]}")
    subset, subset_edge_index, relabeled_edge_index = get_1_hop_ring_nbd(
        node_index, undirected_edge_index)
    # if len(subset)-1 != relabeled_edge_index[...].max().item():
    #     print(len(subset))
    #     print(relabeled_edge_index[...].max())

    edgewise_edge_index = get_edgewise_graph(relabeled_edge_index)
    print(edgewise_edge_index.shape)
    if edgewise_edge_index.shape[-1] == 0:
        print(relabeled_edge_index.shape)
    # em = aggr(edgewise_edge_index[1], subset_edge_index[1])
    graph = Data(
        x=data.x[subset],
        y=data.y[subset],
        edge_index=relabeled_edge_index,
        original_x=subset,
        edgewise_edge_index=edgewise_edge_index,
        train_mask=data.train_mask[subset],
        test_mask=data.test_mask[subset],
        val_mask=data.val_mask[subset]
    )

    # filename = os.path.join(final_folder_name, f'data_{node_index}.pkl')
    # with open(filename, 'wb') as pickle_file:
    #     pickle.dump(graph, pickle_file)

