from typing import Optional, Callable, List

import os
import os.path as osp
import shutil

import networkx as nx

import torch
from torch_geometric.data import InMemoryDataset, download_url, extract_zip, Data
from torch_geometric.io import read_tu_data


class S2VGraph(object):
    def __init__(self, g, label, node_tags=None, node_features=None):
        '''
            g: a networkx graph
            label: an integer graph label
            node_tags: a list of integer node tags
            node_features: a torch float tensor, one-hot representation of the tag that is used as input to neural nets
            edge_mat: a torch long tensor, contain edge list, will be used to create torch sparse tensor
            neighbors: list of neighbors (without self-loop)
        '''
        self.label = label
        self.g = g
        self.node_tags = node_tags
        self.neighbors = []
        self.node_features = 0
        self.edge_mat = 0

        self.max_neighbor = 0

        
def load_data(dataset, degree_as_tag):
    '''
        dataset: name of dataset
        test_proportion: ratio of test train split
        seed: random seed for random splitting of dataset
    '''

    print('loading data')
    g_list = []
    label_dict = {}
    feat_dict = {}

    with open('dataset/%s/%s.txt' % (dataset, dataset), 'r') as f:
        n_g = int(f.readline().strip())
        for i in range(n_g):
            row = f.readline().strip().split()
            n, l = [int(w) for w in row]
            if not l in label_dict:
                mapped = len(label_dict)
                label_dict[l] = mapped
            g = nx.Graph()
            node_tags = []
            node_features = []
            n_edges = 0
            for j in range(n):
                g.add_node(j)
                row = f.readline().strip().split()
                tmp = int(row[1]) + 2
                if tmp == len(row):
                    # no node attributes
                    row = [int(w) for w in row]
                    attr = None
                else:
                    row, attr = [int(w) for w in row[:tmp]], np.array([float(w) for w in row[tmp:]])
                if not row[0] in feat_dict:
                    mapped = len(feat_dict)
                    feat_dict[row[0]] = mapped
                node_tags.append(feat_dict[row[0]])

                if tmp > len(row):
                    node_features.append(attr)

                n_edges += row[1]
                for k in range(2, len(row)):
                    g.add_edge(j, row[k])

            if node_features != []:
                node_features = np.stack(node_features)
                node_feature_flag = True
            else:
                node_features = None
                node_feature_flag = False

            assert len(g) == n

            g_list.append(S2VGraph(g, l, node_tags))

    #add labels and edge_mat       
    for g in g_list:
        g.neighbors = [[] for i in range(len(g.g))]
        for i, j in g.g.edges():
            g.neighbors[i].append(j)
            g.neighbors[j].append(i)
        degree_list = []
        for i in range(len(g.g)):
            g.neighbors[i] = g.neighbors[i]
            degree_list.append(len(g.neighbors[i]))
        g.max_neighbor = max(degree_list)

        g.label = label_dict[g.label]

        edges = [list(pair) for pair in g.g.edges()]
        edges.extend([[i, j] for j, i in edges])

        deg_list = list(dict(g.g.degree(range(len(g.g)))).values())
        g.edge_mat = torch.LongTensor(edges).transpose(0,1)

    if degree_as_tag:
        for g in g_list:
            g.node_tags = list(dict(g.g.degree).values())

    #Extracting unique tag labels   
    tagset = set([])
    for g in g_list:
        tagset = tagset.union(set(g.node_tags))

    tagset = list(tagset)
    tag2index = {tagset[i]:i for i in range(len(tagset))}

    for g in g_list:
        g.node_features = torch.zeros(len(g.node_tags), len(tagset))
        g.node_features[range(len(g.node_tags)), [tag2index[tag] for tag in g.node_tags]] = 1


    print('# classes: %d' % len(label_dict))
    print('# maximum node tag: %d' % len(tagset))

    print("# data: %d" % len(g_list))

    return g_list, len(label_dict)

