'''
    Utilities for preprocessing data (adding supernodes, encoding text features) and batching data
'''

import torch_geometric as pyg
import torch
import numpy as np
import random
from torch_scatter import scatter
from copy import deepcopy
from time import time


def embed_text_features(data, text_emb, text_dict):
    '''
    :param data: List of graphs or a single graph on which to embed text features
    :param text_emb: Text embeddings, see below
    :param text_dict: The dictionary mapping texts to indices in the text_emb matrix
    :return:
    '''
    #  Removes text features from the given PyG data object and encodes them into the embedding matrix
    if isinstance(data, list):
        return [embed_text_features(g, text_emb, text_dict) for g in data]
    # assert "x_text" in data and "edge_attr_text" in data, "You need to provide text features!"
    if "x_text" not in data and "edge_attr_text" not in data:
        return data  # pass - this one has already been processed
    data.x = text_emb[[text_dict[item] for item in list(data.x_text)], :]
    # for now don't use edge_attr
    if data.edge_attr is not None:
        data.edge_attr = text_emb[[text_dict[item] for item in (list(data.edge_attr_text))], :]
    else:
        data.edge_attr = text_emb[[text_dict[''] for _ in range(data.edge_index.size(1))], :]
        # or just torch.zeros - empty label embeddings
    del data.x_text
    del data.edge_attr_text
    return data


def strip_uncommon_features(graph):
    '''
    Remove some task-specific features from the graph
    :return:
    '''
    to_keep = ["x", "edge_index", "edge_attr", "supernode", "edge_index_supernode", "edge_index_from_supernode", "num_nodes", "x_id", "y_task_labels"]
    for key in graph.keys:
        if key not in to_keep:
            del graph[key]
    return graph


def renumber_edge_index(edge_index):
    '''
    :param edge_index: Rewire edgelist of bipartite graph that has node idx numbered from 0 on both sides such that the
    nodes on the left side have original indices
    :return: modified edgelist
    '''
    n_nodes_left_side = max(edge_index[0, :]) + 1
    edge_index[1, :] += n_nodes_left_side
    return edge_index


def add_pooling_supernode(data: pyg.data.Data, pool_type: str, separate_edgelist_for_supernode=False):
    '''
    To the given data object, add a supernode connected in a way specified by pool_type.
    Also add self-loops and include edge_attr with respective features.
    :param data: The PyG graph to which to add the supernode.
    :param pool_type: Either "entire_graph", "node", or "link".
    :return: pyg.data.Data
    '''

    assert pool_type in ["node", "link", "entire_graph"]

    if "supernode" in data:
        return data  # pass - supernode has already been added

    n_nodes = data.x.shape[0]
    n_feats = data.x.shape[1]
    n_edges = data.edge_index.shape[1]
    # n_edge_type = 3
    # 0 = original edge,
    # 1 = new edges connecting original nodes with the supernode
    # 2 = new edges connecting supernode with original nodes

    #if "edge_attr" not in data:
    #    data.edge_attr = torch.zeros(n_edges, n_edge_type)
    #else:
    if len(data.edge_attr.shape) == 1:
        data.edge_attr = data.edge_attr.reshape((-1, 1)).float()
        #data.edge_attr = torch.nn.functional.pad(data.edge_attr, (n_edge_type, 0))
    #data.edge_attr[:, 0] = 1  # these are all original edges

    supernode_idx = n_nodes
    if hasattr(data, "node_pooling"):
        sn_links = data.node_pooling[0, :]
    else:
        if pool_type == "node":
            sn_links = data.center_node
        elif pool_type == "link":
            # for link prediciton, pool together source (0) and tail (1)
            sn_links = torch.LongTensor([0, 1])
        elif pool_type == "entire_graph":
            sn_links = torch.arange(start=0, end=n_nodes)
        else:
            raise NotImplementedError
    supernode_attr = torch.zeros(n_feats).float()
    #supernode_attr[0] = 1  # mark supernode with an "1" as the first feature

    if "supernode" in data:
        return data

    # data.x = torch.nn.functional.pad(data.x, (1, 0))
    data.x = torch.cat((data.x, supernode_attr.reshape(1, -1)))

    graph_to_supernode = torch.cat((sn_links.reshape(1, -1), supernode_idx * torch.ones(1, sn_links.size(0)))).long()
    graph_to_supernode_attr = torch.zeros(sn_links.size(0), data.edge_attr.size(1))
    graph_to_supernode_attr[:, 1] = 1
    supernode_to_graph = torch.cat((graph_to_supernode[1, :].reshape(1, -1), graph_to_supernode[0, :].reshape(1, -1)))
    supernode_to_graph_attr = torch.zeros(sn_links.size(0), data.edge_attr.size(1))
    supernode_to_graph_attr[:, 2] = 1
    data.supernode = torch.tensor([supernode_idx])
    if not separate_edgelist_for_supernode:
        data.edge_index = torch.cat((data.edge_index, graph_to_supernode, supernode_to_graph), dim=1)
        data.edge_attr = torch.cat((data.edge_attr, graph_to_supernode_attr, supernode_to_graph_attr), dim=0)
        return data
    data.edge_index_supernode = graph_to_supernode
    data.edge_index_from_supernode = supernode_to_graph
    data.num_nodes = data.x.shape[0]
    return data


