import os, sys
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
sys.path.append("../..")
import torch
import torch.nn.functional as F
import torch.nn as nn
import torch_geometric

from hydra import compose, initialize
from hydra.utils import instantiate
from omegaconf import OmegaConf
from torch_geometric.data import Data
from torch_geometric.datasets import ZINC

from torch.utils.data import Dataset

from torch_geometric.nn import GCNConv, GATConv, global_mean_pool
from torch_geometric.loader import DataLoader
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


def get_edgewise_edge_index(edge_index):
    def remove_intersection(keep, remove):
        keep_list = keep.cpu().numpy().tolist()
        remove_list = remove.cpu().numpy().tolist()
        out = set(keep_list) - set(remove_list)
        return list(out)
    # edge_index = graph.edge_index
    idx1 = []
    idx2 = []
    for i in range(len(edge_index[0, :])):
        target_node = edge_index[0, i] # i.e, we want to find the edges where this node is the taget.
        target_of_target = edge_index[1, i]
        keep_indices = (edge_index[1, :] == target_node).nonzero(as_tuple=True)[0]
        remove_indices = (edge_index[0, :] == target_of_target).nonzero(as_tuple=True)[0]
        indices = remove_intersection(keep_indices, remove_indices)
        idx1 += [i] * len(indices)
        idx2 += indices
    
    final_edge_index = torch.zeros(2, len(idx1))
    final_edge_index[1, :] = torch.tensor(idx1)
    final_edge_index[0, :] = torch.tensor(idx2)
    # this has to be the source, i.e., all the messages from the other edges go to this edge
    final_edge_index = final_edge_index.type(torch.int64)
    return final_edge_index


def get_edgewise_graph(graph):
    undirected_edge_index = torch_geometric.utils.to_undirected(graph.edge_index)
    edgewise_edge_index = get_edgewise_edge_index(undirected_edge_index)
    # if hasattr(graph, 'edge_attr'):
    #     edgewise_edge_attr = graph.edge_attr[edgewise_edge_index[0]]
    #     edgewise_graph = Data(
    #         x=undirected_edge_index[1],
    #         edge_index=edgewise_edge_index.to(torch.long),
    #         edge_attr=edgewise_edge_attr,
    #     )
    # else:
    edgewise_graph = Data(
        x=undirected_edge_index[1],
        edge_index=edgewise_edge_index.to(torch.long)
    )
    return edgewise_graph 

class EdgeWiseDataLoader(Dataset):
    def __init__(self, dataset_args, mode='train', if_split=True):
        self.mode = mode
        if dataset_args.if_split:
            self.main_loader = instantiate(
                dataset_args.loader_params, 
                split=mode)
            self.edgewise_loader = instantiate(
                dataset_args.loader_params, 
                root=os.path.join(dataset_args.loader_params.root, 'pretransformed'),
                split=mode, 
                pre_transform=get_edgewise_graph)
            self.num_samples = len(self.main_loader)
        else:
            self.main_loader = instantiate(
                dataset_args.loader_params)
            self.edgewise_loader = instantiate(
                dataset_args.loader_params, 
                root=os.path.join(dataset_args.loader_params.root, 'pretransformed'),
                pre_transform=get_edgewise_graph)
            assert hasattr(dataset_args, 'train_split'), (
                "need percentage of datat that you want to take care of")
            split_percent = getattr(dataset_args, 'train_split')
            if mode == 'test':
                split_percent = 1 - split_percent
            self.num_samples = int(dataset_args.num_samples * split_percent)


    def __getitem__(self, index):
        if self.mode == 'train':
            return self.main_loader[index], self.edgewise_loader[index]
        else:
            return self.main_loader[-index], self.edgewise_loader[-index]

    def __len__(self):
        return self.num_samples



