import torch
import torch.distributions
from torchvision import transforms

import torch.utils.data as data_utils

from utils.datasets.cifar_augmentation import get_cifar10_augmentation
from utils.datasets.paths import get_tiny_images_files

from utils.datasets.tinyImages import _load_tiny_image
from ssl_utils.tiny_images_od_bias import normalize_class_weights

def get_semi_tiny_images_train_loader(model_confidences, dataset, shuffle=True, class_weights=None,
                                      class_weights_normalize=True, min_conf=0.0, max_conf=0.12, soft_labels=True,
                                      batch_size=100, augm_type='default', cutout_window=16, size=32, num_workers=1):
    if num_workers > 1:
        raise ValueError('Bug in the current multithreaded tinyimages implementation')

    transform = get_cifar10_augmentation(augm_type, cutout_window=cutout_window, out_size=size)
    dataset_out = TinyImageThresholdConfidences(model_confidences, dataset, soft_labels=soft_labels, min_conf=min_conf, max_conf=max_conf,
                                                transform_base=transform, class_weights=class_weights,
                                                class_weights_normalize=class_weights_normalize)
    loader = torch.utils.data.DataLoader(dataset_out, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)
    return loader



class TinyImageThresholdConfidences(torch.utils.data.Dataset):
    def __init__(self, model_confidences, dataset, min_conf, max_conf, transform_base, soft_labels=True,
                 class_weights=None, class_weights_normalize=True, balance_intervals=True):
        self.data_location = get_tiny_images_files(False)
        self.fileID = open(self.data_location, "rb")
        self.balance_intervals = balance_intervals
        self.min_conf = min_conf
        self.max_conf = max_conf
        self.soft_labels = soft_labels
        self.num_classes = model_confidences.shape[1]
        self.class_weighting = True if class_weights is not None else False


        transform = transforms.Compose([
            transforms.ToPILImage(),
            transform_base])

        self.transform = transform

        self.model_confidences = model_confidences
        predicted_max_conf, predicted_class = torch.max(self.model_confidences, dim=1)

        if dataset in ['CIFAR10', 'CIFAR100']:
            exclude_cifar = True
        else:
            exclude_cifar = False

        #in_use_indices [i] holds all valid indices for i-th confidence interval
        self.valid_indices = []

        non_cifar = torch.ones(self.model_confidences.shape[0], dtype=torch.bool)

        if exclude_cifar:
            cifar_idxs = []
            with open('./utils/80mn_cifar_idxs.txt', 'r') as idxs:
                for idx in idxs:
                    # indices in file take the 80mn database to start at 1, hence "- 1"
                    cifar_idxs.append(int(idx) - 1)

            cifar_idxs = torch.LongTensor(cifar_idxs)
            non_cifar[cifar_idxs] = 0

        valid_confidences = (predicted_max_conf >= min_conf) & (predicted_max_conf <= max_conf)
        self.valid_indices  = torch.nonzero(valid_confidences & non_cifar).squeeze() #remove non cifars
        self.length = len(self.valid_indices)

        if self.class_weighting:
            if class_weights_normalize:
                self.class_weights = normalize_class_weights(class_weights, model_confidences)
            else:
                self.class_weights = class_weights

            # self.class_weights = calculate_class_od_weights(selection_logits[self.in_use_indices],
            #                                                 50000, self.class_weighting_R, 0.25)
            self.valid_predicted_classes = predicted_class[self.valid_indices]

        print(f'[{min_conf}, {max_conf}], Exclude Cifar {exclude_cifar} - Samples {self.length} - Class weighting {self.class_weighting}')

    def __getitem__(self, index):
        valid_index = self.valid_indices[index]
        img = _load_tiny_image(valid_index, self.fileID)

        if self.transform is not None:
            img = self.transform(img)

        if self.soft_labels:
            model_prediction = self.model_confidences[valid_index, :]
        else:
            model_prediction = (1./self.num_classes) * torch.ones(self.num_classes)

        if self.class_weighting:
            sample_weight = self.class_weights[self.valid_predicted_classes[index]]
            conf_weight = torch.zeros(model_prediction.shape[0] + 1)
            conf_weight[0:-1] = model_prediction
            conf_weight[-1] = sample_weight
            return img, conf_weight
        else:
            return img, model_prediction

    def __len__(self):
        return self.length


