import numpy as np
from torch.utils.data import Subset

from experiments.utils.active_learning_data import get_balanced_sample_indices, ActiveLearningData


def SubRangeDataset(dataset, begin, end):
    length = len(dataset)
    assert 0 <= begin <= length
    assert begin <= end <= length

    return Subset(dataset, np.arange(begin, end))


def train_validation_split(dataset, validation_size: int):
    total_size = len(dataset)
    split_point = total_size - validation_size
    train_dataset = SubRangeDataset(dataset, 0, split_point)
    validation_dataset = SubRangeDataset(dataset, split_point, total_size)
    return train_dataset, validation_dataset


def train_validation_split_different_transformer(num_classes, train_dataset, validation_dataset, validation_size):
    # Different transformers... a bit of a hack.
    num_per_class = validation_size // num_classes
    validation_indices = get_balanced_sample_indices(train_dataset.targets, num_classes, num_per_class)

    train_al = ActiveLearningData(train_dataset)
    train_al.extract_dataset_from_pool_from_indices(validation_indices)
    train_dataset = train_al.pool_dataset

    validation_al = ActiveLearningData(validation_dataset)
    validation_dataset = validation_al.extract_dataset_from_pool_from_indices(validation_indices)

    return train_dataset, validation_dataset
