import torch
import torch.distributions
from torch.utils.data import Dataset
from torchvision import datasets, transforms
from torch.utils.data import Sampler

from utils.datasets.tinyImages import _load_cifar_exclusion_idcs
from utils.datasets.paths import get_CIFAR10_path, get_CIFAR100_path

import os
import numpy as np
import matplotlib.pyplot as plt
import pathlib

from utils.datasets.cifar import get_CIFAR10_labels, get_CIFAR100_labels
from utils.datasets.cifar_augmentation import get_cifar10_augmentation
from .cifar_train_val import CIFAR10TrainValidationSplit, CIFAR100TrainValidationSplit
from utils.datasets.paths import get_tiny_images_files
from utils.datasets.tinyImages import TINY_LENGTH

from utils.datasets.tinyImages import _load_tiny_image, _preload_tiny_images
from .loading_utils import load_teacher_data
from .cifar_subsets import CIFARSubset
from .tiny_images_subset import TinyImagesSubset
from utils.datasets.tinyImages import get_80MTinyImages
from torchvision.utils import save_image


def get_tiny_cifar_partition(dataset_classifications_path, teacher_model, dataset, samples_per_class,
                             cifarTrainValSplit, selection_model=None, semi_ratio=0,
                             class_tpr_min=None, od_exclusion_threshold=None,
                             calibrate_temperature=False, verbose_exclude=False, soft_labels=True, batch_size=100,
                             augm_type='default',
                             cutout_window=16, aa_magnitude=1, size=32, num_workers=8, exclude_cifar=False,
                             exclude_cifar10_1=False, id_config_dict=None, od_config_dict=None,
                             ssl_config=None):

    teacher_logits, selection_logits, class_thresholds, temperature = load_teacher_data(dataset_classifications_path, teacher_model,
                                                                      selection_model=selection_model,
                                                                      class_tpr_min=class_tpr_min,
                                                                      od_exclusion_threshold=od_exclusion_threshold,
                                                                      calibrate_temperature=calibrate_temperature,
                                                                      ssl_config=ssl_config)

    augm_config = {}
    transform = get_cifar10_augmentation(augm_type, cutout_window, out_size=size,
                                         magnitude_factor=aa_magnitude, config_dict=augm_config)

    top_dataset = CIFARPlusTinyImageTopKPartition(teacher_logits, selection_logits, dataset,
                                                  samples_per_class=samples_per_class, transform_base=transform,
                                                  cifarTrainValSplit=cifarTrainValSplit, min_conf=class_thresholds,
                                                  temperature=temperature, soft_labels=soft_labels,
                                                  exclude_cifar=exclude_cifar, exclude_cifar10_1=exclude_cifar10_1,
                                                  teacher_model=teacher_model)

    if semi_ratio <= 0:
        sampler = BalancedSampler(top_dataset, subdivide_epochs=False)
        sampler_description = 'Balanced Sampler'
    else:
        sampler = TrainSemiRatioSampler(top_dataset, batch_size, ratio=semi_ratio)
        sampler_description = f'Fixed Ratio {semi_ratio}:1'

    top_loader = torch.utils.data.DataLoader(top_dataset, sampler=sampler, batch_size=batch_size, num_workers=num_workers)

    top_k_indices = top_dataset.get_used_semi_indices(verbose_exclude)
    bottom_dataset = CIFARTinyImageBottomKPartition(teacher_logits, top_k_indices, transform_base=transform,
                                                    temperature=temperature, soft_labels=soft_labels,
                                                    exclude_cifar=exclude_cifar, exclude_cifar10_1=exclude_cifar10_1)

    bottom_loader = torch.utils.data.DataLoader(bottom_dataset, shuffle=True, batch_size=batch_size, num_workers=1)

    if id_config_dict is not None:
        id_config_dict['Dataset'] ='Cifar-SSL'
        id_config_dict['Train validation split'] = cifarTrainValSplit
        id_config_dict['Sampler'] = sampler_description
        id_config_dict['Batch out_size'] = batch_size
        id_config_dict['SSL Samples per class'] = samples_per_class
        id_config_dict['Soft labels'] = soft_labels
        id_config_dict['Exclude CIFAR'] = exclude_cifar
        id_config_dict['Exclude CIFAR10.1'] = exclude_cifar10_1
        id_config_dict['Augmentation'] = augm_config

    if od_config_dict is not None:
        od_config_dict['Dataset'] = 'TinyImagesPartition'
        od_config_dict['Batch out_size'] = batch_size
        od_config_dict['Verbose exclude'] = verbose_exclude
        od_config_dict['Exclude CIFAR'] = True if dataset in ['CIFAR10', 'CIFAR100'] else False
        od_config_dict['Augmentation'] = augm_config

    return top_loader, bottom_loader


