from torch_geometric.data import InMemoryDataset, download_url, Data
from torch_geometric.utils.undirected import to_undirected
from torch_geometric.datasets import Planetoid
import torch_geometric.transforms as T
from torch_sparse import coalesce
import torch
import pickle
import os.path as osp
import os
import numpy as np
from utils import get_rnd_seed


def random_split(labels, class_size, node_size, percls_trn, all_val):
    # adopted from https://github.com/ivam-he/ChebNetII/blob/main/main/utils.py
    trn_idx, rest_idx = [], []
    seed = get_rnd_seed()
    rnd_state = np.random.RandomState(seed)
    # training-splits
    for lb in range(class_size):
        class_idx = np.where(labels == lb)[0]
        if percls_trn > len(class_idx):
            trn_idx.extend(class_idx)
        else:
            trn_idx.extend(rnd_state.choice(class_idx, percls_trn, replace=False))
    # val & tst-splits
    rest_idx = [i for i in range(node_size) if i not in trn_idx]
    val_idx = rnd_state.choice(rest_idx, all_val, replace=False)
    tst_idx = [j for j in rest_idx if j not in val_idx]
    return trn_idx, val_idx, tst_idx, seed


def split_data(labels, trn_pro, val_pro, nb_split=10):
    # full-supervised: 0.6/0.2/0.2
    # semi-supervised: 0.025/0.025/0.95
    class_size, node_size = labels.max() + 1, len(labels)
    percls_trn = int(round(trn_pro * node_size / class_size))
    all_val = int(round(val_pro * node_size))
    split_dic_ls = []
    for _ in range(nb_split):
        trn_idx, val_idx, tst_idx, seed = random_split(labels, class_size, node_size, percls_trn, all_val)
        assert len(set(trn_idx).intersection(val_idx)) == 0
        assert len(set(trn_idx).intersection(tst_idx)) == 0
        assert len(set(val_idx).intersection(tst_idx)) == 0
        split_dic_ls.append({"trn_idx": trn_idx, "val_idx": val_idx, "tst_idx": tst_idx, "seed": seed})
    return split_dic_ls


def random_planetoid_split(labels, class_size, node_size, percls_trn, all_val, all_tst):
    trn_idx, rest_idx = [], []
    seed = get_rnd_seed()
    rnd_state = np.random.RandomState(seed)
    # training-splits
    for lb in range(class_size):
        class_idx = np.where(labels == lb)[0]
        if percls_trn > len(class_idx):
            trn_idx.extend(class_idx)
        else:
            trn_idx.extend(rnd_state.choice(class_idx, percls_trn, replace=False))
    # val & tst-splits
    rest_idx = [i for i in range(node_size) if i not in trn_idx]
    val_idx = rnd_state.choice(rest_idx, all_val, replace=False)
    all_tst_idx = [j for j in rest_idx if j not in val_idx]
    tst_idx = rnd_state.choice(all_tst_idx, all_tst, replace=False)
    return trn_idx, val_idx, tst_idx, seed


def split_sparse_planetoid(labels, percls_trn=20, all_val=500, all_tst=1000, nb_split=10):
    class_size, node_size = labels.max() + 1, len(labels)
    split_ls = []
    for _ in range(nb_split):
        trn_idx, val_idx, tst_idx, seed = random_planetoid_split(labels, class_size, node_size, percls_trn, all_val, all_tst)
        assert len(set(trn_idx).intersection(val_idx)) == 0
        assert len(set(trn_idx).intersection(tst_idx)) == 0
        assert len(set(val_idx).intersection(tst_idx)) == 0
        split_ls.append({"trn_idx": trn_idx, "val_idx": val_idx, "tst_idx": tst_idx, "seed": seed})
    return split_ls


class Chame_Squir_Actor(InMemoryDataset):
    # adopted from https://github.com/ivam-he/ChebNetII/blob/main/main/dataset_loader.py
    def __init__(self, root, name=None, p2raw=None, transform=None, pre_transform=None):
        if name == 'actor':
            name = 'film'
        existing_dataset = ['chameleon', 'film', 'squirrel']
        if name not in existing_dataset:
            raise ValueError(f'name of hypergraph dataset must be one of: {existing_dataset}')
        else:
            self.name = name

        if (p2raw is not None) and osp.isdir(p2raw):
            self.p2raw = p2raw
        elif p2raw is None:
            self.p2raw = None
        elif not osp.isdir(p2raw):
            raise ValueError(
                f'path to raw hypergraph dataset "{p2raw}" does not exist!')

        if not osp.isdir(root):
            os.makedirs(root)

        self.root = root

        super(Chame_Squir_Actor, self).__init__(root, transform, pre_transform)

        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def raw_dir(self):
        return osp.join(self.root, self.name, 'raw')

    @property
    def processed_dir(self):
        return osp.join(self.root, self.name, 'processed')

    @property
    def raw_file_names(self):
        file_names = [self.name]
        return file_names

    @property
    def processed_file_names(self):
        return ['data.pt']

    def download(self):
        pass

    def process(self):
        p2f = osp.join(self.raw_dir, self.name)
        with open(p2f, 'rb') as f:
            data = pickle.load(f)
        data = data if self.pre_transform is None else self.pre_transform(data)
        torch.save(self.collate([data]), self.processed_paths[0])

    def __repr__(self):
        return '{}()'.format(self.name)


