import torch
import torch.distributions
from torch.utils.data import Dataset
from torchvision import datasets, transforms


from utils.datasets.pets import Pets, get_pets_path, get_pets_labels
from utils.datasets.stanford_cars import StanfordCars, get_stanford_cars_path, get_stanford_cars_labels
from utils.datasets.food_101 import Food101, get_food_101_path, get_food_101_labels
from utils.datasets.flowers import Flowers, get_flowers_path, get_flowers_labels
from utils.datasets.openimages import OpenImages, get_openimages_path
from utils.datasets.cifar import get_CIFAR10_labels, get_CIFAR10_path

from utils.datasets.imagenet_augmentation import get_imageNet_augmentation
from utils.datasets.cifar_augmentation import get_cifar10_augmentation

from .loading_utils import load_teacher_data
from .cifar_semi_tiny_partition import BalancedSampler

def get_openImages_dataset_partition(dataset_classifications_path, teacher_model,
                             dataset, samples_per_class,
                             selection_model=None,
                             class_tpr_min=None, od_exclusion_threshold=None,
                             calibrate_temperature=False, verbose_exclude=False, soft_labels=True, batch_size=100,
                             augm_type='default', subdivide_epochs=False,
                             size=32, num_workers=8,
                             id_config_dict=None, od_config_dict=None, ssl_config=None):

    teacher_logits, selection_logits, class_thresholds, temperature = load_teacher_data(dataset_classifications_path,
                                                                                        teacher_model,
                                                                      selection_model=selection_model,
                                                                      class_tpr_min=class_tpr_min,
                                                                      od_exclusion_threshold=od_exclusion_threshold,
                                                                      calibrate_temperature=calibrate_temperature,
                                                                      ssl_config=ssl_config)

    augm_config = {}

    transform = get_imageNet_augmentation(augm_type, out_size=size, config_dict=augm_config)

    if dataset == 'cifar10':
        labels = get_CIFAR10_labels()
        path = get_CIFAR10_path()
        train_dataset = datasets.CIFAR10(path, train=True, transform=transform)
    elif dataset == 'pets':
        train_dataset = Pets(get_pets_path(), 'train', transform=transform)
        labels = get_pets_labels()
    elif dataset == 'cars':
        train_dataset = StanfordCars(get_stanford_cars_path(), True, transform=transform)
        labels = get_stanford_cars_labels()
    elif dataset == 'flowers':
        train_dataset = Flowers(get_flowers_path(), 'train', transform=transform)
        labels = get_flowers_labels()
    elif dataset == 'food-101':
        train_dataset = Food101(get_food_101_path(), 'train', transform=transform )
        labels = get_food_101_labels()
    else:
        raise NotImplementedError()

    top_dataset = TrainSetPlusOpenImagesTopKPartition(teacher_logits, selection_logits, train_dataset, samples_per_class,
                                                      transform, min_conf=class_thresholds, class_labels=labels,
                                                      exclude_dataset=dataset, temperature=temperature,
                                                      soft_labels=soft_labels)

    sampler = BalancedSampler(top_dataset, subdivide_epochs=subdivide_epochs)

    epoch_subdivs = sampler.num_epoch_subdivs

    top_loader = torch.utils.data.DataLoader(top_dataset, sampler=sampler,
                                             batch_size=batch_size, num_workers=num_workers)

    top_k_indices = top_dataset.get_used_semi_indices(verbose_exclude)
    bottom_dataset = OpenImagesBottomKPartition(teacher_logits, top_k_indices, transform=transform,
                                                    temperature=temperature, soft_labels=soft_labels,
                                                    exclude_dataset=dataset)

    bottom_loader = torch.utils.data.DataLoader(bottom_dataset, shuffle=True,
                                                batch_size=batch_size, num_workers=num_workers)

    if id_config_dict is not None:
        id_config_dict['Dataset'] =f'{dataset} SSL'
        id_config_dict['Batch out_size'] = batch_size
        id_config_dict['Samples per class'] = samples_per_class
        id_config_dict['Epoch subdivs'] = epoch_subdivs
        id_config_dict['Soft labels'] = soft_labels
        id_config_dict['Exclude dataset'] = dataset
        id_config_dict['Augmentation'] = augm_config

    if od_config_dict is not None:
        od_config_dict['Dataset'] = 'OpenImages Partition'
        od_config_dict['Batch out_size'] = batch_size
        od_config_dict['Verbose exclude'] = verbose_exclude
        od_config_dict['Exclude dataset'] = dataset
        od_config_dict['Augmentation'] = augm_config

    return top_loader, bottom_loader, epoch_subdivs