def get_tiny_cifar_subset_partition(dataset_classifications_path, teacher_model, dataset,
                                               labeled_per_class, samples_per_class, unlabeled_samples,
                             selection_model=None, semi_ratio=1,
                             class_tpr_min=None, od_exclusion_threshold=None,
                             calibrate_temperature=False, verbose_exclude=False, soft_labels=True, batch_size=100,
                             augm_type='default',
                             cutout_window=16, aa_magnitude=1, size=32, num_workers=8, exclude_cifar=False,
                             exclude_cifar10_1=False, id_config_dict=None, od_config_dict=None,
                             ssl_config=None):

    teacher_logits, selection_logits, class_thresholds, temperature = load_teacher_data(dataset_classifications_path, teacher_model,
                                                                      selection_model=selection_model,
                                                                      class_tpr_min=class_tpr_min,
                                                                      od_exclusion_threshold=od_exclusion_threshold,
                                                                      calibrate_temperature=calibrate_temperature,
                                                                      ssl_config=ssl_config)

    augm_config = {}
    transform = get_cifar10_augmentation(augm_type, cutout_window, out_size=size,
                                         magnitude_factor=aa_magnitude, config_dict=augm_config)

    if labeled_per_class == 400:
        samples_per_class_per_split = [400, 500, 4100]
    elif labeled_per_class == 100:
        samples_per_class_per_split = [100, 500, 4400]
    else:
        raise NotImplementedError()


    top_dataset = CIFARSubsetPlusTinyImageTopKPartition(teacher_logits, selection_logits, dataset,
                                                        samples_per_class_per_split, unlabeled_samples,
                                                        samples_per_class=samples_per_class, transform_base=transform,
                                                        min_conf=class_thresholds, temperature=temperature,
                                                        soft_labels=soft_labels, exclude_cifar=exclude_cifar,
                                                        exclude_cifar10_1=exclude_cifar10_1,
                                                        teacher_model=teacher_model)


    if semi_ratio <= 0:
        sampler = BalancedSampler(top_dataset, subdivide_epochs=False)
        sampler_description = 'Balanced Sampler'
    else:
        sampler = TrainSemiRatioSampler(top_dataset, batch_size, ratio=semi_ratio)
        sampler_description = f'Fixed Ratio {semi_ratio}:1'

    top_loader = torch.utils.data.DataLoader(top_dataset, sampler=sampler, batch_size=batch_size, num_workers=num_workers)

    top_k_indices = top_dataset.get_used_semi_indices(verbose_exclude)
    bottom_dataset = CIFARTinyImageBottomKPartition(teacher_logits, top_k_indices, transform_base=transform,
                                                    temperature=temperature, soft_labels=soft_labels,
                                                    exclude_cifar=exclude_cifar, exclude_cifar10_1=exclude_cifar10_1)

    bottom_loader = torch.utils.data.DataLoader(bottom_dataset, shuffle=True, batch_size=batch_size, num_workers=1)

    if id_config_dict is not None:
        id_config_dict['Dataset'] ='CifarSubset-SSL'
        id_config_dict['Labeled samples per class'] = labeled_per_class
        id_config_dict['Sampler'] = sampler_description
        id_config_dict['Batch out_size'] = batch_size
        id_config_dict['SSL Samples per class'] = samples_per_class
        id_config_dict['Soft labels'] = soft_labels
        id_config_dict['Exclude CIFAR'] = exclude_cifar
        id_config_dict['Exclude CIFAR10.1'] = exclude_cifar10_1
        id_config_dict['Augmentation'] = augm_config

    if od_config_dict is not None:
        od_config_dict['Dataset'] = 'TinyImagesPartition'
        od_config_dict['Batch out_size'] = batch_size
        od_config_dict['Verbose exclude'] = verbose_exclude
        od_config_dict['Exclude CIFAR'] = True if dataset in ['CIFAR10', 'CIFAR100'] else False
        od_config_dict['Augmentation'] = augm_config

    return top_loader, bottom_loader



