import sys, os
import glob
import random
import numpy as np
import torch
from torch.utils.data import Dataset
from torch_geometric.data import Data
import igraph as ig
from einops import repeat
import torch_geometric.utils as geo_utils
from torch_geometric.utils import k_hop_subgraph
from torch_geometric.data import Data
import torch_geometric

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def remove_intersection(keep, remove):
    keep_list = keep.cpu().numpy().tolist()
    remove_list = remove.cpu().numpy().tolist()
    out = set(keep_list) - set(remove_list)
    return list(out)


def get_edgewise_graph(edge_index):
    # edge_index = graph.edge_index
    idx1 = []
    idx2 = []
    for i in range(len(edge_index[0, :])):
        target_node = edge_index[0, i] # i.e, we want to find the edges where this node is the taget.
        target_of_target = edge_index[1, i]
        keep_indices = (edge_index[1, :] == target_node).nonzero(as_tuple=True)[0]
        remove_indices = (edge_index[0, :] == target_of_target).nonzero(as_tuple=True)[0]
        indices = remove_intersection(keep_indices, remove_indices)
        idx1 += [i] * len(indices)
        idx2 += indices
    
    final_edge_index = torch.zeros(2, len(idx1))
    final_edge_index[1, :] = torch.tensor(idx1)
    final_edge_index[0, :] = torch.tensor(idx2)
    final_edge_index = final_edge_index.type(torch.int64).to(device)
    return final_edge_index

class TransductiveDataLoader(Dataset):
    def __init__(self, data, num_hops=1):
        self.data = data
        self.num_hops = num_hops
        return_dict_list = self.get_edge_index_list()
        self.edgewise_edge_index_list = return_dict_list[0]
        self.k_hop_edge_index_list = return_dict_list[1]
        self.subset_list = return_dict_list[2]

    def get_edge_index_list(self):
        x = self.data.x
        edgewise_edge_index_list = []
        k_hop_edge_index_list = []
        subset_list = []
        for index in range(len(x)):
            subset, edge_index, mapping, edge_mask = geo_utils.k_hop_subgraph(
                node_idx=index, 
                num_hops=self.num_hops, 
                relabel_nodes=True, 
                edge_index=self.data.edge_index,
                flow='target_to_source'
            )
            edgewise_edge_index = get_edgewise_graph(edge_index)
            edgewise_edge_index_list.append(edgewise_edge_index)
            k_hop_edge_index_list.append(edge_index)
            subset_list.append(subset)
        return edgewise_edge_index_list, k_hop_edge_index_list, subset_list
        

    def __getitem__(self, index):
        subset = self.subset_list[index]
        edge_index = self.k_hop_edge_index_list[index]
        edgewise_edge_index = self.edgewise_edge_index_list[index]
        
        graph = Data(
            x=subset,
            edge_index=edge_index,
            y=self.data.y[subset],
            edgewise_edge_index=edgewise_edge_index
        )
        return graph
        

    def __len__(self):
        return self.data.x.shape[0]

