from dataset import gen_random_loaders


class BaseNodes:
    def __init__(
            self,
            data_name,
            data_path,
            n_nodes,
            n_train_nodes=None,
            batch_size=128,
            partition_type='by_class',
            classes_per_node=2,
            alpha_train=None,
            alpha_test=None,
            embedding_dir_path=None
    ):

        self.data_name = data_name
        self.data_path = data_path
        self.n_nodes = n_nodes
        self.n_train_nodes = n_train_nodes if n_train_nodes is not None else n_nodes
        self.partition_type = partition_type
        self.classes_per_node = classes_per_node
        self.alpha_train = alpha_train
        self.alpha_test = alpha_test
        self.embedding_dir_path = embedding_dir_path
        self.batch_size = batch_size

        self.train_loaders, self.val_loaders, self.test_loaders = None, None, None
        self._init_dataloaders()

    def _init_dataloaders(self):
        self.train_loaders, self.val_loaders, self.test_loaders = gen_random_loaders(
            self.data_name,
            self.data_path,
            self.n_nodes,
            self.n_train_nodes,
            self.batch_size,
            self.partition_type,
            self.classes_per_node,
            self.alpha_train,
            self.alpha_test,
            self.embedding_dir_path
        )

    def __len__(self):
        return self.n_nodes
