import os, sys
import torch
import torch.nn.functional as F
import torch.nn as nn
import wandb
import numpy as np
from torch_geometric.data import Data

import torch_geometric.transforms as T
import torch_geometric.utils as geo_utils
from einops import repeat
from scipy.sparse import csr_matrix
from scipy.sparse.csgraph import minimum_spanning_tree
 
from torch_geometric.datasets import WikipediaNetwork
import torch_geometric.utils as geo_utils
import pickle



def remove_intersection(keep, remove):
    keep_list = keep.cpu().numpy().tolist()
    remove_list = remove.cpu().numpy().tolist()
    out = set(keep_list) - set(remove_list)
    return list(out)


def get_edgewise_graph(edge_index):
    # edge_index = graph.edge_index
    idx1 = []
    idx2 = []
    for i in range(len(edge_index[0, :])):
        target_node = edge_index[0, i] # i.e, we want to find the edges where this node is the taget.
        target_of_target = edge_index[1, i]
        keep_indices = (edge_index[1, :] == target_node).nonzero(as_tuple=True)[0]
        remove_indices = (edge_index[0, :] == target_of_target).nonzero(as_tuple=True)[0]
        indices = remove_intersection(keep_indices, remove_indices)
        idx1 += [i] * len(indices)
        idx2 += indices
    
    final_edge_index = torch.zeros(2, len(idx1))
    final_edge_index[1, :] = torch.tensor(idx1)
    final_edge_index[0, :] = torch.tensor(idx2)
    final_edge_index = final_edge_index.type(torch.int64)
    return final_edge_index



root_folder = '/home/user/data'
sub_name = 'squirrel'
# sub_name = 'cora_2'
num_hops = 2

final_folder_name = f"{root_folder}/graph_datasets/k_hop_nbd/{sub_name}"
os.makedirs(final_folder_name, exist_ok=True)

dataset = WikipediaNetwork(root=root_folder, name=sub_name)

data = dataset[0]
splits_file = np.load(
    f'/home/user/data/{sub_name}/'
    f'geom_gcn/raw/'
    f'{sub_name}_split_0.6_0.2_0.npz')
train_mask = splits_file['train_mask']
val_mask = splits_file['val_mask']
test_mask = splits_file['test_mask']
data.train_mask = torch.tensor(train_mask, dtype=torch.bool)
data.val_mask = torch.tensor(val_mask, dtype=torch.bool)
data.test_mask = torch.tensor(test_mask, dtype=torch.bool)



with open(os.path.join(final_folder_name, 'data.pkl'), 'wb') as pickle_file:
    pickle.dump(data, pickle_file)

undirected_edge = geo_utils.to_undirected(data.edge_index)

for node_index in range(data.x.shape[0]):
    print(node_index)
    subset, edge_index, mapping, edge_mask = geo_utils.k_hop_subgraph(
        node_idx=node_index, 
        num_hops=num_hops, 
        relabel_nodes=True, 
        edge_index=undirected_edge,
    )
    edgewise_edge_index = get_edgewise_graph(edge_index)
    graph = Data(
        x=data.x[subset],
        y=data.y[subset],
        edge_index=edge_index,
        original_x=subset,
        edgewise_edge_index=edgewise_edge_index,
        train_mask=data.train_mask[subset],
        test_mask=data.test_mask[subset],
        val_mask=data.val_mask[subset]
    )

    filename = os.path.join(final_folder_name, f'data_{node_index}.pkl')
    with open(filename, 'wb') as pickle_file:
        pickle.dump(graph, pickle_file)

