import pickle
import numpy as np
import torch
from torch import Tensor
import torch.nn as nn
from scipy import sparse as sp
import torch.nn.functional as F
from sklearn.model_selection import train_test_split
from ogb.nodeproppred import DglNodePropPredDataset
import dgl
from dgl.data import CoraGraphDataset, CoraFullDataset, register_data_args, RedditDataset, CiteseerGraphDataset, \
    AmazonCoBuyComputerDataset
from ogb.graphproppred import DglGraphPropPredDataset, collate_dgl, Evaluator
import copy
from torch_geometric.utils import to_networkx, degree, to_dense_adj, to_scipy_sparse_matrix
from dgl.data.utils import download
import pandas as pd
import networkx as nx
import os
import json
import random


class Linear_IL(nn.Linear):
    def forward(self, input: Tensor, n_cls=10000, normalize=True) -> Tensor:
        if normalize:
            return F.linear(F.normalize(input, dim=-1), F.normalize(self.weight[0:n_cls], dim=-1), bias=None)
        else:
            return F.linear(input, self.weight[0:n_cls], bias=None)


def accuracy(logits, labels, cls_balance=True, ids_per_cls=None):
    if cls_balance:
        _, indices = torch.max(logits.detach(), dim=1)
        acc_per_cls = [torch.sum((indices == labels)[ids]) / len(ids) for ids in ids_per_cls]
        return sum(acc_per_cls).item() / len(acc_per_cls)
    else:
        _, indices = torch.max(logits, dim=1)
        correct = torch.sum(indices == labels)
        return correct.item() * 1.0 / len(labels)


def mean_AP(args, logits, labels, cls_balance=True, ids_per_cls=None):
    eval_ogb = Evaluator(args.dataset)
    pos = (F.sigmoid(logits) > 0.5)
    APs = 0
    if cls_balance:
        _, indices = torch.max(logits, dim=1)
        ids = _.cpu().numpy()
        acc_per_cls = [torch.sum((indices == labels)[ids]) / len(ids) for ids in ids_per_cls]
        return sum(acc_per_cls).item() / len(acc_per_cls)
    else:
        input_dict = {"y_true": labels, "y_pred": logits}

        eval_result_ogb = eval_ogb.eval(input_dict)
        for c, ids in enumerate(ids_per_cls):
            TP_ = (pos[ids, c] * labels[ids, c]).sum()
            FP_ = (pos[ids, c] * (labels[ids, c] == False)).sum()
            med0 = TP_ + FP_ + 0.0001
            med1 = TP_ / med0
            APs += med1
        med2 = APs / labels.shape[1]
        return med2.item()


def evaluate_batch(args, model, g, features, labels, mask, label_offset1, label_offset2, cls_balance=True,
                   ids_per_cls=None):
    model.eval()
    with torch.no_grad():
        dataloader = dgl.dataloading.DataLoader(g.cpu(), list(range(labels.shape[0])), args.nb_sampler,
                                                batch_size=args.batch_size, shuffle=False, drop_last=False)
        output = torch.tensor([]).cuda(args.gpu)
        output_l = torch.tensor([]).cuda(args.gpu)
        for input_nodes, output_nodes, blocks in dataloader:
            blocks = [b.to(device='cuda:{}'.format(args.gpu)) for b in blocks]
            input_features = blocks[0].srcdata['feat']
            output_labels = blocks[-1].dstdata['label'].squeeze()
            output_predictions, _ = model.forward_batch(blocks, input_features)
            output = torch.cat((output, output_predictions), dim=0)
            output_l = torch.cat((output_l, output_labels), dim=0)

        logits = output[:, label_offset1:label_offset2]
        if cls_balance:
            return accuracy(logits, labels.cuda(args.gpu), cls_balance=cls_balance, ids_per_cls=ids_per_cls)
        else:
            return accuracy(logits[mask], labels[mask].cuda(args.gpu), cls_balance=cls_balance, ids_per_cls=ids_per_cls)


