import dgl
from graph_learning.data_setting import DataSettingConfig, DataTransform
from graph_learning.dataset.graph import GLGraph
import numpy as np
import torch
import copy
from sklearn.model_selection import train_test_split

@DataSettingConfig.register('static-node-random-split',
                            help='Train/valid/test set split for nodes.')
class StaticNodeRandomSplitConfig(DataSettingConfig):
    def __init__(self, args, context):
        super().__init__(args, context)

    @property
    def builder(self):
        return StaticNodeRandomSplit

    @classmethod
    def define_parser(cls, parser):
        super().define_parser(parser)
        parser.add_argument('--train-ratio', default=0.2, type=float)
        parser.add_argument('--valid-ratio', default=0.3, type=float)
        parser.add_argument('--no-stratify', action='store_true')

class StaticNodeRandomSplit(DataTransform):
    def __init__(self, train_ratio, valid_ratio, no_stratify):
        self.train_ratio = train_ratio
        self.valid_ratio = valid_ratio
        self.test_ratio = 1 - (train_ratio+valid_ratio)
        self.no_stratify = no_stratify

        self._use_test = (train_ratio + valid_ratio < 1)

    def transform(self, graph: GLGraph):
        train_mask, valid_mask, test_mask = self._sample_mask(graph)
        graph.ndata['train_mask'] = train_mask.to(graph.device)
        graph.ndata['val_mask'] = valid_mask.to(graph.device)
        graph.ndata['test_mask'] = test_mask.to(graph.device)
        return graph

    def _create_mask(self, idx, l):
        """Create mask."""
        mask = np.zeros(l)
        mask[idx] = 1
        return torch.BoolTensor(mask)

    def _sample_mask(self, graph):
        if self.no_stratify:
            indices = np.random.permutation(graph.nodes())
            nsize = graph.number_of_nodes()
            train_indices = indices[:int(nsize * self.train_ratio)]
            val_indices = indices[int(nsize * self.train_ratio):int(nsize * (self.train_ratio+self.valid_ratio))]
            test_indices = indices[int(nsize * (self.train_ratio+self.valid_ratio)):]
        else:
            labeled_mask = ~np.isnan(graph.ndata['labels'].cpu()).bool()
            nsize = len(labeled_mask)
            lsize = labeled_mask.sum()
            labeled = np.where(labeled_mask)[0]
            labels = graph.ndata['labels'].cpu()

            train_indices, val_test_indices = train_test_split(labeled, test_size=int(lsize*(self.valid_ratio+self.test_ratio)), stratify=labels)

            val_test_labels = labels[val_test_indices]
            val_indices, test_indices = train_test_split(val_test_indices, test_size=int(lsize*self.test_ratio), stratify=val_test_labels)

        train_mask = self._create_mask(train_indices, nsize)
        val_mask = self._create_mask(val_indices, nsize)
        test_mask = self._create_mask(test_indices, nsize)

        return train_mask, val_mask, test_mask