class TrainSetPlusOpenImagesTopKPartition(Dataset):
    def __init__(self, teacher_logits, selection_logits, train_dataset, samples_per_class, transform, min_conf,
                 class_labels, exclude_dataset = None,
                 temperature=1.0, soft_labels=True):
        self.open_images = OpenImages(get_openimages_path(), split='train', transform=transform,
                                      exclude_dataset=exclude_dataset)
        self.samples_per_class = samples_per_class
        self.soft_labels = soft_labels
        self.temperature = temperature

        assert len(self.open_images) == teacher_logits.shape[0]

        self.model_logits = teacher_logits
        predicted_max_conf, predicted_class = torch.max(torch.softmax(selection_logits,dim=1), dim=1)

        self.num_classes = len(class_labels)

        inclusion_idcs = torch.ones(self.model_logits.shape[0], dtype=torch.bool)

        for excluded_idx in self.open_images.exclude_idcs:
            inclusion_idcs[excluded_idx] = 0

        self.train_dataset = train_dataset

        self.num_train_samples = len(self.train_dataset)
        self.train_per_class = torch.zeros(self.num_classes, dtype=torch.long)

        self.train_class_idcs = []
        targets_tensor = torch.zeros(len(self.train_dataset), dtype=torch.long)

        #there is no uniform interface to get all targets for different datasets
        for i in range(len(self.train_dataset)):
            _, target_i = self.train_dataset[i]
            targets_tensor[i] = target_i

        for i in range(self.num_classes):
            train_i = torch.nonzero(targets_tensor == i, as_tuple=False).squeeze()
            self.train_class_idcs.append(train_i)
            self.train_per_class[i] = len(train_i)

        self.in_use_indices = []
        self.valid_indices = []
        self.semi_per_class = torch.zeros(self.num_classes, dtype=torch.long)

        min_sampels_per_class = int(1e13)
        max_samples_per_class = 0

        for i in range(self.num_classes):
            min_conf_flag = predicted_max_conf >= min_conf[i]
            included_correct_class_bool_idcs = (predicted_class == i) & inclusion_idcs & min_conf_flag

            included_correct_class_linear_idcs = torch.nonzero(included_correct_class_bool_idcs, as_tuple=False).squeeze(dim=1)
            included_correct_class_confidences = predicted_max_conf[included_correct_class_bool_idcs]
            included_correct_class_sort_idcs = torch.argsort(included_correct_class_confidences, descending=True)

            num_samples_i = int( min( samples_per_class, torch.numel(included_correct_class_linear_idcs) ))

            class_i_idcs = included_correct_class_linear_idcs[included_correct_class_sort_idcs[: num_samples_i]]

            self.valid_indices.append(included_correct_class_linear_idcs)

            self.in_use_indices.append(class_i_idcs)
            self.semi_per_class[i] = torch.numel(class_i_idcs)

            min_sampels_per_class = min(min_sampels_per_class, torch.numel(class_i_idcs))
            max_samples_per_class = max(max_samples_per_class, torch.numel(class_i_idcs))

            if num_samples_i < samples_per_class:
                print(f'Incomplete class {class_labels[i]} - Target count: {samples_per_class} - Found samples {torch.numel(class_i_idcs)}')

        self.num_semi_samples = torch.sum(self.semi_per_class)
        self.length = self.num_train_samples + self.num_semi_samples

        #internal idx ranges
        self.train_idx_ranges = []
        self.semi_idx_ranges = []

        train_idx_start = 0
        semi_idx_start = self.num_train_samples
        for i in range(self.num_classes):
            i_train_samples = self.train_per_class[i]
            i_semi_samples = self.semi_per_class[i]

            train_idx_next = train_idx_start + i_train_samples
            semi_idx_next = semi_idx_start + i_semi_samples
            self.train_idx_ranges.append( (train_idx_start, train_idx_next))
            self.semi_idx_ranges.append( (semi_idx_start, semi_idx_next))

            train_idx_start = train_idx_next
            semi_idx_start = semi_idx_next

        self.cum_train_lengths = torch.cumsum(self.train_per_class, dim=0)
        self.cum_semi_lengths = torch.cumsum(self.semi_per_class, dim=0)

        print(f'Top K -  Temperature {self.temperature} - Soft labels {soft_labels}'
              f'  -  Target Samples per class { self.samples_per_class} - Train Samples {self.num_train_samples}')
        print(f'Min Semi Samples {min_sampels_per_class} - Max Semi samples {max_samples_per_class}'
              f' - Total semi samples {self.num_semi_samples} - Total length {self.length}')


    #if verbose exclude, include all indices that fulfill the conf requirement but that are outside of the top-k range
    def get_used_semi_indices(self, verbose_exclude=False):
        if verbose_exclude:
            return torch.cat(self.valid_indices)
        else:
            return torch.cat(self.in_use_indices)

    def _load_train_image(self, class_idx, sample_idx):
        train_idx = self.train_class_idcs[class_idx][sample_idx]
        img, label = self.train_dataset[train_idx]
        if self.soft_labels:
            one_hot_label = torch.zeros(self.num_classes)
            one_hot_label[label] = 1.0
            return img, one_hot_label
        else:
            return img, label

    def _load_tiny_image(self, class_idx, tiny_lin_idx):
        valid_index = self.in_use_indices[class_idx][tiny_lin_idx].item()
        img, _ = self.open_images[valid_index]

        if self.soft_labels:
            label = torch.softmax(self.model_logits[valid_index, :] / self.temperature, dim=0)
        else:
            label = torch.argmax(self.model_logits[valid_index, :]).item()
        return img, label

    def __getitem__(self, index):
        if index < self.num_train_samples:
            class_idx = torch.nonzero(self.cum_train_lengths > index, as_tuple=False)[0]
            if class_idx > 0:
                sample_idx = index - self.cum_train_lengths[class_idx - 1]
            else:
                sample_idx = index
            return self._load_train_image(class_idx, sample_idx)
        else:
            index_semi = index - self.num_train_samples
            class_idx = torch.nonzero(self.cum_semi_lengths > index_semi, as_tuple=False)[0]
            if class_idx > 0:
                sample_idx = index_semi - self.cum_semi_lengths[class_idx - 1]
            else:
                sample_idx = index_semi

            return self._load_tiny_image(class_idx, sample_idx)

    def __len__(self):
        return self.length

