import json
import os.path
import os.path as osp
import pickle as pkl
from itertools import chain

from torch_geometric.datasets import Planetoid, Twitch,LastFMAsia, WebKB
import torch_geometric.transforms as T

import numpy as np
import torch
from datasets.base_data import Graph
from datasets.base_dataset import NodeDataset
from datasets.link_split import link_class_split
from datasets.node_split import node_class_split
from datasets.utils import pkl_read_file,remove_self_loops,edge_homophily,node_homophily, linkx_homophily

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv

class GCN_gen(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers):
        super(GCN_gen, self).__init__()

        self.convs = nn.ModuleList()
        self.convs.append(
            GCNConv(in_channels, hidden_channels))
        for _ in range(num_layers - 2):
            self.convs.append(
                GCNConv(hidden_channels, hidden_channels))
        self.convs.append(
            GCNConv(hidden_channels, out_channels))

        self.activation = F.relu

    def reset_parameters(self):
        for conv in self.convs:
            conv.reset_parameters()

    def forward(self, x, edge_index, edge_weight=None):
        for i, conv in enumerate(self.convs[:-1]):
            x = conv(x, edge_index, edge_weight)
            x = self.activation(x)
        x = self.convs[-1](x, edge_index)
        return x

class PyGSDDataset(NodeDataset):
   
    def __init__(self, args, name, root, k, node_split, node_split_id=0, edge_split="direction", edge_split_id=0):
        super(PyGSDDataset, self).__init__(root, name, k)
        self.read_file()
        self.node_split = node_split
        self.node_split_id = node_split_id
        self.edge_split = edge_split
        self.edge_split_id = edge_split_id
        self.cache_node_split = osp.join(self.raw_dir, "{}-node-splits".format(self.name))
        self.cache_edge_split = osp.join(self.raw_dir, "{}-edge-splits".format(self.name))
        self.env_list = None
        self.num_nodes = None
        self.official_split = None

        if name in ("cora", "citeseer", "pubmed", "twitch", "lastfmasia","webkb"):
            file_name = "./data/"+self.name+"/raw/"+self.name+"-node-splits.npy"
            if os.path.exists(file_name):
                self.train_idx, self.val_idx, self.test_idx, self.env_list, self.seed_idx, self.stopping_idx = \
                    node_class_split(name=name.lower(), data=self.data, cache_node_split=self.cache_node_split,
                                        official_split=self.official_split, split=self.node_split,
                                        node_split_id=self.node_split_id, train_size_per_class=20, val_size=500)
        else:
            self.train_idx, self.val_idx, self.test_idx, self.seed_idx, self.stopping_idx = \
            node_class_split(name=name.lower(), data=self.data, cache_node_split=self.cache_node_split,
                                official_split=self.official_split, split=self.node_split,
                                node_split_id=self.node_split_id, train_size_per_class=20, val_size=500)

        edge_index = torch.from_numpy(np.vstack((self.edge.row.numpy(), self.edge.col.numpy()))).long()
        if self.name not in ("cora", "citeseer", "pubmed", "twitch","lastfmasia","webkb"):
            self.observed_edge_idx, self.observed_edge_weight, self.train_edge_pairs_idx, self.val_edge_pairs_idx, \
            self.test_edge_pairs_idx, self.train_edge_pairs_label, self.val_edge_pairs_label, self.test_edge_pairs_label = \
                link_class_split(edge_index=edge_index, A=self.edge.sparse_matrix,
                             cache_edge_split=self.cache_edge_split, task=self.edge_split,
                             edge_split_id=self.edge_split_id, prob_val=0.15, prob_test=0.05, )
        self.num_node_classes = self.num_classes
        if edge_split in ("existence", "direction", "sign"):
            self.num_edge_classes = 2
        elif edge_split in ("three_class_digraph"):
            self.num_edge_classes = 3
        elif edge_split in ("four_class_signed_digraph"):
            self.num_edge_classes = 4
        elif edge_split in ("five_class_signed_digraph"):
            self.num_edge_classes = 5
        else:
            self.num_edge_classes = None
        if args.heterophily:
            self.edge_homophily = edge_homophily(self.adj, self.y)
            self.node_homophily = node_homophily(self.adj, self.y)
            self.linkx_homophily = linkx_homophily(self.adj, self.y)

    @property
    def raw_file_paths(self):
        return self.raw_dir

    @property
    def processed_file_paths(self):
        return osp.join(self.processed_dir, f"{self.name}.graph")

    def read_file(self):
        self.data = pkl_read_file(self.processed_file_paths)
        self.edge = self.data.edge
        self.node = self.data.node
        self.x = self.data.x
        self.y = self.data.y
        self.adj = self.data.adj
        self.num_features = self.data.num_features
        self.num_classes = self.data.num_classes
        self.num_node = self.data.num_node
        self.num_edge = self.data.num_edge

    def download(self):
        return

    def process(self):
        print("processing...")
        if self.name in ("cora", "citeseer", "pubmed"):
            transform = T.NormalizeFeatures()
            data_dir = "./data/"+self.name+"/raw/"
            pre_dataset = Planetoid(root=data_dir,
                                    name=self.name, transform=transform)
            dataset = pre_dataset[0]

            edge_index = dataset.edge_index
            x = dataset.x
            #spu_feat_num = int(x.size(1) / 2)
            spu_feat_num = 160
            env_num=5
            class_num = dataset.y.max().item()+1
            node_idx_list = [torch.arange(dataset.num_nodes) + i*dataset.num_nodes for i in range(env_num)]
            print("creating new synthetic data...")
            x_list, edge_index_list, y_list, env_list = [], [], [], []
            idx_shift = 0
            Generator_x = GCN_gen(in_channels=class_num, hidden_channels=10, out_channels=spu_feat_num, num_layers=2)
            Generator_noise = nn.Linear(env_num, spu_feat_num)
            with torch.no_grad():
                for i in range(env_num):
                    label_new = F.one_hot(dataset.y, class_num).squeeze(1).float()
                    context_ = torch.normal(mean=i, std=1, size=(x.size(0), env_num))
                    context_[:, i] = 1
                    x2 = Generator_x(label_new, edge_index) + Generator_noise(context_)
                    x2 += torch.ones_like(x2).normal_(0, 1)
                    x_new = torch.cat([x, x2], dim=1)
                    x_list.append(x_new)
                    y_list.append(dataset.y)
                    edge_index_list.append(edge_index + idx_shift)
                    env_list.append(torch.ones(x.size(0)) * i)
                    idx_shift += dataset.num_nodes

            edge_index = torch.cat(edge_index_list, dim=1)
            undi_edge_index = torch.unique(edge_index, dim=1)
            undi_edge_index = remove_self_loops(undi_edge_index)[0]
            features = torch.cat(x_list, dim=0)
            labels = torch.cat(y_list, dim=0)
            env = torch.cat(env_list, dim=0)
            row,col = undi_edge_index
            edge_weight = torch.ones(len(row))
            #features = dataset.x
            num_node = features.shape[0]
            #labels = dataset.y

            train_num = 5
            train_ratio=0.6
            valid_ratio=0.2
            ind_idx = torch.cat(node_idx_list[:train_num], dim=0)
            # idx = ind_idx
            idx = torch.randperm(ind_idx.size(0))
            train_idx_ind = idx[:int(idx.size(0) * train_ratio)]
            valid_idx_ind = idx[int(idx.size(0) * train_ratio): int(idx.size(0) * (train_ratio + valid_ratio))]
            test_idx_ind = idx[int(idx.size(0) * (train_ratio + valid_ratio)):]
            self.train_idx = ind_idx[train_idx_ind]
            self.val_idx = ind_idx[valid_idx_ind]
            self.test_idx = ind_idx[test_idx_ind]
            try:
                node_split_file = "./data/"+self.name+"/raw/"+self.name+"-node-splits.npy"
                split_full = np.load(node_split_file, allow_pickle=True)
                masks = dict(enumerate(split_full.flatten(), 1))[1]
                self.env_list = masks['env']
            except:
                print("Execute node split, it may take a while")
                masks = {}
                train_mask, val_mask, test_mask = torch.zeros((labels.shape[0],1), dtype=torch.int), torch.zeros((labels.shape[0],1), dtype=torch.int), torch.zeros((labels.shape[0],1), dtype=torch.int)
                train_mask[self.train_idx , 0] = 1
                test_mask[self.test_idx , 0] = 1
                val_mask[self.val_idx , 0] = 1
                masks['train'] = train_mask.bool()
                masks['val'] = val_mask.bool()
                masks['test'] = test_mask.bool()
                masks['env'] = env
            np.save(node_split_file, masks)

        elif self.name == "lastfmasia":
            transform = T.NormalizeFeatures()
            data_dir = "./data/"+self.name+"/raw/"
            pre_dataset = LastFMAsia(root=data_dir,
                                     transform=transform)
            dataset = pre_dataset[0]

            edge_index = dataset.edge_index
            x = dataset.x
            #spu_feat_num = int(x.size(1) / 2)
            spu_feat_num = 160
            env_num=5
            class_num = dataset.y.max().item()+1
            node_idx_list = [torch.arange(dataset.num_nodes) + i*dataset.num_nodes for i in range(env_num)]
            print("creating new synthetic data...")
            x_list, edge_index_list, y_list, env_list = [], [], [], []
            idx_shift = 0
            Generator_x = GCN_gen(in_channels=class_num, hidden_channels=10, out_channels=spu_feat_num, num_layers=2)
            Generator_noise = nn.Linear(env_num, spu_feat_num)
            with torch.no_grad():
                for i in range(env_num):
                    label_new = F.one_hot(dataset.y, class_num).squeeze(1).float()
                    context_ = torch.normal(mean=i, std=1, size=(x.size(0), env_num))
                    context_[:, i] = 1
                    x2 = Generator_x(label_new, edge_index) + Generator_noise(context_)
                    x2 += torch.ones_like(x2).normal_(0, 1)
                    x_new = torch.cat([x, x2], dim=1)
                    x_list.append(x_new)
                    y_list.append(dataset.y)
                    edge_index_list.append(edge_index + idx_shift)
                    env_list.append(torch.ones(x.size(0)) * i)
                    idx_shift += dataset.num_nodes

            edge_index = torch.cat(edge_index_list, dim=1)
            undi_edge_index = torch.unique(edge_index, dim=1)
            undi_edge_index = remove_self_loops(undi_edge_index)[0]
            features = torch.cat(x_list, dim=0)
            labels = torch.cat(y_list, dim=0)
            env = torch.cat(env_list, dim=0)
            row,col = undi_edge_index
            edge_weight = torch.ones(len(row))
            num_node = features.shape[0]

            train_num = 5
            train_ratio=0.6
            valid_ratio=0.2
            ind_idx = torch.cat(node_idx_list[:train_num], dim=0)
            # idx = ind_idx
            idx = torch.randperm(ind_idx.size(0))
            train_idx_ind = idx[:int(idx.size(0) * train_ratio)]
            valid_idx_ind = idx[int(idx.size(0) * train_ratio): int(idx.size(0) * (train_ratio + valid_ratio))]
            test_idx_ind = idx[int(idx.size(0) * (train_ratio + valid_ratio)):]
            self.train_idx = ind_idx[train_idx_ind]
            self.val_idx = ind_idx[valid_idx_ind]
            self.test_idx = ind_idx[test_idx_ind]
            try:
                node_split_file = "./data/"+self.name+"/raw/"+self.name+"-node-splits.npy"
                split_full = np.load(node_split_file, allow_pickle=True)
                masks = dict(enumerate(split_full.flatten(), 1))[1]
                self.env_list = masks['env']
            except:
                print("Execute node split, it may take a while")
                masks = {}
                train_mask, val_mask, test_mask = torch.zeros((labels.shape[0],1), dtype=torch.int), torch.zeros((labels.shape[0],1), dtype=torch.int), torch.zeros((labels.shape[0],1), dtype=torch.int)
                train_mask[self.train_idx , 0] = 1
                test_mask[self.test_idx , 0] = 1
                val_mask[self.val_idx , 0] = 1
                masks['train'] = train_mask.bool()
                masks['val'] = val_mask.bool()
                masks['test'] = test_mask.bool()
                masks['env'] = env
            np.save(node_split_file, masks)

        elif self.name == "twitch":
            transform = T.NormalizeFeatures()
            data_dir = "./data/"+self.name+"/raw/"
            sub_graphs = ['DE', 'PT', 'RU', 'ES', 'FR', 'EN']
            sub_graphs_num = {}
            x_list, edge_index_list, y_list = [], [], []
            node_idx_list = []
            idx_shift = 0
            env_list = []
            for i, g in enumerate(sub_graphs):
                torch_dataset = Twitch(root=data_dir,
                               name=g, transform=transform)
                data = torch_dataset[0]
                x, edge_index, y = data.x, data.edge_index, data.y
                x_list.append(x)
                y_list.append(y)
                edge_index_list.append(edge_index + idx_shift)
                node_idx_list.append(torch.arange(data.num_nodes) + idx_shift)
                idx_shift += data.num_nodes
                sub_graphs_num[g]=data.num_nodes
                env_list.append(torch.ones(data.num_nodes) * i)
            x = torch.cat(x_list, dim=0)
            y = torch.cat(y_list, dim=0)
            env = torch.cat(env_list, dim=0)
            edge_index = torch.cat(edge_index_list, dim=1)
            undi_edge_index = torch.unique(edge_index, dim=1)
            undi_edge_index = remove_self_loops(undi_edge_index)[0]
            row,col = undi_edge_index
            edge_weight = torch.ones(len(row))
            features = x
            num_node = features.shape[0]
            labels = y
            #self.train_idx, self.val_idx, self.test_idx
            train_num = sub_graphs_num['DE']+sub_graphs_num['PT']
            val_num = sub_graphs_num['RU']+sub_graphs_num['ES']
            test_num = sub_graphs_num['FR']+sub_graphs_num['EN']
            self.train_idx = [i for i in range(train_num)]
            self.val_idx = [i+train_num for i in range(val_num)]
            self.test_idx = [i+train_num+val_num for i in range(test_num)]
            self.train_idx, self.val_idx, self.test_idx = torch.tensor(self.train_idx, dtype=torch.int), torch.tensor(self.val_idx, dtype=torch.int), torch.tensor(self.test_idx, dtype=torch.int)
            #print(self.train_idx)
            try:
                node_split_file = "./data/twitch/raw/twitch-node-splits.npy"
                split_full = np.load(node_split_file, allow_pickle=True)
                masks = dict(enumerate(split_full.flatten(), 1))[1]
                self.env_list = masks['env']
            except:
                print("Execute node split, it may take a while")
                masks = {}
                train_mask, val_mask, test_mask = torch.zeros((labels.shape[0],1), dtype=torch.int), torch.zeros((labels.shape[0],1), dtype=torch.int), torch.zeros((labels.shape[0],1), dtype=torch.int)
                #print(self.train_idx)
                train_mask[self.train_idx.long() , 0] = 1
                test_mask[self.test_idx.long(), 0] = 1
                val_mask[self.val_idx.long(), 0] = 1
                masks['train'] = train_mask.bool()
                masks['val'] = val_mask.bool()
                masks['test'] = test_mask.bool()
                masks['env'] = env
            np.save(node_split_file, masks)

        elif self.name == "webkb":
            transform = T.NormalizeFeatures()
            data_dir = "./data/"+self.name+"/raw/"
            sub_graphs = ['wisconsin', 'cornell', 'texas']
            sub_graphs_num = {}
            x_list, edge_index_list, y_list = [], [], []
            node_idx_list = []
            idx_shift = 0
            env_list = []
            for i, g in enumerate(sub_graphs):
                torch_dataset = WebKB(root=data_dir,
                               name=g, transform=transform)
                data = torch_dataset[0]
                x, edge_index, y = data.x, data.edge_index, data.y
                x_list.append(x)
                y_list.append(y)
                edge_index_list.append(edge_index + idx_shift)
                node_idx_list.append(torch.arange(data.num_nodes) + idx_shift)
                idx_shift += data.num_nodes
                sub_graphs_num[g]=data.num_nodes
                env_list.append(torch.ones(data.num_nodes) * i)
            x = torch.cat(x_list, dim=0)
            y = torch.cat(y_list, dim=0)
            env = torch.cat(env_list, dim=0)
            edge_index = torch.cat(edge_index_list, dim=1)
            undi_edge_index = torch.unique(edge_index, dim=1)
            undi_edge_index = remove_self_loops(undi_edge_index)[0]
            row,col = undi_edge_index
            edge_weight = torch.ones(len(row))
            features = x
            num_node = features.shape[0]
            labels = y
            #self.train_idx, self.val_idx, self.test_idx
            train_num = sub_graphs_num['wisconsin']
            val_num = sub_graphs_num['cornell']
            test_num = sub_graphs_num['texas']
            self.train_idx = [i for i in range(train_num)]
            self.val_idx = [i+train_num for i in range(val_num)]
            self.test_idx = [i+train_num+val_num for i in range(test_num)]
            self.train_idx, self.val_idx, self.test_idx = torch.tensor(self.train_idx, dtype=torch.int), torch.tensor(self.val_idx, dtype=torch.int), torch.tensor(self.test_idx, dtype=torch.int)
            #print(self.train_idx)
            try:
                node_split_file = "./data/webkb/raw/webkb-node-splits.npy"
                split_full = np.load(node_split_file, allow_pickle=True)
                masks = dict(enumerate(split_full.flatten(), 1))[1]
                self.env_list = masks['env']
            except:
                print("Execute node split, it may take a while")
                masks = {}
                train_mask, val_mask, test_mask = torch.zeros((labels.shape[0],1), dtype=torch.int), torch.zeros((labels.shape[0],1), dtype=torch.int), torch.zeros((labels.shape[0],1), dtype=torch.int)
                #print(self.train_idx)
                train_mask[self.train_idx.long() , 0] = 1
                test_mask[self.test_idx.long(), 0] = 1
                val_mask[self.val_idx.long(), 0] = 1
                masks['train'] = train_mask.bool()
                masks['val'] = val_mask.bool()
                masks['test'] = test_mask.bool()
                masks['env'] = env
            np.save(node_split_file, masks)
    
        g = Graph(row, col, edge_weight, num_node, x=features, y=labels)
        self.num_nodes = features.shape[0]
        with open(self.processed_file_paths, 'wb') as rf:
            try:
                pkl.dump(g, rf)
            except IOError as e:
                print(e)
                exit(1)
