import numpy as np
import torch

from . import data_helper as helper

class DataGenerator:
    def __init__(self, root, config, is_train=True):
        self.is_train = is_train
        self.config = config
        # load data here
        self.batch_size = self.config.batch_size
        self.load_data(root)

    # load the specified dataset in the config to the data_generator instance
    def load_data(self, root):
        graphs, labels = helper.load_dataset(self.config.dataset_name, root)

        # if no fold specify creates random split to train and validation
        if self.config.num_fold is None:
            graphs, labels = helper.shuffle(graphs, labels)
            idx = len(graphs) // 10
            self.train_graphs, self.train_labels, self.val_graphs, self.val_labels = graphs[idx:], labels[idx:], graphs[:idx], labels[:idx]
        elif self.config.num_fold == 0:
            train_idx, test_idx = helper.get_parameter_split(self.config.dataset_name, root)
            self.train_graphs, self.train_labels, self.val_graphs, self.val_labels = graphs[train_idx], labels[
                train_idx], graphs[test_idx], labels[test_idx]
        else:
            train_idx, test_idx = helper.get_train_val_indexes(self.config.num_fold,
                self.config.dataset_name, root)
            self.train_graphs, self.train_labels, self.val_graphs, self.val_labels = graphs[train_idx], labels[train_idx], graphs[test_idx], labels[test_idx]
        # change validation graphs to the right shape
        self.val_graphs = [np.expand_dims(g, 0) for g in self.val_graphs]
        self.train_size = len(self.train_graphs)
        self.val_size = len(self.val_graphs)

    def next_batch(self):
        return next(self.iter)

    # initialize an iterator from the data for one training epoch
    def initialize(self, is_train):
        if is_train:
            self.reshuffle_data()
        else:
            self.iter = zip(self.val_graphs, self.val_labels)

    def __iter__(self):
        if self.is_train:
            self.reshuffle_data()
        else:
            self.iter = zip(self.val_graphs, self.val_labels)
        self.iter = [(torch.tensor(x, dtype=torch.float32),
                      torch.tensor(y, dtype=torch.int64)) for x, y in self.iter]
        return iter(self.iter)

    # resuffle data iterator between epochs
    def reshuffle_data(self):
        graphs, labels = helper.group_same_size(self.train_graphs, self.train_labels)
        graphs, labels = helper.shuffle_same_size(graphs, labels)
        graphs, labels = helper.split_to_batches(graphs, labels, self.batch_size)
        self.num_iterations_train = len(graphs)
        graphs, labels = helper.shuffle(graphs, labels)
        self.iter = zip(graphs, labels)



if __name__ == '__main__':
    import sys
    sys.path.append('../utils/')
    import config
    config = utils.config.process_config('../configs/example.json')
    data = DataGenerator(config)
    data.initialize(True)