class OpenImagesBottomKPartition(Dataset):
    def __init__(self, teacher_logits, top_k_indices, transform, temperature=1, soft_labels=True,
                 exclude_dataset=None):
        self.open_images = OpenImages(get_openimages_path(), split='train', transform=transform,
                                      exclude_dataset=exclude_dataset)

        assert len(self.open_images) == teacher_logits.shape[0]

        self.soft_labels = soft_labels
        self.num_classes = teacher_logits.shape[1]
        self.temperature = temperature
        self.model_logits = teacher_logits

        #in_use_indices [i] holds all valid indices for i-th confidence interval
        self.valid_indices = []

        inclusion_idcs = torch.ones(self.model_logits.shape[0], dtype=torch.bool)

        for excluded_idx in self.open_images.exclude_idcs:
            inclusion_idcs[excluded_idx] = 0

        valid_bool_indices  = torch.ones(self.model_logits.shape[0], dtype=torch.bool)
        valid_bool_indices[top_k_indices] = 0
        valid_bool_indices = valid_bool_indices & inclusion_idcs
        self.valid_indices = torch.nonzero(valid_bool_indices, as_tuple=False).squeeze()

        self.length = len(self.valid_indices)

        print(f'Exclude Dataset {exclude_dataset} - Samples {self.length} - Temperature {self.temperature}')

    def __getitem__(self, index):
        valid_index = self.valid_indices[index]
        img, _ = self.open_images[valid_index]

        if self.soft_labels:
            model_prediction = torch.softmax(self.model_logits[valid_index, :] / self.temperature, dim=0)
        else:
            model_prediction = (1./self.num_classes) * torch.ones(self.num_classes)

        return img, model_prediction

    def __len__(self):
        return self.length
