import torch
import torch.distributions
from torch.utils.data import Dataset
from torchvision import datasets, transforms
from torch.utils.data import Sampler

from utils.datasets.tinyImages import _load_cifar_exclusion_idcs
from utils.datasets.paths import get_CIFAR10_path, get_CIFAR100_path

import numpy as np

from utils.datasets.cifar import get_CIFAR10_labels, get_CIFAR100_labels
from utils.datasets.cifar_augmentation import get_cifar10_augmentation
from .cifar_train_val import CIFAR10TrainValidationSplit, CIFAR100TrainValidationSplit
from utils.datasets.paths import get_tiny_images_files

from utils.datasets.tinyImages import _load_tiny_image, _preload_tiny_images
from .loading_utils import load_teacher_data

def get_tiny_top_only(dataset_classifications_path, teacher_model, dataset, samples_per_class,
                    selection_model=None, class_tpr_min=None, od_exclusion_threshold=None,
                     calibrate_temperature=False, soft_labels=True,
                     batch_size=100, shuffle=True, augm_type='default',
                     cutout_window=16, size=32, num_workers=8, exclude_cifar=False, exclude_cifar10_1=False,
                     id_config_dict=None, ssl_config=None):

    model_confidences, _, class_thresholds, temperature = load_teacher_data(dataset_classifications_path, teacher_model,
                                                                         class_tpr_min=class_tpr_min,
                                                                         od_exclusion_threshold=od_exclusion_threshold,
                                                                         calibrate_temperature=calibrate_temperature,
                                                                         ssl_config=ssl_config)


    augm_config = {}
    transform = get_cifar10_augmentation(augm_type, cutout_window, out_size=size, config_dict=augm_config)

    if dataset == 'cifar10':
        class_labels = get_CIFAR10_labels()
    elif dataset == 'cifar100':
        class_labels = get_CIFAR100_labels()
    else:
        raise NotImplementedError()

    exclusion_idcs = _load_cifar_exclusion_idcs(exclude_cifar, exclude_cifar10_1)

    top_dataset = TinyImageTopKPartition(model_confidences, samples_per_class=samples_per_class,
                                             transform_base=transform, temperature=temperature,
                                             min_conf=class_thresholds, class_labels=class_labels,
                                             soft_labels=soft_labels,
                                             exclusion_idcs=exclusion_idcs)
    top_loader = torch.utils.data.DataLoader(top_dataset, shuffle=shuffle, batch_size=batch_size, num_workers=num_workers)

    if id_config_dict is not None:
        id_config_dict['Dataset'] ='Cifar-SSL Additional Samples'
        id_config_dict['Batch out_size'] = batch_size
        id_config_dict['Samples per class'] = samples_per_class
        id_config_dict['Soft labels'] = soft_labels
        id_config_dict['Exclude CIFAR'] = exclude_cifar
        id_config_dict['Exclude CIFAR10.1'] = exclude_cifar10_1
        id_config_dict['Augmentation'] = augm_config

    return top_loader


class TinyImageTopKPartition(Dataset):
    def __init__(self, model_logits, samples_per_class, transform_base, min_conf,
                 class_labels, exclusion_idcs = None,
                 temperature=1.0, soft_labels=True, preload=True):
        self.data_location = get_tiny_images_files(False)
        self.fileID = open(self.data_location, "rb")
        self.samples_per_class = samples_per_class
        self.soft_labels = soft_labels
        self.preload = preload
        self.temperature = temperature
        transform = transforms.Compose([
            transforms.ToPILImage(),
            transform_base])

        self.transform = transform

        self.model_logits = model_logits
        predicted_max_conf, predicted_class = torch.max(torch.softmax(self.model_logits,dim=1), dim=1)

        self.num_classes = len(class_labels)

        inclusion_idcs = torch.ones(self.model_logits.shape[0], dtype=torch.bool)

        if exclusion_idcs is not None:
            inclusion_idcs[exclusion_idcs] = 0

        self.in_use_indices = []
        self.valid_indices = []
        self.class_semi_counts = []

        min_sampels_per_class = int(1e13)
        max_samples_per_class = 0

        for i in range(self.num_classes):

            min_conf_flag = predicted_max_conf >= min_conf[i]
            included_correct_class_bool_idcs = (predicted_class == i) & inclusion_idcs & min_conf_flag

            included_correct_class_linear_idcs = torch.nonzero(included_correct_class_bool_idcs, as_tuple=False).squeeze()
            included_correct_class_confidences = predicted_max_conf[included_correct_class_bool_idcs]
            included_correct_class_sort_idcs = torch.argsort(included_correct_class_confidences, descending=True)

            num_samples_i = int( min( samples_per_class, len(included_correct_class_linear_idcs) ))
            class_i_idcs = included_correct_class_linear_idcs[included_correct_class_sort_idcs[: num_samples_i]]

            self.valid_indices.append(included_correct_class_linear_idcs)

            self.in_use_indices.append(class_i_idcs)
            self.class_semi_counts.append(len(class_i_idcs))

            min_sampels_per_class = min(min_sampels_per_class, len(class_i_idcs))
            max_samples_per_class = max(max_samples_per_class, len(class_i_idcs))

            if num_samples_i < samples_per_class:
                print(f'Incomplete class {class_labels[i]} - Target count: {samples_per_class} - Found samples {len(class_i_idcs)}')

        self.num_semi_samples = 0
        for i in range(self.num_classes):
            self.num_semi_samples += self.class_semi_counts[i]

        self.length = self.num_semi_samples

        print(f'Top K -  Temperature {self.temperature} - Soft labels {soft_labels}'
              f'  -  Target Samples per class { self.samples_per_class}')
        print(f'Min Semi Samples {min_sampels_per_class} - Max Semi samples {max_samples_per_class}'
              f' - Total semi samples {self.num_semi_samples} - Total length {self.length}')

        if preload:
            print(f'Preloading images')
            self.class_data = []
            for class_idx in range(self.num_classes):
                self.class_data.append(_preload_tiny_images(self.in_use_indices[class_idx], self.fileID))


    #if verbose exclude, include all indices that fulfill the conf requirement but that are outside of the top-k range
    def get_used_semi_indices(self, verbose_exclude=False):
        if verbose_exclude:
            return torch.cat(self.valid_indices)
        else:
            return torch.cat(self.in_use_indices)

    def _load_tiny_image(self, class_idx, tiny_lin_idx):
        valid_index = self.in_use_indices[class_idx][tiny_lin_idx]
        if self.preload:
            img = self.class_data[class_idx][tiny_lin_idx, :]
        else:
            img = _load_tiny_image(valid_index, self.fileID)

        if self.transform is not None:
            img = self.transform(img)

        if self.soft_labels:
            label = torch.softmax(self.model_logits[valid_index, :] / self.temperature, dim=0)
        else:
            label = torch.argmax(self.model_logits[valid_index, :]).item()
        return img, label

    def __getitem__(self, index):
            index_semi = index
            cumulative_idx = 0
            for i in range(self.num_classes):
                next_cumulative = cumulative_idx + self.class_semi_counts[i]
                if index_semi < next_cumulative:
                    class_idx = i
                    sample_idx = index_semi - cumulative_idx
                    break
                cumulative_idx = next_cumulative

            return self._load_tiny_image(class_idx, sample_idx)

    def __len__(self):
        return self.length