class CIFARTinyImageBottomKPartition(Dataset):
    def __init__(self, teacher_logits, top_k_indices, transform_base, temperature=1, soft_labels=True,
                 exclude_cifar=False, exclude_cifar10_1=False):
        self.data_location = get_tiny_images_files(False)
        self.fileID = open(self.data_location, "rb")
        self.soft_labels = soft_labels
        self.num_classes = teacher_logits.shape[1]
        self.temperature = temperature

        transform = transforms.Compose([
            transforms.ToPILImage(),
            transform_base])

        self.transform = transform

        self.model_logits = teacher_logits

        #in_use_indices [i] holds all valid indices for i-th confidence interval
        self.valid_indices = []

        non_cifar = torch.ones(self.model_logits.shape[0], dtype=torch.bool)

        cifar_idxs = _load_cifar_exclusion_idcs(exclude_cifar, exclude_cifar10_1)
        cifar_idxs = torch.LongTensor(cifar_idxs)
        non_cifar[cifar_idxs] = 0

        valid_bool_indices  = torch.ones(self.model_logits.shape[0], dtype=torch.bool)
        valid_bool_indices[top_k_indices] = 0
        valid_bool_indices = valid_bool_indices & non_cifar
        self.valid_indices = torch.nonzero(valid_bool_indices, as_tuple=False).squeeze()

        self.length = len(self.valid_indices)

        print(f'Exclude Cifar {exclude_cifar} - Samples {self.length} - Temperature {self.temperature}')

    def __getitem__(self, index):
        valid_index = self.valid_indices[index]
        img = _load_tiny_image(valid_index, self.fileID)

        if self.transform is not None:
            img = self.transform(img)

        if self.soft_labels:
            model_prediction = torch.softmax(self.model_logits[valid_index, :] / self.temperature, dim=0)
        else:
            model_prediction = (1./self.num_classes) * torch.ones(self.num_classes)

        return img, model_prediction

    def __len__(self):
        return self.length


def plot_class_conf_histograms(predicted_class, predicted_max_conf, class_labels, min_conf, inclusion_idcs,
                               teacher_model):
    num_classes = len(class_labels)
    scale_factor = 4
    fig, axs = plt.subplots(1,num_classes, figsize=(scale_factor * ( 1 +num_classes), scale_factor))

    for i in range(num_classes):
        included_correct_class_bool_idcs = (predicted_class == i) & inclusion_idcs
        class_confs = predicted_max_conf[included_correct_class_bool_idcs]

        accepted = torch.sum(class_confs >= min_conf[i])
        total = len(class_confs)

        class_confs_np = predicted_max_conf[included_correct_class_bool_idcs].detach().numpy()
        axs[i].hist(class_confs_np,bins=50)
        y_min, y_max = axs[i].get_ylim()
        axs[i].plot(np.array([min_conf[i], min_conf[i]]), np.array([y_min, y_max]))
        axs[i].set_title(f'{class_labels[i]}\n Acc. {accepted} Total {total}')

    res_dir = 'SSLResults/'
    filename = f'{teacher_model}_class_hists.png'
    path = os.path.join(res_dir, filename)
    pathlib.Path(res_dir).mkdir(parents=True, exist_ok=True)
    plt.savefig(path)
    plt.close(fig)