def evaluate(model, g, features, labels, mask, label_offset1, label_offset2, cls_balance=True, ids_per_cls=None,
             save_logits_name=None):
    model.eval()
    with torch.no_grad():
        output, _ = model(g, features)
        logits = output
        if save_logits_name is not None:
            with open('/store/continual_graph_learning/baselines_by_TWP/NCGL/results/logits_for_tsne/{}.pkl'.format(
                    save_logits_name), 'wb') as f:
                pickle.dump({'logits': logits, 'ids_per_cls': ids_per_cls}, f)

        if cls_balance:
            return accuracy(logits, labels, cls_balance=cls_balance, ids_per_cls=ids_per_cls)
        else:
            return accuracy(logits[mask], labels[mask], cls_balance=cls_balance, ids_per_cls=ids_per_cls)


def evaluatewp(output, labels, mask, cls_balance=True, ids_per_cls=None):
    logits = output
    if cls_balance:
        return accuracy(logits, labels, cls_balance=cls_balance, ids_per_cls=ids_per_cls)
    else:
        return accuracy(logits[mask], labels[mask], cls_balance=cls_balance, ids_per_cls=ids_per_cls)


def init_structure_encoding(subgraph, type_init='rw', n_rw=16, n_dg=16):
    edge_indexs = torch.cat((subgraph.edges()[0].unsqueeze(0), subgraph.edges()[1].unsqueeze(0)), dim=0)
    if type_init == 'rw':
        # Geometric diffusion features with Random Walk
        # 
        A = to_scipy_sparse_matrix(edge_indexs, num_nodes=subgraph.num_nodes())
        D = (subgraph.in_degrees() ** -1.0).numpy()
        Dinv = sp.diags(D)
        RW = A * Dinv
        M = RW
        SE_rw = [torch.from_numpy(M.diagonal()).float()]
        M_power = M
        for _ in range(n_rw - 1):
            M_power = M_power * M
            SE_rw.append(torch.from_numpy(M_power.diagonal()).float())
        SE_rw = torch.stack(SE_rw, dim=-1)
        return SE_rw

    elif type_init == 'dg':
        # PE_degree
        g_dg = subgraph.in_degrees().numpy().clip(1, n_dg)
        SE_dg = torch.zeros([subgraph.num_nodes(), n_dg])
        for i in range(len(g_dg)):
            SE_dg[i, int(g_dg[i] - 1)] = 1
        return SE_dg

    elif type_init == 'rw_dg':
        # SE_rw
        # ipdb.set_trace()
        A = to_scipy_sparse_matrix(edge_indexs, num_nodes=subgraph.num_nodes())
        D = (subgraph.in_degrees() ** -1.0).numpy()
        Dinv = sp.diags(D)
        RW = A * Dinv
        M = RW
        SE = [torch.from_numpy(M.diagonal()).float()]
        M_power = M
        for _ in range(n_rw - 1):
            M_power = M_power * M
            SE.append(torch.from_numpy(M_power.diagonal()).float())
        SE_rw = torch.stack(SE, dim=-1)
        # PE_degree
        g_dg = subgraph.in_degrees().numpy().clip(1, n_dg)
        SE_dg = torch.zeros([subgraph.num_nodes(), n_dg])
        for i in range(len(g_dg)):
            SE_dg[i, int(g_dg[i] - 1)] = 1
        return torch.cat([SE_rw, SE_dg], dim=1)