class TUDataset(InMemoryDataset):
    def __init__(self, root: str, name: str,
                 transform: Optional[Callable] = None,
                 pre_transform: Optional[Callable] = None,
                 pre_filter: Optional[Callable] = None,
                 use_node_attr: bool = True, use_edge_attr: bool = False,
                 cleaned: bool = False):
        self.name = name
        self.cleaned = cleaned
        super().__init__(root, transform, pre_transform, pre_filter)

        self.data, self.slices = torch.load(self.processed_paths[0])
        if self.data.x is not None and not use_node_attr:
            num_node_attributes = self.num_node_attributes
            self.data.x = self.data.x[:, num_node_attributes:]
        if self.data.edge_attr is not None and not use_edge_attr:
            num_edge_attributes = self.num_edge_attributes
            self.data.edge_attr = self.data.edge_attr[:, num_edge_attributes:]

#     @property
#     def processed_dir(self) -> str:
#         name = f'processed{"_cleaned" if self.cleaned else ""}'
#         return osp.join(self.root, self.name, name)

    @property
    def num_node_labels(self) -> int:
        if self.data.x is None:
            return 0
        for i in range(self.data.x.size(1)):
            x = self.data.x[:, i:]
            if ((x == 0) | (x == 1)).all() and (x.sum(dim=1) == 1).all():
                return self.data.x.size(1) - i
        return 0

    @property
    def num_node_attributes(self) -> int:
        if self.data.x is None:
            return 0
        return self.data.x.size(1) - self.num_node_labels

    @property
    def raw_file_names(self) -> List[str]:
        names = ['A', 'graph_indicator']
        return []

    @property
    def processed_file_names(self) -> str:
        return  self.name + '.pt'

    def download(self):
        pass

    def process(self):
#         self.data, self.slices = read_tu_data(self.raw_dir, self.name)
        graphs, num_classes = load_data(self.name, False)
        
    
        counter = 0
        data_list=[]
        for graph in graphs: 
            X_concat = graph.node_features
            nidx = counter + torch.arange(X_concat.shape[0])
            counter = counter + X_concat.shape[0]
            Adj_block = self.__preprocess_neighbors_sumavepool([graph])
            data_list.append(Data(x=X_concat, edge_index=Adj_block.coalesce().indices(), y=graph.label, nidx=nidx))
            
        
        
#         if self.pre_filter is not None:
#             data_list = [self.get(idx) for idx in range(len(self))]
#             data_list = [data for data in data_list if self.pre_filter(data)]
#             self.data, self.slices = self.collate(data_list)

#         if self.pre_transform is not None:
#             data_list = [self.get(idx) for idx in range(len(self))]
#             data_list = [self.pre_transform(data) for data in data_list]
#             
        self.data, self.slices = self.collate(data_list)
        torch.save((self.data, self.slices), self.processed_paths[0])

    def __repr__(self) -> str:
        return f'{self.name}({len(self)})'

    def __preprocess_neighbors_sumavepool(self, batch_graph):
        ###create block diagonal sparse matrix

        edge_mat_list = []
        start_idx = [0]
        for i, graph in enumerate(batch_graph):
            start_idx.append(start_idx[i] + len(graph.g))
            edge_mat_list.append(graph.edge_mat + start_idx[i])
        Adj_block_idx = torch.cat(edge_mat_list, 1)
        Adj_block_elem = torch.ones(Adj_block_idx.shape[1])

    #     #Add self-loops in the adjacency matrix if learn_eps is False, i.e., aggregate center nodes and neighbor nodes altogether.

    #     if not self.learn_eps:
    #         num_node = start_idx[-1]
    #         self_loop_edge = torch.LongTensor([range(num_node), range(num_node)])
    #         elem = torch.ones(num_node)
    #         Adj_block_idx = torch.cat([Adj_block_idx, self_loop_edge], 1)
    #         Adj_block_elem = torch.cat([Adj_block_elem, elem], 0)

        Adj_block = torch.sparse.FloatTensor(Adj_block_idx, Adj_block_elem, torch.Size([start_idx[-1],start_idx[-1]]))

        return Adj_block

    
