from re import I
import numpy as np
import networkx as nx

import torch
from torch.nn.functional import one_hot

import torch_geometric.transforms as T

from torch_geometric.loader import NeighborSampler
from torch_geometric.data import Data
from torch_geometric.utils import to_undirected, from_scipy_sparse_matrix, remove_self_loops, to_scipy_sparse_matrix, degree, to_networkx
from torch_geometric.datasets import Planetoid, Actor, CitationFull, WikipediaNetwork
from ogb.nodeproppred import PygNodePropPredDataset, Evaluator
from collections import defaultdict

from model_utils import NormalizedAggr
from scipy import sparse as sp

import faiss

import pdb
import networkx as nx
from collections import defaultdict
import pickle 
from torch_geometric.utils import degree
import random

root_dir='./data/'

def load_data(data_name, device):
    dataset = None
    if data_name in ["cora", "citeseer", "pubmed"]:
        dataset = Planetoid(name=data_name, root=root_dir+"graph_data", split="public")
        data = dataset[0]
        num_classes = dataset.num_classes
        data.y = data.y.unsqueeze(-1)

    elif data_name in ["actor"]:
        dataset = Actor(root=root_dir+"graph_data")
        data = dataset[0]
        num_classes = dataset.num_classes
        data.y = data.y.unsqueeze(-1)
        
    elif data_name in ["chameleon", "squirrel"]:
        graph_adjacency_list_file_path = root_dir+'graph_data/{}/out1_graph_edges.txt'.format(data_name)
        graph_node_features_and_labels_file_path =  root_dir+'graph_data/{}/out1_node_feature_label.txt'.format(data_name)

        G = nx.Graph()
        graph_node_features_dict = {}
        graph_labels_dict = {}

        with open(graph_node_features_and_labels_file_path) as graph_node_features_and_labels_file:
            graph_node_features_and_labels_file.readline()
            for line in graph_node_features_and_labels_file:
                line = line.rstrip().split('\t')
                assert (len(line) == 3)
                assert (int(line[0]) not in graph_node_features_dict and int(line[0]) not in graph_labels_dict)
                graph_node_features_dict[int(line[0])] = np.array(line[1].split(','), dtype=np.uint8)
                graph_labels_dict[int(line[0])] = int(line[2])

        with open(graph_adjacency_list_file_path) as graph_adjacency_list_file:
            graph_adjacency_list_file.readline()
            for line in graph_adjacency_list_file:
                line = line.rstrip().split('\t')
                assert (len(line) == 2)
                if int(line[0]) not in G:
                    G.add_node(int(line[0]), features=graph_node_features_dict[int(line[0])],
                               label=graph_labels_dict[int(line[0])])
                if int(line[1]) not in G:
                    G.add_node(int(line[1]), features=graph_node_features_dict[int(line[1])],
                               label=graph_labels_dict[int(line[1])])
                G.add_edge(int(line[0]), int(line[1]))

        adj = nx.adjacency_matrix(G, sorted(G.nodes()))
        features = np.array([features for _, features in sorted(G.nodes(data='features'), key=lambda x: x[0])])
        labels = np.array([label for _, label in sorted(G.nodes(data='label'), key=lambda x: x[0])])
        edge_index = from_scipy_sparse_matrix(adj)[0].long()
        x = torch.tensor(features).float()
        x = x/(x.sum(1, keepdim=True).clamp(min=1))
        y = torch.tensor(labels).unsqueeze(-1).long()
        data = Data(x=x, y=y, edge_index=edge_index)
        num_classes = 5
    else:
        dataset = PygNodePropPredDataset(name=data_name, root= root_dir+"ogb")
        data = dataset[0]
        num_classes = dataset.num_classes

    if data_name == "ogbn-proteins":
        data.x = data.adj_t.mean(dim=1)
        data.adj_t.set_value_(None)
    else:
        data.edge_index = to_undirected(data.edge_index)
    
    data = data.to(device)
    
    # evaluator = Evaluator(name=name)
    return num_classes, data, dataset

def load_split(data_name, data, run, device, dataset=None):
    if data_name in ["cora", "citeseer", "pubmed"]:
        # split_idx = dict()
        # split_idx['train'] = torch.where(data.train_mask)[0]
        # split_idx['valid'] = torch.where(data.val_mask)[0]
        # split_idx['test'] = torch.where(data.test_mask)[0]
        splits_file_path=root_dir+f"graph_data/splits/{data_name}_split_0.6_0.2_{run}.npz"
        split_idx = dict()
        with np.load(splits_file_path) as splits_file:
            train = splits_file['train_mask']
            split_idx['train'] = torch.tensor(np.where(train==1)[0])
            val = splits_file['val_mask']
            split_idx['valid'] = torch.tensor(np.where(val==1)[0])
            test = splits_file['test_mask']
            split_idx['test'] = torch.tensor(np.where(test==1)[0])

    elif data_name in ["actor"]:
        split_idx = dict()
        split_idx['train'] = torch.where(data.train_mask[:,run])[0]
        split_idx['valid'] = torch.where(data.val_mask[:,run])[0]
        split_idx['test'] = torch.where(data.test_mask[:,run])[0]
    
    elif data_name in ["chameleon", "squirrel"]:
        splits_file_path= root_dir+f"graph_data/splits/{data_name}_split_0.6_0.2_{run}.npz"
        split_idx = dict()
        with np.load(splits_file_path) as splits_file:
            train = splits_file['train_mask']
            split_idx['train'] = torch.tensor(np.where(train==1)[0])
            val = splits_file['val_mask']
            split_idx['valid'] = torch.tensor(np.where(val==1)[0])
            test = splits_file['test_mask']
            split_idx['test'] = torch.tensor(np.where(test==1)[0])
    else:
        split_idx = dataset.get_idx_split()
    train_idx = split_idx['train'].to(device)

    return split_idx, train_idx


def init_seed(seed=2020):
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def generate_reference(data, sort_types, num=1, data_name=None):
    sort_types = sort_types.split(',')
    reference_points_list = []
    for sort_type in sort_types:
        if sort_type == "bfs_load":
            with open(f'reference_points_per_node_{data_name}.pkl', 'rb') as f:
                reference_points = pickle.load(f)
                reference_points = reference_points.to(data.x.device)
        elif sort_type == "ppnp_load":
            with open(f'ppnp_reference_points_per_node_{data_name}.pkl', 'rb') as f:
                reference_points = pickle.load(f)                
                reference_points = reference_points.to(data.x.device)
        elif sort_type == "feat_load":
            with open(f'feat_reference_points_per_node_{data_name}.pkl', 'rb') as f:
                reference_points = pickle.load(f)      
                reference_points = reference_points.to(data.x.device)
        
        reference_points_list.append(reference_points)
    
    return reference_points_list