def add_pooling_supernode_lst(data, pool_type: str, separate_edgelist_for_supernode: bool = False):
    '''
    Wrapper around add_pooling_supernode, which can take in either a list or pyg.data.Data object.
    :param data: list or pyg.data.data
    :param pool_type: see add_pooling_supernode definition.
    :return:
    '''
    if isinstance(data, list):
        return [add_pooling_supernode(g, pool_type, separate_edgelist_for_supernode) for g in data]
    else:
        return add_pooling_supernode(data, pool_type, separate_edgelist_for_supernode)


def is_support_set(graph: pyg.data.Data):
    return hasattr(graph, "support_set")


def create_batch_simple_encoder(task_lst: list, copy_to_device=None, add_original_y_to_graph=True,
                                add_supernodes=False):
    '''
    Create a very simple batch graph. Used for debugging GNN data loaders...
    -- Using it for now for arxiv debugging.....
    '''
    if add_original_y_to_graph:
        tasks = [add_original_y(task, add_supernodes=add_supernodes) for task in task_lst]
    else:
        tasks = task_lst
    graphs = []

    for task in tasks:
        graphs += task["support"]
        graphs += task["query"]

    loader = pyg.loader.DataLoader(graphs, batch_size=len(graphs))
    batched = next(iter(loader))
    if copy_to_device is not None:
        return batched.to(copy_to_device)
    return batched

def sanity_check(graph_object):
    num_x = graph_object.x.size(0)
    num_edge_index = max(graph_object.edge_index[0].max(), graph_object.edge_index[1].max()) + 1
    assert num_x >= num_edge_index, "Number of nodes in x and max node idx do not match."
    assert graph_object.edge_index.size(1) == graph_object.edge_attr.size(0), "Number of edges and edge attributes do not match."


