from typing import List
import os
import pandas as pd
import numpy as np
import torch
import dgl

from graph_learning.data_setting import DataTransform, DataSettingConfig
from graph_learning.config import Config
from graph_learning.dataset.graph import GLGraph, edge_batch

@DataSettingConfig.register('pair-node-cls-sample-mask',
                            help='Train/valid/test set construction for node-pair tasks.')
class PairNodeClsSampleMaskConfig(DataSettingConfig):
    def __init__(self, args, context):
        super().__init__(args, context)

    @property
    def builder(self):
        return PairNodeClsSampleMask

    @classmethod
    def define_parser(cls, parser):
        super().define_parser(parser)
        parser.add_argument('--valid-ratio', default=0.1, type=float)
        parser.add_argument('--test-ratio', default=0.1, type=float)
        parser.add_argument('--task', choices=['node-pair', 'link'],
                            help='tasks: node-pair classification, link prediction')

def deduplicate_edges(edges):
    edges_new = np.zeros((2,edges.shape[1]//2), dtype=int)
    # add none self edge
    j = 0
    skip_node = set() # node already put into result
    for i in range(edges.shape[1]):
        if edges[0,i]<edges[1,i]:
            edges_new[:,j] = edges[:,i]
            j += 1
        elif edges[0,i]==edges[1,i] and edges[0,i] not in skip_node:
            edges_new[:,j] = edges[:,i]
            skip_node.add(edges[0,i])
            j += 1

    return edges_new

class PairNodeClsSampleMask(DataTransform):
    def __init__(self, valid_ratio, test_ratio, task,):
        self.valid_ratio = valid_ratio
        self.test_ratio = test_ratio
        self.task = task

    def _transform(self, graph):
        if self.task == 'node-pair':
            pair_labels = graph.ndata.pop('pair_labels').nonzero().t().numpy()
        elif self.task == 'link':
            pair_labels = torch.stack(graph.edges()).numpy()
            pair_labels = deduplicate_edges(pair_labels)
        else:
            raise NotImplementedError

        e = pair_labels.shape[1]
        n = graph.number_of_nodes()
        indices = np.random.permutation(e)
        pair_labels = pair_labels[:, indices]

        split1 = int(self.valid_ratio * e)
        split2 = int((self.valid_ratio+self.test_ratio) * e)
        pair_valid = pair_labels[:, :split1]
        pair_test = pair_labels[:, split1:split2]
        pair_train = pair_labels[:, split2:]

        if self.task == 'link':
            edge_ids = graph.edges(form='eid')[indices]
            edges = edge_ids[split2:]
            graph = graph.adapt(dgl.to_bidirected(
                graph.edge_subgraph(edges, preserve_nodes=True),
                copy_ndata=True))

        def get_mask_pair_neg(pair_pos, num_nodes, num_neg):
            pair_pos_set = []
            for i in range(pair_pos.shape[1]):
                pair_pos_set.append(tuple(pair_pos[:, i]))
                pair_pos_set.append(tuple(pair_pos[::-1, i]))
            pair_pos_set = set(pair_pos_set)

            pair_negative = np.zeros((2, num_neg), dtype=pair_pos.dtype)
            for i in range(num_neg):
                while True:
                    pair_temp = tuple(np.random.choice(num_nodes, size=(2,), replace=False))
                    if pair_temp not in pair_pos_set:
                        pair_negative[:, i] = pair_temp
                        break
            return pair_negative

        graph.gdata['pair_train_pos'] = torch.from_numpy(pair_train)
        graph.gdata['pair_valid_pos'] = torch.from_numpy(pair_valid)
        graph.gdata['pair_test_pos'] = torch.from_numpy(pair_test)

        graph.add_batch_schema('pair_train_pos', edge_batch)
        graph.add_batch_schema('pair_valid_pos', edge_batch)
        graph.add_batch_schema('pair_test_pos', edge_batch)

        pair_train_neg = get_mask_pair_neg(pair_train, n, pair_train.shape[1])
        pair_valid_neg = get_mask_pair_neg(pair_labels, n, pair_valid.shape[1])
        pair_test_neg = get_mask_pair_neg(pair_labels, n, pair_test.shape[1])

        graph.gdata['pair_train_neg'] = torch.from_numpy(pair_train_neg)
        graph.gdata['pair_valid_neg'] = torch.from_numpy(pair_valid_neg)
        graph.gdata['pair_test_neg'] = torch.from_numpy(pair_test_neg)

        graph.add_batch_schema('pair_train_neg', edge_batch)
        graph.add_batch_schema('pair_valid_neg', edge_batch)
        graph.add_batch_schema('pair_test_neg', edge_batch)

        return graph