class TrainSetPlusTinyImageTopKPartition(Dataset):
    def __init__(self, teacher_logits, selection_logits, train_dataset, samples_per_class, transform_base, min_conf,
                 class_labels, exclusion_idcs=None, temperature=1.0, soft_labels=True, preload=True,
                 teacher_model=None):
        self.data_location = get_tiny_images_files(False)
        self.fileID = open(self.data_location, "rb")
        self.samples_per_class = samples_per_class
        self.soft_labels = soft_labels
        self.preload = preload
        self.temperature = temperature
        transform = transforms.Compose([
            transforms.ToPILImage(),
            transform_base])

        self.transform = transform

        self.model_logits = teacher_logits
        predicted_max_conf, predicted_class = torch.max(torch.softmax(selection_logits,dim=1), dim=1)

        self.num_classes = len(class_labels)

        inclusion_idcs = torch.ones(self.model_logits.shape[0], dtype=torch.bool)

        if exclusion_idcs is not None:
            inclusion_idcs[exclusion_idcs] = 0

        self.train_dataset = train_dataset

        self.num_train_samples = len(self.train_dataset)
        self.train_per_class = torch.zeros(self.num_classes, dtype=torch.long)

        self.train_class_idcs = []
        targets_tensor = torch.LongTensor(self.train_dataset.targets)
        for i in range(self.num_classes):
            train_i = torch.nonzero(targets_tensor == i, as_tuple=False).squeeze()
            self.train_class_idcs.append(train_i)
            self.train_per_class[i] = len(train_i)

        self.in_use_indices = []
        self.valid_indices = []
        self.semi_per_class = torch.zeros(self.num_classes, dtype=torch.long)

        min_sampels_per_class = int(1e13)
        max_samples_per_class = 0

        for i in range(self.num_classes):
            min_conf_flag = predicted_max_conf >= min_conf[i]
            included_correct_class_bool_idcs = (predicted_class == i) & inclusion_idcs & min_conf_flag

            included_correct_class_linear_idcs = torch.nonzero(included_correct_class_bool_idcs, as_tuple=False).squeeze()
            included_correct_class_confidences = predicted_max_conf[included_correct_class_bool_idcs]
            included_correct_class_sort_idcs = torch.argsort(included_correct_class_confidences, descending=True)

            num_samples_i = int( min( samples_per_class, len(included_correct_class_linear_idcs) ))
            class_i_idcs = included_correct_class_linear_idcs[included_correct_class_sort_idcs[: num_samples_i]]

            self.valid_indices.append(included_correct_class_linear_idcs)

            self.in_use_indices.append(class_i_idcs)
            self.semi_per_class[i] = len(class_i_idcs)

            min_sampels_per_class = min(min_sampels_per_class, len(class_i_idcs))
            max_samples_per_class = max(max_samples_per_class, len(class_i_idcs))

            if num_samples_i < samples_per_class:
                print(f'Incomplete class {class_labels[i]} - Target count: {samples_per_class} - Found samples {len(class_i_idcs)}')

        self.num_semi_samples = torch.sum(self.semi_per_class)
        self.length = self.num_train_samples + self.num_semi_samples

        #internal idx ranges
        self.train_idx_ranges = []
        self.semi_idx_ranges = []

        train_idx_start = 0
        semi_idx_start = self.num_train_samples
        for i in range(self.num_classes):
            i_train_samples = self.train_per_class[i]
            i_semi_samples = self.semi_per_class[i]

            train_idx_next = train_idx_start + i_train_samples
            semi_idx_next = semi_idx_start + i_semi_samples
            self.train_idx_ranges.append( (train_idx_start, train_idx_next))
            self.semi_idx_ranges.append( (semi_idx_start, semi_idx_next))

            assert ( semi_idx_next - semi_idx_start) == self.semi_per_class[i]

            train_idx_start = train_idx_next
            semi_idx_start = semi_idx_next

        self.cum_train_lengths = torch.cumsum(self.train_per_class, dim=0)
        self.cum_semi_lengths = torch.cumsum(self.semi_per_class, dim=0)

        print(f'Top K -  Temperature {self.temperature} - Soft labels {soft_labels}'
              f'  -  Target Samples per class { self.samples_per_class} - Train Samples {self.num_train_samples}')
        print(f'Min Semi Samples {min_sampels_per_class} - Max Semi samples {max_samples_per_class}'
              f' - Total semi samples {self.num_semi_samples} - Total length {self.length}')

        if preload:
            print(f'Preloading images')
            self.class_data = []
            for class_idx in range(self.num_classes):
                self.class_data.append(_preload_tiny_images(self.in_use_indices[class_idx], self.fileID))

        if teacher_model is not None:
            plot_class_conf_histograms(predicted_class, predicted_max_conf, class_labels, min_conf, inclusion_idcs,
                                       teacher_model)




    #if verbose exclude, include all indices that fulfill the conf requirement but that are outside of the top-k range
    def get_used_semi_indices(self, verbose_exclude=False):
        if verbose_exclude:
            return torch.cat(self.valid_indices)
        else:
            return torch.cat(self.in_use_indices)

    def _load_train_image(self, class_idx, sample_idx):
        train_idx = self.train_class_idcs[class_idx][sample_idx]
        img, label = self.train_dataset[train_idx]
        if self.soft_labels:
            one_hot_label = torch.zeros(self.num_classes)
            one_hot_label[label] = 1.0
            return img, one_hot_label
        else:
            return img, label

    def _load_tiny_image(self, class_idx, tiny_lin_idx):
        valid_index = self.in_use_indices[class_idx][tiny_lin_idx].item()
        if self.preload:
            img = self.class_data[class_idx][tiny_lin_idx, :]
        else:
            img = _load_tiny_image(valid_index, self.fileID)

        if self.transform is not None:
            img = self.transform(img)

        if self.soft_labels:
            label = torch.softmax(self.model_logits[valid_index, :] / self.temperature, dim=0)
        else:
            label = torch.argmax(self.model_logits[valid_index, :]).item()
        return img, label

    def __getitem__(self, index):
        if index < self.num_train_samples:
            class_idx = torch.nonzero(self.cum_train_lengths > index, as_tuple=False)[0]
            if class_idx > 0:
                sample_idx = index - self.cum_train_lengths[class_idx - 1]
            else:
                sample_idx = index
            return self._load_train_image(class_idx, sample_idx)
        else:
            index_semi = index - self.num_train_samples
            class_idx = torch.nonzero(self.cum_semi_lengths > index_semi, as_tuple=False)[0]
            if class_idx > 0:
                sample_idx = index_semi - self.cum_semi_lengths[class_idx - 1]
            else:
                sample_idx = index_semi

            return self._load_tiny_image(class_idx, sample_idx)

    def __len__(self):
        return self.length