from torch_geometric.datasets import Planetoid
import torch_geometric.transforms as T

def get_planetoid_dataset(name, normalize_features=True, transform=None, split="complete"):
    path = osp.join('.', 'data', name)
    if split == 'complete':
        dataset = Planetoid(path, name)
        dataset[0].train_mask.fill_(False)
        dataset[0].train_mask[:dataset[0].num_nodes - 1000] = 1
        dataset[0].val_mask.fill_(False)
        dataset[0].val_mask[dataset[0].num_nodes - 1000:dataset[0].num_nodes - 500] = 1
        dataset[0].test_mask.fill_(False)
        dataset[0].test_mask[dataset[0].num_nodes - 500:] = 1
    else:
        dataset = Planetoid(path, name, split=split)
    if transform is not None and normalize_features:
        dataset.transform = T.Compose([T.NormalizeFeatures(), transform])
    elif normalize_features:
        dataset.transform = T.NormalizeFeatures()
    elif transform is not None:
        dataset.transform = transform
    return dataset

def one_hot_embedding(labels, num_classes):
    y = torch.eye(num_classes) 
    return y[labels] 

from torch_geometric.data import Data
class PlanetoidDataset(torch.utils.data.Dataset):
    def __init__(self, split='train', samples_per_epoch=100, name='Cora', device='cpu'):
        dataset = get_planetoid_dataset(name)
        self.X = dataset[0].x.float().to(device)
#         self.y = one_hot_embedding(dataset[0].y,dataset.num_classes).float().to(device)
        self.y = dataset[0].y.to(device)
        self.edge_index = dataset[0].edge_index.long().to(device)
        self.n_features = dataset[0].num_node_features
        self.num_classes = dataset.num_classes
        
        if split=='train':
            self.mask = dataset[0].train_mask.to(device)
        if split=='val':
            self.mask = dataset[0].val_mask.to(device)
        if split=='test':
            self.mask = dataset[0].test_mask.to(device)
         
        self.samples_per_epoch = samples_per_epoch

    def __len__(self):
        return self.samples_per_epoch

    def __getitem__(self, idx):
        return Data(x=self.X, y=self.y, mask=self.mask, edge_index=self.edge_index)
    
    
    
import pickle
import torch_geometric.utils as tut
class StructuresDataset(torch.utils.data.Dataset):
    def __init__(self, train=True, test=False, structure='grid', samples=200):
        
        with open("src/dataset/%s_dataset_size_10_20_p0.02.pkl" % structure,'rb') as f:
#         with open("src/dataset/%s_dataset_size_30_50.pkl" % structure,'rb') as f:
            D = pickle.load(f)
            self.Dpos = []
            self.Dneg = []
            
            self.motif = D[0]
            
            if train:
                self.Dpos += D[1][:int(samples*0.8)]
                self.Dneg += D[2][:int(samples*0.8)]
                
            if test:
                self.Dpos += D[1][int(samples*0.8):samples]
                self.Dneg += D[2][int(samples*0.8):samples]
            
        self.n_features = 1
        self.n_features = 1
        self.num_classes = 2
        
    def __len__(self):
        return len(self.Dpos)*2

    def __getitem__(self, idx):
        
        if idx%2:
            edge_index = torch.tensor(self.Dpos[idx//2]).long().t()
            y = torch.tensor(1)
        else:
            edge_index = torch.tensor(self.Dneg[idx//2]).long().t()
            y = torch.tensor(0)
        X = torch.ones(edge_index.max()+1).float()[:,None]
        edge_index = tut.to_undirected(edge_index)
            
        
        return Data(x=X, y=y, edge_index=edge_index,nidx=X[:,0]*0)
    
    
    
    
    
    
    
    
    
    
    
    
    
    