class incremental_graph_trans_(nn.Module):
    def __init__(self, dataset, n_cls):
        super().__init__()
        # transductive setting
        self.graph, self.labels = dataset[0]
        self.graph.ndata['label'] = self.labels
        self.d_data = self.graph.ndata['feat'].shape[1]
        self.n_cls = n_cls
        self.n_nodes = self.labels.shape[0]
        self.tr_va_te_split = dataset[1]

    def get_graph(self, tasks_to_retain=[], n_agents=None, partition='noniid0.1', node_ids=None, remove_edges=True):
        # get the partial graph
        # tasks-to-retain: classes retained in the partial graph
        # tasks-to-infer: classes to predict on the partial graph
        node_ids_ = copy.deepcopy(node_ids)
        node_ids_retained = []
        ids_train_old, ids_valid_old, ids_test_old = [], [], []
        random.seed(42)
        np.random.seed(42)
        if len(tasks_to_retain) > 0:
            # retain nodes according to classes
            for t in tasks_to_retain:
                ids_train_old.extend(self.tr_va_te_split[t][0])
                ids_valid_old.extend(self.tr_va_te_split[t][1])
                ids_test_old.extend(self.tr_va_te_split[t][2])
                node_ids_retained.extend(
                    self.tr_va_te_split[t][0] + self.tr_va_te_split[t][1] + self.tr_va_te_split[t][2])

            # Only training set is divided, the test set is not divided
            alla_ids_train_old_map = {0: ids_train_old}
            if n_agents > 1:
                n_train = len(ids_train_old)
                ids_train_old = np.array(ids_train_old)
                if partition[:6] == 'noniid':
                    min_size = 0
                    min_require_size = 3
                    if partition == 'noniid0.1':
                        beta = 0.1
                    elif partition == 'noniid0.5':
                        beta = 0.5
                    alla_ids_train_old_map = {}
                    while min_size < min_require_size:
                        idx_batch = [[] for _ in range(n_agents)]
                        for k in tasks_to_retain:
                            idx_k = np.where(self.labels.squeeze()[ids_train_old] == k)[0]
                            np.random.shuffle(idx_k)
                            proportions = np.random.dirichlet(np.repeat(beta, n_agents))
                            proportions = np.array(
                                [p * (len(idx_j) < n_train / n_agents) for p, idx_j in zip(proportions, idx_batch)])
                            proportions = proportions / proportions.sum()
                            proportions = (np.cumsum(proportions) * len(idx_k)).astype(int)[:-1]
                            idx_batch = [idx_j + idx.tolist() for idx_j, idx in
                                         zip(idx_batch, np.split(idx_k, proportions))]
                            min_size = min([len(idx_j) for idx_j in idx_batch])
                    for j in range(n_agents):
                        np.random.shuffle(idx_batch[j])
                        alla_ids_train_old_map[j] = ids_train_old[idx_batch[j]]
                else:
                    pass
                traindata_cls_counts = record_net_data_stats(self.labels.squeeze(), alla_ids_train_old_map)
                data_distributions = traindata_cls_counts / traindata_cls_counts.sum(axis=1)[:, np.newaxis]
                print('data_distributions', traindata_cls_counts, data_distributions)

            subgraph_0 = dgl.node_subgraph(self.graph, node_ids_retained, store_ids=True)
            if node_ids_ is None:
                subgraph = subgraph_0

        if node_ids_ is not None:
            # retrain the given nodes
            if not isinstance(node_ids_[0], list):
                # if nodes are not divided into different tasks
                subgraph_1 = dgl.node_subgraph(self.graph, node_ids_, store_ids=True)
                if remove_edges:
                    # to facilitate the methods like ER-GNN to only retrieve nodes
                    n_edges = subgraph_1.edges()[0].shape[0]
                    subgraph_1.remove_edges(list(range(n_edges)))
            elif isinstance(node_ids_[0], list):
                # if nodes are diveded into different tasks
                subgraph_1 = dgl.node_subgraph(self.graph, node_ids_[0],
                                               store_ids=True)  # load the subgraph containing nodes of the first task
                node_ids_.pop(0)
                for ids in node_ids_:
                    # merge the remaining nodes
                    subgraph_1 = dgl.batch([subgraph_1, dgl.node_subgraph(self.graph, ids, store_ids=True)])

            if len(tasks_to_retain) == 0:
                subgraph = subgraph_1

        if len(tasks_to_retain) > 0 and node_ids is not None:
            subgraph = dgl.batch([subgraph_0, subgraph_1])

        node_ids_per_task_reordered, ids_train, ids_val, ids_test = [], [], [], []
        old_ids = subgraph.ndata['_ID'].cpu()
        ids_train = [[(old_ids == i).nonzero()[0][0].item() for i in alla_ids_train_old_map[a_id]] for a_id in
                     range(n_agents)]
        ids_val = [(old_ids == i).nonzero()[0][0].item() for i in ids_valid_old]
        ids_test = [(old_ids == i).nonzero()[0][0].item() for i in ids_test_old]
        for c in tasks_to_retain:
            ids = (subgraph.ndata['label'] == c).nonzero()[:, 0].view(-1).tolist()
            node_ids_per_task_reordered.append(ids)
        subgraph = dgl.add_self_loop(subgraph)
        return subgraph, node_ids_per_task_reordered, [ids_train, ids_val, ids_test]