def plot_unlabeled_vs_od_conf_histograms(logits, class_labels, min_conf, exclusion_idcs,
                                         unlabeled_idcs, teacher_model):
    predicted_max_conf, predicted_class = torch.max(torch.softmax(logits,dim=1), dim=1)

    num_classes = len(class_labels)
    scale_factor = 4
    fig, axs = plt.subplots(2,num_classes, figsize=(scale_factor * ( 1 + num_classes), 2 * scale_factor))

    inclusion_idcs = torch.ones(logits.shape[0], dtype=torch.bool)
    inclusion_idcs[exclusion_idcs] = 0

    unlabeled_bool = torch.zeros(logits.shape[0], dtype=torch.bool)
    unlabeled_bool[unlabeled_idcs] = 1

    for i in range(num_classes):
        for j in range(2):
            if j == 0:
                idcs = unlabeled_bool
            else:
                idcs = ~unlabeled_bool

            included_correct_class_bool_idcs = (predicted_class == i) & inclusion_idcs & idcs
            class_confs = predicted_max_conf[included_correct_class_bool_idcs]

            accepted = torch.sum(class_confs >= min_conf[i])
            total = len(class_confs)

            class_confs_np = class_confs.detach().numpy()
            axs[j,i].hist(class_confs_np,bins=50)
            y_min, y_max = axs[j,i].get_ylim()
            axs[j,i].plot(np.array([min_conf[i], min_conf[i]]), np.array([y_min, y_max]))
            axs[j,i].set_title(f'{class_labels[i]}\n Acc. {accepted} Total {total}')

            if j == 1:
                axs[j,i].set_yscale('log')

    res_dir = 'SSLResults/'
    filename = f'{teacher_model}_in_vs_od.png'
    path = os.path.join(res_dir, filename)
    pathlib.Path(res_dir).mkdir(parents=True, exist_ok=True)
    plt.savefig(path)
    plt.close(fig)