def get_density_exclusion_tiny_images_train_loader(model_confidences, dataset, class_densities, shuffle=True, class_weights=None,
                                      class_weights_normalize=True, soft_labels=True,
                                      batch_size=100, augm_type='default', cutout_window=16, size=32, num_workers=1):
    if num_workers > 1:
        raise ValueError('Bug in the current multithreaded tinyimages implementation')

    transform = get_cifar10_augmentation(augm_type, cutout_window=cutout_window, out_size=size)
    dataset_out = TinyImageDensityExclusion(model_confidences, dataset, class_densities, soft_labels=soft_labels,
                                                transform_base=transform, class_weights=class_weights,
                                                class_weights_normalize=class_weights_normalize)
    loader = torch.utils.data.DataLoader(dataset_out, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)
    return loader


class TinyImageDensityExclusion(torch.utils.data.Dataset):
    def __init__(self, model_confidences, dataset, class_densities, transform_base, soft_labels=True,
                 class_weights=None, class_weights_normalize=True, balance_intervals=True):
        self.data_location = get_tiny_images_files(False)
        self.fileID = open(self.data_location, "rb")
        self.balance_intervals = balance_intervals
        self.soft_labels = soft_labels
        self.num_classes = model_confidences.shape[1]
        self.class_weighting = True if class_weights is not None else False

        transform = transforms.Compose([
            transforms.ToPILImage(),
            transform_base])

        self.transform = transform

        self.model_confidences = model_confidences
        predicted_max_conf, predicted_class = torch.max(self.model_confidences, dim=1)

        if dataset in ['CIFAR10', 'CIFAR100']:
            exclude_cifar = True
        else:
            exclude_cifar = False

        #in_use_indices [i] holds all valid indices for i-th confidence interval

        non_cifar = torch.ones(self.model_confidences.shape[0], dtype=torch.bool)

        if exclude_cifar:
            cifar_idxs = []
            with open('./utils/80mn_cifar_idxs.txt', 'r') as idxs:
                for idx in idxs:
                    # indices in file take the 80mn database to start at 1, hence "- 1"
                    cifar_idxs.append(int(idx) - 1)

            cifar_idxs = torch.LongTensor(cifar_idxs)
            non_cifar[cifar_idxs] = 0

        total_length = model_confidences.shape[0]
        self.valid_indices = non_cifar
        for class_idx in range(model_confidences.shape[1]):

            correct_class_bool_idcs = predicted_class == class_idx
            correct_class_linear_idcs = torch.nonzero(correct_class_bool_idcs).squeeze()
            correct_class_confidences = predicted_max_conf[correct_class_bool_idcs]
            non_cifar_correct_class_sort_idcs = torch.argsort(correct_class_confidences, descending=True)


            num_samples_i = min(int(total_length * class_densities[class_idx]), len(correct_class_linear_idcs))
            class_i_exclusion_idcs = correct_class_linear_idcs[non_cifar_correct_class_sort_idcs[: num_samples_i]]
            self.valid_indices[class_i_exclusion_idcs] = 0

        self.valid_indices  = torch.nonzero(self.valid_indices).squeeze() #remove non cifars
        self.length = len(self.valid_indices)

        if self.class_weighting:
            if class_weights_normalize:
                self.class_weights = normalize_class_weights(class_weights, model_confidences)
            else:
                self.class_weights = class_weights

            # self.class_weights = calculate_class_od_weights(selection_logits[self.in_use_indices],
            #                                                 50000, self.class_weighting_R, 0.25)
            self.valid_predicted_classes = predicted_class[self.valid_indices]

        print(f'Density based exclusion OD, Exclude Cifar {exclude_cifar} - Samples {self.length} - Class weighting {self.class_weighting}')

    def __getitem__(self, index):
        valid_index = self.valid_indices[index]
        img = _load_tiny_image(valid_index, self.fileID)

        if self.transform is not None:
            img = self.transform(img)

        if self.soft_labels:
            model_prediction = self.model_confidences[valid_index, :]
        else:
            model_prediction = (1./self.num_classes) * torch.ones(self.num_classes)

        if self.class_weighting:
            sample_weight = self.class_weights[self.valid_predicted_classes[index]]
            conf_weight = torch.zeros(model_prediction.shape[0] + 1)
            conf_weight[0:-1] = model_prediction
            conf_weight[-1] = sample_weight
            return img, conf_weight
        else:
            return img, model_prediction

    def __len__(self):
        return self.length

