import torch
import torch.distributions
from torchvision import datasets, transforms
from torchvision.datasets.vision import VisionDataset

from utils.datasets.paths import get_CIFAR10_path, get_CIFAR100_path
import torch.utils.data as data_utils
from multiprocessing import Lock
from .loading_utils import load_teacher_data

import numpy as np

from utils.datasets.cifar import get_CIFAR10_labels, get_CIFAR100_labels
from utils.datasets.cifar_augmentation import get_cifar10_augmentation
from .cifar_semi_tiny_partition import CIFARPlusTinyImageTopKPartition, BalancedSampler
from utils.datasets.tinyImages import _load_tiny_image, _preload_tiny_images

def get_tiny_cifar_partition_without_od(dataset_classifications_path, teacher_model, dataset, samples_per_class,
                                        cifarTrainValSplit, soft_labels=True, calibrate_temperature=False,
                                        class_tpr_min=None, od_exclusion_threshold=None, batch_size=100,
                                        augm_type='default', subdivide_epochs=False,
                                        cutout_window=16, size=32, num_workers=8,
                                        exclude_cifar=False, exclude_cifar10_1=False,
                                        id_config=None, ssl_config=None):

    teacher_logits, selection_logits, class_thresholds, temperature  = load_teacher_data(dataset_classifications_path, teacher_model,
                                                                         class_tpr_min=class_tpr_min,
                                                                         od_exclusion_threshold=od_exclusion_threshold,
                                                                         calibrate_temperature=calibrate_temperature,
                                                                         ssl_config=ssl_config)


    augm_config = {}
    transform = get_cifar10_augmentation(augm_type, cutout_window, out_size=size, config_dict=augm_config)

    top_dataset = CIFARPlusTinyImageTopKPartition(teacher_logits, selection_logits, dataset,
                                                  samples_per_class=samples_per_class, transform_base=transform,
                                                  cifarTrainValSplit=cifarTrainValSplit, min_conf=class_thresholds,
                                                  temperature=temperature, soft_labels=soft_labels,
                                                  exclude_cifar=exclude_cifar, exclude_cifar10_1=exclude_cifar10_1)
    sampler = BalancedSampler(top_dataset, subdivide_epochs=subdivide_epochs)
    epoch_subdivs = sampler.num_epoch_subdivs

    top_loader = torch.utils.data.DataLoader(top_dataset, sampler=sampler, batch_size=batch_size, num_workers=num_workers)

    if id_config is not None:
        id_config['Dataset'] = 'SemiCifar-SSL'
        id_config['Batch out_size'] = batch_size
        id_config['Samples per class'] = samples_per_class
        id_config['Epoch subdivs'] = epoch_subdivs
        id_config['Soft labels'] = soft_labels
        id_config['Exclude CIFAR'] = exclude_cifar
        id_config['Exclude CIFAR10.1'] = exclude_cifar10_1
        id_config['Augmentation'] = augm_config

    return top_loader, epoch_subdivs