def train_valid_test_split(ids, ratio_valid_test):
    va_te_ratio = sum(ratio_valid_test)
    train_ids, va_te_ids = train_test_split(ids, test_size=va_te_ratio)
    return [train_ids] + train_test_split(va_te_ids, test_size=ratio_valid_test[1] / va_te_ratio)


class NodeLevelDataset(incremental_graph_trans_):
    def __init__(self, name='ogbn-arxiv', IL='class', default_split=False, ratio_valid_test=None, args=None):
        r""""
        name: name of the dataset
        IL: use task- or class-incremental setting
        default_split: if True, each class is split according to the splitting of the original dataset, which may cause the train-val-test ratio of different classes greatly different
        ratio_valid_test: in form of [r_val,r_test] ratio of validation and test set, train set ratio is directly calculated by 1-r_val-r_test
        """

        if name[0:4] == 'ogbn':
            data = DglNodePropPredDataset(name, root=f'{args.ori_data_path}/ogb_downloaded')
            graph, label = data[0]
        elif name == 'CoraFull-CL':
            custom_download_path = f'{args.ori_data_path}'
            download_path = f'{args.ori_data_path}/cora_full.zip'
            if not os.path.exists(download_path):
                download('https://data.dgl.ai/dataset/cora_full.zip', path=download_path)
            data = CoraFullDataset(custom_download_path)
            graph, label = data[0], data[0].dstdata['label'].view(-1, 1)
        elif name == 'Arxiv-CL':
            data = DglNodePropPredDataset('ogbn-arxiv', root=f'{args.ori_data_path}/ogb_downloaded')
            graph, label = data[0]
        elif name == 'Reddit-CL':
            # data = RedditDataset(self_loop=False)
            custom_download_path = f'{args.ori_data_path}'
            # download_path = f'{args.ori_data_path}/reddit.zip'
            # if not os.path.exists(download_path):
            #     download('https://data.dgl.ai/dataset/reddit.zip', path=download_path)
            data = RedditDataset(self_loop=False, raw_dir=custom_download_path)  # , force_reload=True
            graph, label = data.graph, data.labels.view(-1, 1)
        elif name == 'Cora-CL':
            custom_download_path = f'{args.ori_data_path}'
            download_path = f'{args.ori_data_path}/cora.zip'
            if not os.path.exists(download_path):
                download('https://data.dgl.ai/dataset/cora.zip', path=download_path)
            data = CoraGraphDataset(custom_download_path)
            graph, label = data[0], data[0].dstdata['label'].view(-1, 1)
        elif name == 'CiteSeer-CL':
            custom_download_path = f'{args.ori_data_path}'
            download_path = f'{args.ori_data_path}/citeseer.zip'
            if not os.path.exists(download_path):
                download('https://data.dgl.ai/dataset/citeseer.zip', path=download_path)
            data = CiteseerGraphDataset(custom_download_path)
            graph, label = data[0], data[0].dstdata['label'].view(-1, 1)
        elif name == 'Computers-CL':
            custom_download_path = f'{args.ori_data_path}'
            data = AmazonCoBuyComputerDataset(custom_download_path)
            graph, label = data[0], data[0].dstdata['label'].view(-1, 1)
        elif name == 'SLAP-CL':
            graph, label = read_slap()
        else:
            print('invalid data name')

        n_cls = args.n_cls
        cls = [i for i in range(n_cls)]
        cls_id_map = {i: list((label.squeeze() == i).nonzero().squeeze().view(-1, ).numpy()) for i in cls}
        cls_sizes = {c: len(cls_id_map[c]) for c in cls_id_map}
        for c in cls_sizes:
            if cls_sizes[c] < 2:
                cls.remove(
                    c)  # remove classes with less than 2 examples, which cannot be split into train, val, test sets
        cls_id_map = {i: list((label.squeeze() == i).nonzero().squeeze().view(-1, ).numpy()) for i in cls}
        n_cls = len(cls)
        if default_split:
            split_idx = data.get_idx_split()
            train_idx, valid_idx, test_idx = split_idx["train"].tolist(), split_idx["valid"].tolist(), split_idx[
                "test"].tolist()
            tr_va_te_split = {c: [list(set(cls_id_map[c]).intersection(set(train_idx))),
                                  list(set(cls_id_map[c]).intersection(set(valid_idx))),
                                  list(set(cls_id_map[c]).intersection(set(test_idx)))] for c in cls}

        elif not default_split:
            split_name = f'{args.data_path}/tr{round(1 - ratio_valid_test[0] - ratio_valid_test[1], 2)}_va{ratio_valid_test[0]}_te{ratio_valid_test[1]}_split_{name}.pkl'
            try:
                tr_va_te_split = pickle.load(
                    open(split_name, 'rb'))  # could use same split across different experiments for consistency
            except:
                if ratio_valid_test[1] > 0:
                    tr_va_te_split = {c: train_valid_test_split(cls_id_map[c], ratio_valid_test=ratio_valid_test) for c
                                      in cls}
                    print(f'splitting is {ratio_valid_test}')
                elif ratio_valid_test[1] == 0:
                    tr_va_te_split = {c: [cls_id_map[c], [], []] for c in cls}
                with open(split_name, 'wb') as f:
                    pickle.dump(tr_va_te_split, f)
        super().__init__([[graph, label], tr_va_te_split], n_cls)


