import random
import pickle
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.model_selection import train_test_split
from ogb.nodeproppred import DglNodePropPredDataset
import dgl
from dgl.data import CoraFullDataset, RedditDataset, AmazonCoBuyComputerDataset
import copy
import collections

class semi_task_manager:
    def __init__(self):
        self.task_info = collections.OrderedDict()
        self.task_info[-1] = 0
        self.g = None
        self.newg = []
        self.degree = None

    def add_task(self, task_i, masks):
        self.task_info[task_i] = masks

    def get_label_offset(self, task_i, original=False):
        if original:
            if task_i>0:
                return self.task_info[task_i-1], self.task_info[task_i]
            else:
                return 0, self.task_info[task_i]
        else:
            return 0, self.task_info[task_i]

def accuracy(logits, labels, cls_balance=True, ids_per_cls=None):
    if cls_balance:
        logi = logits.cpu().numpy()
        _, indices = torch.max(logits, dim=1)
        ids = _.cpu().numpy()
        acc_per_cls = [torch.sum((indices == labels)[ids])/len(ids) for ids in ids_per_cls]
        return sum(acc_per_cls).item()/len(acc_per_cls)
    else:
        _, indices = torch.max(logits, dim=1)
        correct = torch.sum(indices == labels)
        return correct.item() * 1.0 / len(labels)

def evaluate(model, g, features, labels, mask, label_offset1, label_offset2, cls_balance=True, ids_per_cls=None):
    model.eval()
    with torch.no_grad():
        output, _ = model(g, features)
        logits = output[:, label_offset1:label_offset2]
        if cls_balance:
            return accuracy(logits, labels, cls_balance=cls_balance, ids_per_cls=ids_per_cls)
        else:
            mask = [node_id for cls in ids_per_cls for node_id in cls]
            return accuracy(logits[mask], labels[mask], cls_balance=cls_balance, ids_per_cls=ids_per_cls)

def get_nhop_neighborhood(graph, nodes, n_hops, device):
    # Ensure input nodes are a tensor
    if not isinstance(nodes, torch.Tensor):
        nodes = torch.tensor(nodes).to(device)

    all_nodes = nodes
    for _ in range(n_hops):
        # Get outgoing edges from current nodes
        u, v = graph.out_edges(all_nodes)
        neighbors = v.unique()

        # Combine current nodes and neighbors into a single list
        all_nodes = torch.cat([all_nodes, neighbors]).unique()

    return all_nodes.tolist()