def create_batch(graphs: list, group_starts: list = [0], copy_to_device=None, keep_original_ys=False,
                 renumber_mg_edge_index=True, is_mixed_tasks=False):
    '''
    Create a batched graph. All graphs should have .y attribute set to the class number.
    If a list of graphs is passed, the first graph needs to have .y set.
    :param graphs: list of PyG data objects or lists of PyG data objects.
    :param group_starts: list of indices at which new groups start. A "group" just means that tasks from it should be considered separate
                         (e.g. first we are doing arxiv classification and then something else, and we want to have everything separated)
    :return: batched graph, DeepSet 2nd layer pooling info (multiple subgraph embeddings are combined for DeepSet in
             some cases (if list of PyG data objects is passed) for some tasks), y's (for the DeepSet outputs)
    '''

    graph_list = []
    second_pooling_mapping = []
    pooling_ys = []
    metagraph_edges, metagraph_edge_attr = [], []
    query_set_mask = []
    y_true = []
    # edges: tuples (start_idx, end_idx); edge_attr: edge values (typically 0 or 1 for classification)

    # Each element of the list is a tensor of the indices of graphs that need to be pooled together using DeepSet
    n = 0
    count_scatter = 0  # count for the scatter tensor
    scatter_tensor = []
    task_level_tensor = []  # tensor specifying which parts of the output belong to which task
    count_task_scatter = -1
    # for i in range(len(graphs)):
    #    if i in group_starts:
    #        count_task_scatter += 1
    #    task_level_tensor.append(count_task_scatter)
    n_tasks = None
    add_to_tasks = 0  # This number is added to task idx
    for i, item in enumerate(graphs):
        if i in group_starts:
            n_tasks = None
            count_task_scatter += 1
        if isinstance(item, list):
            tmp_mapping = []
            for graph in item:
                if n_tasks is None:
                    n_tasks = graph.y_task_labels.size(1)
                    if n > 0 and not keep_original_ys:
                        #  Do not add for the first time
                        add_to_tasks += n_tasks
                graph_list.append(graph)
                tmp_mapping.append(n)
                if not keep_original_ys:
                    n += 1
            second_pooling_mapping.append(torch.tensor(tmp_mapping))
            pooling_ys.append(item[0].y.item() + add_to_tasks)
            for j in range(n_tasks):
                metagraph_edges.append((i, j + add_to_tasks))
                if is_support_set(item[0]):
                    feat_val = item[0].y_task_labels[0, j] * 2 - 1
                    support_set = 1
                else:
                    #feat_val = -1  # Missing value (not in support set)  # This is different than in the new pretraining dataloader
                    feat_val = 0
                    support_set = 0
                metagraph_edge_attr.append((1-support_set, feat_val))
                query_set_mask.append(1 - support_set)
                y_true.append(item[0].y_task_labels[0, j])
                scatter_tensor.append(count_scatter)
                task_level_tensor.append(count_task_scatter)
        else:
            if n_tasks is None:
                n_tasks = item.y_task_labels.size(1)
                if n > 0 and not keep_original_ys:
                    add_to_tasks += n_tasks
            graph_list.append(item)
            second_pooling_mapping.append(torch.tensor([n]))
            pooling_ys.append(item.y.item() + add_to_tasks)
            if not keep_original_ys:
                n += 1
            for j in range(n_tasks):
                metagraph_edges.append((i, j + add_to_tasks))
                if is_support_set(item):
                    feat_val = item.y_task_labels[0, j] * 2 - 1
                    support_set = 1
                else:
                    support_set = 0
                    feat_val = 0
                metagraph_edge_attr.append((1-support_set, feat_val))
                #metagraph_edge_attr.append((0.,0.))
                #print("Temporarily setting metagraph edge attr to 0")
                query_set_mask.append(1 - support_set)
                y_true.append(item.y_task_labels[0, j])
                scatter_tensor.append(count_scatter)
                task_level_tensor.append(count_task_scatter)
        count_scatter += 1
    if True or is_mixed_tasks:  # always strip uncommon features for now
        # strip graph_list of task-specific info
        graph_list = [strip_uncommon_features(g) for g in graph_list]
    #[sanity_check(g) for g in graph_list]
    def tmp_corr_edge_index_dtype(graph):
        graph.edge_index = graph.edge_index.long()
        return graph
    graph_list = [tmp_corr_edge_index_dtype(g) for g in graph_list]

    loader = pyg.loader.DataLoader(graph_list, batch_size=len(graph_list), num_workers=0)
    try:
        batched = next(iter(loader))
    except:
        raise Exception("Problem with dataloader")
    mg_edge_index = torch.tensor(metagraph_edges).T
    if renumber_mg_edge_index:
        mg_edge_index = renumber_edge_index(mg_edge_index)
    return batched, second_pooling_mapping, np.array(pooling_ys), mg_edge_index, \
           torch.FloatTensor(metagraph_edge_attr), torch.LongTensor(query_set_mask).bool(), torch.tensor(y_true), torch.tensor(
        scatter_tensor), \
           torch.tensor(task_level_tensor)