class WebKB(InMemoryDataset):
    # adopted from https://github.com/ivam-he/ChebNetII/blob/main/main/dataset_loader.py
    url = ('https://raw.githubusercontent.com/graphdml-uiuc-jlu/geom-gcn/master/new_data')

    def __init__(self, root, name, transform=None, pre_transform=None):
        self.name = name.lower()
        assert self.name in ['cornell', 'texas', 'washington', 'wisconsin']

        super(WebKB, self).__init__(root, transform, pre_transform)
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def raw_dir(self):
        return osp.join(self.root, self.name, 'raw')

    @property
    def processed_dir(self):
        return osp.join(self.root, self.name, 'processed')

    @property
    def raw_file_names(self):
        return ['out1_node_feature_label.txt', 'out1_graph_edges.txt']

    @property
    def processed_file_names(self):
        return 'data.pt'

    def download(self):
        for name in self.raw_file_names:
            download_url(f'{self.url}/{self.name}/{name}', self.raw_dir)

    def process(self):
        with open(self.raw_paths[0], 'r') as f:
            data = f.read().split('\n')[1:-1]
            x = [[float(v) for v in r.split('\t')[1].split(',')] for r in data]
            x = torch.tensor(x, dtype=torch.float)

            y = [int(r.split('\t')[2]) for r in data]
            y = torch.tensor(y, dtype=torch.long)

        with open(self.raw_paths[1], 'r') as f:
            data = f.read().split('\n')[1:-1]
            data = [[int(v) for v in r.split('\t')] for r in data]
            edge_index = torch.tensor(data, dtype=torch.long).t().contiguous()
            edge_index = to_undirected(edge_index)
            edge_index, _ = coalesce(edge_index, None, x.size(0), x.size(0))

        data = Data(x=x, edge_index=edge_index, y=y)
        data = data if self.pre_transform is None else self.pre_transform(data)
        torch.save(self.collate([data]), self.processed_paths[0])

    def __repr__(self):
        return '{}()'.format(self.name)


def DataLoader(name):
    # adopted from https://github.com/ivam-he/ChebNetII/blob/main/main/dataset_loader.py
    name = name.lower()
    if name in ['cora', 'citeseer', 'pubmed']:
        dataset = Planetoid(osp.join('./data/raw', name), name, transform=T.NormalizeFeatures())
    elif name in ['chameleon', 'squirrel', 'actor']:
        dataset = Chame_Squir_Actor(root='./data/raw/', name=name, transform=T.NormalizeFeatures())
    elif name in ['texas', 'cornell']:
        dataset = WebKB(root='./data/raw/', name=name, transform=T.NormalizeFeatures())
    else:
        raise ValueError(f'dataset {name} not supported in dataloader')
    return dataset


def eigenDecompose(A):
    D = torch.sum(A, dim=1, keepdim=True)
    D = torch.clamp(D, min=1).pow(-0.5)
    A_sym = (D * A).T * D
    L_sym = torch.eye(A.shape[0]).float() - A_sym
    W, U = torch.linalg.eigh(L_sym)
    return W, U


def main():
    # for datname in ['chameleon', 'squirrel', 'texas', 'cornell', 'actor', 'cora', 'citeseer', 'pubmed']:
    for datname in ['chameleon', 'squirrel', 'texas', 'cornell']:
        # loading
        dataset = DataLoader(datname)
        data = dataset[0]
        # parsing
        node_feat, label = data.x, data.y
        edge_index = data.edge_index
        num_classes, num_nodes = dataset.num_classes, data.num_nodes

        # saving processed data
        res_dic = {"node_feat": node_feat, "label": label, "edge_index": edge_index,
                   "num_classes": num_classes, "num_nodes": num_nodes}
        torch.save(res_dic, f"./data/processed/{datname}_dataDic.pt")
        print(f"{datname} processed...")

        # eigen-decomposition
        A = torch.zeros((num_nodes, num_nodes))
        A[edge_index[0], edge_index[1]] = 1
        W, U = eigenDecompose(A)
        torch.save({"W": W, "U": U}, f"./data/eigen_dcp/{datname}__dcp.pt")
        print(f"{datname} decomposed...")


if __name__ == '__main__':
    main()