class incremental_graph_trans_(nn.Module):
    def __init__(self, dataset, n_cls):
        super().__init__()
        # transductive setting
        self.graph, self.labels = dataset[0]
        #self.graph = dgl.add_reverse_edges(self.graph)
        self.graph = dgl.add_self_loop(self.graph)
        self.graph.ndata['label'] = self.labels
        self.d_data = self.graph.ndata['feat'].shape[1]
        self.n_cls = n_cls
        self.n_nodes = self.labels.shape[0]
        self.tr_va_te_split = dataset[1]
        self.id_to_split = {}
        for class_key, splits in self.tr_va_te_split.items():
            train_ids, val_ids, test_ids = splits
            for id in train_ids:
                self.id_to_split[id] = 0  # 0 for train
            for id in val_ids:
                self.id_to_split[id] = 1  # 1 for val
            for id in test_ids:
                self.id_to_split[id] = 2  # 2 for test
        self.current_subgraph = None
        self.old_to_new = torch.full((self.n_nodes,), -1, dtype=torch.int64)
        self.mask = np.full((self.n_nodes,), False, dtype=bool)
        self.ids_test = []

    def update_subgraph(self, node_ids=[], tasks_to_retain=[], valid=True, device=None):
        prev_num_nodes = 0 if self.current_subgraph is None else self.current_subgraph.num_nodes('_N')
        self.old_to_new[node_ids] = torch.arange(prev_num_nodes, prev_num_nodes + len(node_ids), dtype=torch.int64)
        self.mask[node_ids] = True
        if self.current_subgraph is None:
            self.current_subgraph = dgl.node_subgraph(self.graph, node_ids, store_ids=True).to(device)
        else:
            self.current_subgraph.add_nodes(len(node_ids), {'feat': self.graph.ndata['feat'][torch.tensor(node_ids)].to(device),
                                                            'label': self.graph.ndata['label'][torch.tensor(node_ids)].to(device),
                                                            '_ID': torch.tensor(node_ids).to(device)})
            edges = self.graph.out_edges(node_ids)
            mask = self.mask[edges[1]]
            edges = (edges[0][mask], edges[1][mask])
            u, v = self.old_to_new[edges[0]].to(device), self.old_to_new[edges[1]].to(device)
            self.current_subgraph.add_edges(u, v)
            self.current_subgraph.add_edges(v[v<prev_num_nodes], u[v<prev_num_nodes])

        old_ids = self.current_subgraph.ndata['_ID'].cpu()
        splt = 1 if valid else 2
        test_ids = self.old_to_new[torch.tensor([node_id for node_id in node_ids if self.id_to_split[node_id]==splt], dtype=torch.int64)]
        self.ids_test.extend(test_ids.tolist())
        train_ids_current_batch = [(old_ids == i).nonzero()[0][0].item() for i in node_ids if self.id_to_split[i]==0]
        node_ids_per_task_reordered = []
        for c in tasks_to_retain:
            ids = (self.current_subgraph.ndata['label'] == c).nonzero()[:, 0].view(-1).tolist()
            node_ids_per_task_reordered.append(ids)

        return self.current_subgraph, node_ids_per_task_reordered, self.ids_test, train_ids_current_batch
    
    
    def get_graph(self, tasks_to_retain=[]):
        """
        Retrieves a partial graph based on the specified tasks and node IDs. Used only for the joint baseline

        Args:
            tasks_to_retain (list): List of classes to retain in the partial graph. Defaults to an empty list.

        Returns:
            tuple: A tuple containing the following elements:
                - subgraph (DGLGraph): The partial graph.
                - node_ids_per_task_reordered (list): A list of node IDs per task, reordered based on the specified tasks.
                - ids_train, ids_val, ids_test (list): Lists of node IDs for training, validation, and testing, respectively.
        """

        node_ids_retained = []
        ids_train_old, ids_valid_old, ids_test_old = [], [], []
        if len(tasks_to_retain) > 0:
            # retain nodes according to classes
            for t in tasks_to_retain:
                ids_train_old.extend(self.tr_va_te_split[t][0])
                ids_valid_old.extend(self.tr_va_te_split[t][1])
                ids_test_old.extend(self.tr_va_te_split[t][2])
                node_ids_retained.extend(self.tr_va_te_split[t][0] + self.tr_va_te_split[t][1] + self.tr_va_te_split[t][2])
            subgraph = dgl.node_subgraph(self.graph, node_ids_retained, store_ids=True)

        old_ids = subgraph.ndata['_ID'].cpu()
        ids_train = [(old_ids == i).nonzero()[0][0].item() for i in ids_train_old]
        ids_val = [(old_ids == i).nonzero()[0][0].item() for i in ids_valid_old]
        ids_test = [(old_ids == i).nonzero()[0][0].item() for i in ids_test_old]
        node_ids_per_task_reordered = []
        for c in tasks_to_retain:
            ids = (subgraph.ndata['label'] == c).nonzero()[:, 0].view(-1).tolist()
            node_ids_per_task_reordered.append(ids)
        subgraph = dgl.add_self_loop(subgraph)

        return subgraph, node_ids_per_task_reordered, [ids_train, ids_val, ids_test]
       