class CIFARSubsetPlusTinyImageTopKPartition(TrainSetPlusTinyImageTopKPartition):
    def __init__(self, teacher_logits, selection_logits, dataset, samples_per_class_per_split, unlabeled_samples,
                 samples_per_class, transform_base, min_conf, temperature=1.0, soft_labels=True, preload=True,
                 exclude_cifar=False, exclude_cifar10_1=False, teacher_model=None):

        assert exclude_cifar
        assert exclude_cifar10_1

        if dataset.lower() == 'cifar10':
            class_labels = get_CIFAR10_labels()

            cifar_idcs = torch.zeros(50_000, dtype=torch.long)
            with open('./TinyImagesExclusionIdcs/80mn_cifar10_train_idxs.txt', 'r') as idxs:
                for i,idx in enumerate(idxs):
                    cifar_idcs[i] = int(idx)
        elif dataset.lower() == 'cifar100':
            raise NotImplementedError()
            class_labels = get_CIFAR100_labels()
        else:
            raise NotImplementedError()

        cifar_dataset = CIFARSubset('train', dataset.lower(), samples_per_class_per_split, transform_base)

        unlabeled_dataset = CIFARSubset('unlabeled', dataset.lower(), samples_per_class_per_split, transform_base)
        unlabeled_indices = unlabeled_dataset.idcs

        exclusion_idcs = _load_cifar_exclusion_idcs(exclude_cifar, exclude_cifar10_1)
        exclusion_logical = torch.zeros(TINY_LENGTH,dtype=torch.bool)
        exclusion_logical[exclusion_idcs] = 1

        if unlabeled_samples is not None and unlabeled_samples > 0:
            tiny_subset = TinyImagesSubset(transform_base,unlabeled_samples,
                                           exclude_cifar=exclude_cifar, exclude_cifar10_1=exclude_cifar10_1 )

            exclusion_logical[:] = 1
            exclusion_logical[tiny_subset.included_indices] = 0
            print(f'Including a subset of {len(tiny_subset.included_indices)} samples')

        #add unlabeled cifar samples back into the dataset by removing them from exclusion
        exclusion_logical[cifar_idcs[unlabeled_indices]] = 0
        exclusion_idcs = torch.nonzero(exclusion_logical, as_tuple=False).squeeze()
        print(f'Force inclusion of additional {len(unlabeled_indices)} samples')
        print(f'Total samples: {TINY_LENGTH - len(exclusion_idcs)}')

        plot_unlabeled_vs_od_conf_histograms(teacher_logits, class_labels, min_conf, exclusion_idcs,
                                             cifar_idcs[unlabeled_indices], teacher_model)

        super().__init__(teacher_logits, selection_logits, cifar_dataset, samples_per_class, transform_base, min_conf,
                         class_labels, exclusion_idcs=exclusion_idcs, temperature=temperature, soft_labels=soft_labels,
                         preload=preload, teacher_model=teacher_model)


class CIFARPlusTinyImageTopKPartition(TrainSetPlusTinyImageTopKPartition):
    def __init__(self, teacher_logits, selection_logits, dataset, samples_per_class, transform_base, cifarTrainValSplit,
                 min_conf, temperature=1.0, soft_labels=True, preload=True, exclude_cifar=False,
                 exclude_cifar10_1=False, teacher_model=None):
        if dataset.lower() == 'cifar10':
            class_labels = get_CIFAR10_labels()
            path = get_CIFAR10_path()
            if cifarTrainValSplit:
                cifar_dataset = CIFAR10TrainValidationSplit(path, train=True, transform=transform_base)
            else:
                cifar_dataset = datasets.CIFAR10(path, train=True, transform=transform_base)

            exclusion_idcs = _load_cifar_exclusion_idcs(exclude_cifar, exclude_cifar10_1)
        elif dataset.lower() == 'cifar100':
            class_labels = get_CIFAR100_labels()
            path = get_CIFAR100_path()
            if cifarTrainValSplit:
                cifar_dataset = CIFAR100TrainValidationSplit(path, train=True, transform=transform_base)
            else:
                cifar_dataset = datasets.CIFAR100(path, train=True, transform=transform_base)

            exclusion_idcs = _load_cifar_exclusion_idcs(exclude_cifar, exclude_cifar10_1)
        else:
            raise NotImplementedError()

        super().__init__(teacher_logits, selection_logits, cifar_dataset, samples_per_class, transform_base, min_conf,
                         class_labels, exclusion_idcs=exclusion_idcs, temperature=temperature, soft_labels=soft_labels,
                         preload=preload, teacher_model=teacher_model)


