import copy
from .utils import compute_split_idx
def transductive_split(dataset, split_sizes, task, k_fold):
    assert task in ['node', 'edge', 'link']
    if task == 'node':
        idx_splits = compute_split_idx(original_len=dataset.data.node_feature.shape[0],
                                       split_sizes=split_sizes,
                                       random=True,
                                       k_fold=k_fold)

        datasets = []
        for nodes_split_i in idx_splits:
            dataset_copy = copy.deepcopy(dataset)
            setattr(dataset_copy.my_data[0], 'node_label_index', nodes_split_i)
            dataset_copy.my_data[0].node_label = dataset_copy[0].node_label[nodes_split_i]
            datasets.append(dataset_copy)

    # assert False
    return datasets


def transductive_split_data(data, split_sizes, task, k_fold):
    assert task in ['node', 'edge', 'link']
    if task == 'node':
        idx_splits = compute_split_idx(original_len=data.node_feature.shape[0],
                                       split_sizes=split_sizes,
                                       random=True,
                                       k_fold=k_fold)

        datasets = []
        for nodes_split_i in idx_splits:
            data_i = copy.deepcopy(data)
            setattr(data_i, 'node_label_index', nodes_split_i)
            datasets.append(data_i)

    return datasets