def train_valid_test_split(ids,ratio_valid_test):
    va_te_ratio = sum(ratio_valid_test)
    train_ids, va_te_ids = train_test_split(ids, test_size=va_te_ratio)
    return [train_ids] + train_test_split(va_te_ids, test_size=ratio_valid_test[1]/va_te_ratio)

class NodeLevelDataset(incremental_graph_trans_):
    def __init__(self,name='ogbn-arxiv',IL='class',default_split=False,ratio_valid_test=None,args=None):
        r""""
        name: name of the dataset
        IL: use task- or class-incremental setting
        default_split: if True, each class is split according to the splitting of the original dataset, which may cause the train-val-test ratio of different classes greatly different
        ratio_valid_test: in form of [r_val,r_test] ratio of validation and test set, train set ratio is directly calculated by 1-r_val-r_test
        """

        # return an incremental graph instance that can return required subgraph upon request
        if name[0:4] == 'ogbn':
            data = DglNodePropPredDataset(name, root=f'{args.ori_data_path}/ogb_downloaded')
            graph, label = data[0]
        elif name in ['CoraFullDataset', 'CoraFull','corafull', 'CoraFull-CL','Corafull-CL']:
            data = CoraFullDataset()
            graph, label = data[0], data[0].dstdata['label'].view(-1, 1)
        elif name in ['reddit','Reddit','Reddit-CL']:
            data = RedditDataset(self_loop=False)
            graph, label = data[0], data[0].dstdata['label'].view(-1, 1)
        elif name == 'Arxiv-CL':
            data = DglNodePropPredDataset('ogbn-arxiv', root=f'{args.ori_data_path}/ogb_downloaded')
            graph, label = data[0]
        elif name == 'AmazonComputer-CL':
            data = AmazonCoBuyComputerDataset()
            graph, label = data[0], data[0].dstdata['label'].view(-1, 1)
        else:
            print('invalid data name')
        n_cls = data.num_classes
        cls = [i for i in range(n_cls)]
        cls_id_map = {i: list((label.squeeze() == i).nonzero().squeeze().view(-1, ).numpy()) for i in cls}
        cls_sizes = {c: len(cls_id_map[c]) for c in cls_id_map}
        self.cls_sizes = cls_sizes
        for c in cls_sizes:
            if cls_sizes[c] < 2:
                cls.remove(c) # remove classes with less than 2 examples, which cannot be split into train, val, test sets
        cls_id_map = {i: list((label.squeeze() == i).nonzero().squeeze().view(-1, ).numpy()) for i in cls}
        n_cls = len(cls)
        if default_split:
            split_idx = data.get_idx_split()
            train_idx, valid_idx, test_idx = split_idx["train"].tolist(), split_idx["valid"].tolist(), split_idx[
                "test"].tolist()
            tr_va_te_split = {c: [list(set(cls_id_map[c]).intersection(set(train_idx))),
                                  list(set(cls_id_map[c]).intersection(set(valid_idx))),
                                  list(set(cls_id_map[c]).intersection(set(test_idx)))] for c in cls}

        elif not default_split:
            split_name = f'{args.data_path}/tr{round(1-ratio_valid_test[0]-ratio_valid_test[1],2)}_va{ratio_valid_test[0]}_te{ratio_valid_test[1]}_split_{name}.pkl'
            try:
                tr_va_te_split = pickle.load(open(split_name, 'rb')) # could use same split across different experiments for consistency
            except:
                if ratio_valid_test[1] > 0:
                    tr_va_te_split = {c: train_valid_test_split(cls_id_map[c], ratio_valid_test=ratio_valid_test)
                                      for c in
                                      cls}
                    print(f'splitting is {ratio_valid_test}')
                elif ratio_valid_test[1] == 0:
                    tr_va_te_split = {c: [cls_id_map[c], [], []] for c in
                                      cls}
                with open(split_name, 'wb') as f:
                    pickle.dump(tr_va_te_split, f)
        super().__init__([[graph, label], tr_va_te_split], n_cls)