class BalancedSampler(Sampler):
    def __init__(self, tiny_top_k_dataset, subdivide_epochs=False):
        super().__init__(None)
        self.semi_per_class = tiny_top_k_dataset.semi_per_class
        self.train_per_class = tiny_top_k_dataset.train_per_class
        self.train_idx_ranges = tiny_top_k_dataset.train_idx_ranges
        self.semi_idx_ranges = tiny_top_k_dataset.semi_idx_ranges
        self.subdivide_epochs = subdivide_epochs

        self.total_per_class = self.semi_per_class + self.train_per_class
        self.samples_per_class = torch.max(self.total_per_class)
        min_per_class = torch.min(self.total_per_class)
        self.num_classes = len(self.semi_per_class)

        if self.subdivide_epochs:
            self.length = torch.sum(self.train_per_class).item()
            self.total_length = self.num_classes * self.samples_per_class.item()
            self.num_epoch_subdivs = int(np.ceil(self.total_length / self.length))
            #every epoch has to have the same length, so length has to perfectly divide total length
            assert (self.total_length % self.length) == 0
        else:
            self.length = self.num_classes * self.samples_per_class
            self.total_length = self.length
            self.num_epoch_subdivs = 1


        self.subdiv_i = self.num_epoch_subdivs

        print(f'Balanced Sampler: Max {self.samples_per_class} - Min {min_per_class}'
              f' - Length {self.total_length} - Epoch subdivs {self.num_epoch_subdivs}')


    def _get_idcs(self):
        intra_class_idcs = []
        for i in range(self.num_classes):
            i_intra_idcs = torch.zeros(self.samples_per_class, dtype=torch.long)

            i_train_start, i_train_end = self.train_idx_ranges[i]
            i_semi_start, i_semi_end = self.semi_idx_ranges[i]

            i_all_idcs = torch.cat([torch.arange(i_train_start, i_train_end, dtype=torch.long),
                                   torch.arange(i_semi_start, i_semi_end, dtype=torch.long)])

            assert len(i_all_idcs) == int(self.total_per_class[i])

            i_collected_samples = 0
            while i_collected_samples < self.samples_per_class:
                samples_to_get = min(self.samples_per_class - i_collected_samples, len(i_all_idcs))
                next_samples = i_all_idcs[torch.randperm(len(i_all_idcs))[:samples_to_get]]

                i_intra_idcs[i_collected_samples:(i_collected_samples + samples_to_get)] = next_samples
                i_collected_samples = i_collected_samples + samples_to_get

            intra_class_idcs.append(i_intra_idcs)

        idcs = torch.cat(intra_class_idcs)[torch.randperm(self.total_length)]
        return idcs

    def __iter__(self):
        if self.subdivide_epochs:
            if self.subdiv_i >= self.num_epoch_subdivs:
                idcs = self._get_idcs()
                self.subdiv_i = 0
                self.subdiv_idcs = []

                idx_i = 0
                for i in range(self.num_epoch_subdivs):
                    samples_subdiv_i = min(self.total_length - idx_i, self.length)
                    idx_i_next = idx_i + samples_subdiv_i
                    self.subdiv_idcs.append( idcs[idx_i:idx_i_next])
                    idx_i = idx_i_next

            iterator = iter(self.subdiv_idcs[self.subdiv_i])
            self.subdiv_i += 1
            return iterator

        else:
            idcs = self._get_idcs()
            return iter(idcs)


    def __len__(self):
        return self.length