def create_debugging_batch(tasks: list, copy_to_device=None):
    #  create a debugging batch with node features and classifier targets only.
    #  This is temporary; only uses "query" set etc.
    x, y = [], []
    n = 0
    y_map_dict = {}
    for task in tasks:
        task_ids = task["task_ids"]
        for subgraph in task["query"]:
            tid = task_ids[subgraph.y.item()]
            if tid not in y_map_dict:
                y_map_dict[tid] = n
                n += 1
            y.append(y_map_dict[tid])
            # y.append(tid)  # actual y - we will compare them across batches
            x.append(subgraph.x[subgraph.center_node.item()])
    if copy_to_device is not None:
        return torch.stack(x).to(copy_to_device), torch.tensor(y).to(copy_to_device)
    return torch.stack(x), torch.tensor(y)
    # , torch.tensor([y_map_dict[yy] for yy in y])


def add_original_y(task, add_supernodes=False):
    # Add the original y to the task. Used for arxiv when classification_only=True.
    class_labels = task["task_ids"]
    task_texts = task["task_descriptions"]
    for subgraph in task["query"]:
        subgraph.y_text = task_texts[subgraph.y.item()]
        subgraph.y = torch.tensor(class_labels[subgraph.y.item()])
    for subgraph in task["support"]:
        subgraph.y_text = task_texts[subgraph.y.item()]
        subgraph.y = torch.tensor(class_labels[subgraph.y.item()])
    if add_supernodes:
        task["query"] = [add_pooling_supernode_lst(g, "node", True) for g in task["query"]]
        task["support"] = [add_pooling_supernode_lst(g, "node", True) for g in task["support"]]
    return task


def create_batch_from_task_list(tasks: list, shuffle=False, copy_to_device=None, keep_original_ys=False,
                                renumber_mg_edge_index=True, is_mixed_tasks=False):
    '''
    :param tasks: list of tasks (dicts outputted by the get_dataset function)
    :param is_mixed_tasks: if True, it means it contains different types of tasks. The resulting batched graph will be
                           stripped of task-specific info such as y_task_labels etc.
    :return: batch_graph, second_pooling_mapping, task_texts, metagraph_edge_index, metagraph_edge_attr
    '''

    def mark_as_support(graph):
        if isinstance(graph, list):
            for item in graph:
                item.support_set = True
        else:
            graph.support_set = True
        return graph

    all_task_texts = []
    all_data = []
    group_starts = []

    for task in tasks:
        task_type = task["pool_type"]
        if keep_original_ys:
            task = add_original_y(task)
        # assert task["query"][0].x.size(1) == 2 # for debugging molecules
        # assert task["support"][0].x.size(1) == 2
        are_supernodes_added = hasattr(task["query"][0], "supernode")
        if not are_supernodes_added:
            query_set = [add_pooling_supernode_lst(g, task_type, True) for g in task["query"]]
            support_set = [add_pooling_supernode_lst(g, task_type, True) for g in task["support"]]
        else:
            query_set = task["query"]
            support_set = task["support"]
        support_set = [mark_as_support(g) for g in support_set]
        # assert query_set[0].x.size(1) == 3
        # assert support_set[0].x.size(1) == 3
        all_data_batch =  support_set + query_set
        tds = task["task_descriptions"]
        all_task_texts += list(tds)
        group_starts.append(len(all_data))
        all_data += all_data_batch

    batch_graph, second_pooling_mapping, pooling_ys, metagraph_edge_index, metagraph_edge_attr, query_set_mask, y_true, scatter_mask, task_mask = create_batch(
        all_data,
        group_starts=group_starts, copy_to_device=copy_to_device, keep_original_ys=keep_original_ys,
        renumber_mg_edge_index=renumber_mg_edge_index, is_mixed_tasks=is_mixed_tasks)

    if copy_to_device is not None:
        device = copy_to_device
        return batch_graph.to(device), [i.to(device) for i in second_pooling_mapping], all_task_texts, \
               metagraph_edge_index.to(device), metagraph_edge_attr.to(device), group_starts, \
               query_set_mask.to(device), y_true.to(device), scatter_mask.to(device), task_mask.to(
            device), torch.from_numpy(pooling_ys).to(device)

    return batch_graph, second_pooling_mapping, all_task_texts, \
           metagraph_edge_index, metagraph_edge_attr, group_starts, \
           query_set_mask, y_true, scatter_mask, task_mask, pooling_ys