def record_net_data_stats(y_train, net_dataidx_map):
    net_cls_counts_dict = {}
    net_cls_counts_npy = np.array([])
    num_classes = int(y_train.max()) + 1

    for net_i, dataidx in net_dataidx_map.items():
        unq, unq_cnt = np.unique(y_train[dataidx], return_counts=True)
        tmp = {unq[i]: unq_cnt[i] for i in range(len(unq))}
        net_cls_counts_dict[net_i] = tmp
        tmp_npy = np.zeros(num_classes)
        for i in range(len(unq)):
            tmp_npy[unq[i]] = unq_cnt[i]
        net_cls_counts_npy = np.concatenate(
            (net_cls_counts_npy, tmp_npy), axis=0)
    net_cls_counts_npy = np.reshape(net_cls_counts_npy, (-1, num_classes))

    data_list = []
    for net_id, data in net_cls_counts_dict.items():
        n_total = 0
        for class_id, n_data in data.items():
            n_total += n_data
        data_list.append(n_total)
    return net_cls_counts_npy



def read_slap():
    input_folder = '/slap/'
    X = pd.read_csv(f'{input_folder}/X.csv')
    y = pd.read_csv(f'{input_folder}/y.csv')

    networkx_graph = nx.read_graphml(f'{input_folder}/graph.graphml')
    networkx_graph = nx.relabel_nodes(networkx_graph, {str(i): i for i in range(len(networkx_graph))})

    dgl_graph = dgl.from_networkx(networkx_graph)
    dgl_graph.ndata['feat'] = torch.tensor(X.values, dtype=torch.float32)

    labels = torch.tensor(y.values.squeeze(), dtype=torch.long)
    dgl_graph.ndata['label'] = labels
    return dgl_graph, labels.view(-1, 1)