class TrainSemiRatioSampler(Sampler):
    def __init__(self, tiny_top_k_dataset, bs, ratio=1):
        super().__init__(None)
        self.train_idx_ranges = tiny_top_k_dataset.train_idx_ranges
        self.semi_idx_ranges = tiny_top_k_dataset.semi_idx_ranges
        self.semi_per_class = tiny_top_k_dataset.semi_per_class

        #assert (bs % (ratio + 1)) == 0

        self.train_bs = int(bs / (ratio + 1))
        self.semi_bs = bs - self.train_bs

        assert self.train_bs + self.semi_bs == bs

        self.total_bs = bs

        self.train_per_epoch_total = torch.sum(tiny_top_k_dataset.train_per_class)
        self.total_semi = int(ratio * self.train_per_epoch_total)
        self.num_classes = len(self.train_idx_ranges)

        assert (self.total_semi % self.num_classes == 0)
        self.semi_per_epoch_per_class = int(self.total_semi / self.num_classes)

        self.length = self.train_per_epoch_total + self.total_semi

        print(f'Train Semi Ratio Sampler: Ratio {ratio}'
              f' - Train samples per epoch: {self.train_per_epoch_total} - Semi sampler per epoch: {self.total_semi}'
              f' - Length {self.length}')
        print(f'Total bs {bs} - Train bs {self.train_bs} - Semi bs {self.semi_bs}')

    def __iter__(self):
        train_idcs = []
        semi_idcs = []
        for class_idx, ((train_class_start, train_class_end), (semi_class_start, semi_class_end)) \
                in enumerate(zip(self.train_idx_ranges,self.semi_idx_ranges)):

            all_train_class_i = torch.arange(train_class_start, train_class_end, 1, dtype=torch.long)
            all_semi_class_i = torch.arange(semi_class_start, semi_class_end, 1, dtype=torch.long)

            available_train_class_i = len(all_train_class_i)

            # collect all train idcs
            train_idcs.append(all_train_class_i)

            #collect semi samples
            semi_idcs_class_i = torch.zeros(self.semi_per_epoch_per_class, dtype=torch.long)

            available_semi_samples_class_i = len(all_semi_class_i)
            samples_to_get = min(self.semi_per_epoch_per_class, available_semi_samples_class_i)

            semi_idcs_class_i[:samples_to_get] = all_semi_class_i[
                torch.randperm(available_semi_samples_class_i)[:samples_to_get]]
            i_collected_samples = samples_to_get

            #if the number of available samples is smaller than the number of required samples repeat samples
            while i_collected_samples < self.semi_per_epoch_per_class:

                #first repeat
                samples_to_get = min(self.semi_per_epoch_per_class - i_collected_samples, available_semi_samples_class_i)
                semi_idcs_class_i[i_collected_samples:(i_collected_samples+samples_to_get)] = all_semi_class_i[
                    torch.randperm(available_semi_samples_class_i)[:samples_to_get]]

                i_collected_samples += samples_to_get

                #if not sufficient, also fill with train samples, important if number of semi samples per class is tiny
                samples_to_get = min(self.semi_per_epoch_per_class - i_collected_samples, available_train_class_i)
                semi_idcs_class_i[i_collected_samples:(i_collected_samples+samples_to_get)] = all_train_class_i[
                    torch.randperm(available_train_class_i)[:samples_to_get]]

                i_collected_samples += samples_to_get


            semi_idcs.append(semi_idcs_class_i)

        train_idcs = torch.cat(train_idcs)[torch.randperm(self.train_per_epoch_total)]
        semi_idcs = torch.cat(semi_idcs)[torch.randperm(self.total_semi)]

        #merge train and semi idcs with the given ratio
        batches = []

        train_idx = 0
        semi_idx = 0
        idcs_idx = 0
        while idcs_idx < self.length:
            train_bs = min(self.train_bs, len(train_idcs) - train_idx)
            #first add train_bs of train samples
            train_batch = train_idcs[train_idx:(train_idx+train_bs)]
            train_idx += train_bs
            idcs_idx += train_bs

            #then semi_bs semi samples
            semi_bs = min(self.semi_bs, len(semi_idcs) - semi_idx )
            semi_batch = semi_idcs[semi_idx:(semi_idx+semi_bs)]
            semi_idx += semi_bs
            idcs_idx += semi_bs

            total_bs = train_bs + semi_bs
            batch = torch.cat([train_batch, semi_batch])[torch.randperm(total_bs)]

            batches.append(batch)

        idcs = torch.cat(batches)
        assert len(idcs) == self.length
        return iter(idcs)


    def __len__(self):
        return self.length