def create_batch_from_task_list_encoder_only(tasks: list, shuffle=True, copy_to_device=None, keep_original_ys=False):
    '''
    :param tasks: list of tasks (dicts outputted by the get_dataset function)
    :param shuffle: whether to shuffle the data in the end.
    :return: batch_graph, second_pooling_mapping, task_texts, metagraph_edge_index, metagraph_edge_attr
    '''

    def mark_as_support(graph):
        if isinstance(graph, list):
            for item in graph:
                item.support_set = True
        else:
            graph.support_set = True
        return graph

    all_task_texts = []
    all_data = []
    group_starts = []

    for task in tasks:
        task_type = task["pool_type"]
        if keep_original_ys:
            task = add_original_y(task)
        query_set = [add_pooling_supernode_lst(g, task_type) for g in task["query"]]
        support_set = [add_pooling_supernode_lst(g, task_type) for g in task["support"]]
        support_set = [mark_as_support(g) for g in support_set]
        all_data_batch = query_set + support_set
        # if shuffle:
        #     random.shuffle(all_data)
        tds = task["task_descriptions"]
        all_task_texts += list(tds)
        group_starts.append(len(all_data))
        all_data += all_data_batch

    batch_graph, second_pooling_mapping, pooling_ys, metagraph_edge_index, metagraph_edge_attr, query_set_mask, y_true, scatter_mask, task_mask = create_batch(
        all_data,
        group_starts=group_starts, copy_to_device=copy_to_device, keep_original_ys=keep_original_ys)

    if copy_to_device is not None:
        device = copy_to_device
        return batch_graph.to(device), [i.to(device) for i in second_pooling_mapping], all_task_texts, \
               metagraph_edge_index.to(device), metagraph_edge_attr.to(device), group_starts, \
               query_set_mask.to(device), y_true.to(device), scatter_mask.to(device), task_mask.to(
            device), torch.from_numpy(pooling_ys).to(device)

    return batch_graph, second_pooling_mapping, all_task_texts, metagraph_edge_index, metagraph_edge_attr, group_starts, \
           query_set_mask, y_true, scatter_mask, task_mask, pooling_ys


def obtain_supernode_embeddings(all_node_emb, supernode_edge_index, supernode_idx, aggr='mean'):
    '''
    A simple aggregator to obtain supernode embeddings.
    :param all_node_emb:
    :param supernode_edge_index:
    :param supernode_idx:
    :param aggr:
    :return:
    '''
    return scatter(src=all_node_emb[supernode_edge_index[0]], index=supernode_edge_index[1], dim=0, reduce=aggr)[
           supernode_idx, :]


def get_texts(task):
    # extract different texts from a task.
    all_texts = set()
    all_texts.add("")  # empty task
    for graph in task["support"] + task["query"]:
        if "x_text" in graph:
            all_texts.update(graph.x_text)
        if "edge_attr_text" in graph:
            all_texts.update(graph.edge_attr_text)
    all_texts = list(all_texts)
    return all_texts, {t: i for i, t in enumerate(all_texts)}


def preprocess_task(task, bert_encoder=None):
    '''
    Preprocess a task by (1) encoding text features (if bert_model is set to a model) and (2) adding pooling supernodes.
    :param task: The task as outputted by the get_dataset function.
    :param bert_encoder: None or the encoder model
    :return:
    '''
    # task_dc = deepcopy(task)
    if bert_encoder is not None:
        all_text, all_text_dict = get_texts(task)
        assert len(set(all_text)) == len(all_text)  # check that all texts are unique
        time0 = time()
        text_emb = bert_encoder.get_sentence_embeddings(all_text)
        time1 = time()
        # print("Time to encode text: ", time1 - time0, "(for ", len(all_text), " texts)")
        task["query"] = [embed_text_features(g, text_emb, all_text_dict) for g in task["query"]]
        task["support"] = [embed_text_features(g, text_emb, all_text_dict) for g in task["support"]]

    # print("trace")
    # ipdb.set_trace()
    task["support"] = add_pooling_supernode_lst(task["support"], task["pool_type"],
                                                separate_edgelist_for_supernode=True)
    task["query"] = add_pooling_supernode_lst(task["query"], task["pool_type"],
                                              separate_edgelist_for_supernode=True)
    return task

