import torch
import torch.distributions
from torch.utils.data import Dataset
from torchvision import datasets, transforms
from torch.utils.data import Sampler

from utils.datasets.tinyImages import _load_cifar_exclusion_idcs
from utils.datasets.paths import get_CIFAR10_path, get_CIFAR100_path

import os
import numpy as np
import matplotlib.pyplot as plt
import pathlib

from utils.datasets.cifar import get_CIFAR10_labels, get_CIFAR100_labels
from utils.datasets.cifar_augmentation import get_cifar10_augmentation
from .cifar_train_val import CIFAR10TrainValidationSplit, CIFAR100TrainValidationSplit
from utils.datasets.paths import get_tiny_images_files
from utils.datasets.tinyImages import TINY_LENGTH

from .cifar_semi_tiny_partition import plot_class_conf_histograms
from utils.datasets.tinyImages import _load_tiny_image, _preload_tiny_images
from .loading_utils import load_teacher_data
from .cifar_subsets import CIFARSubset
from .tiny_images_subset import TinyImagesSubset
from utils.datasets.tinyImages import get_80MTinyImages
from torchvision.utils import save_image
from .cifar_subsets_od import get_CIFAR10_subset_plus_OD
from .cifar_subsets import get_CIFAR10_subset
from utils.datasets.cifar import get_CIFAR10, get_CIFAR10_labels
from .cifar_semi_tiny_partition import BalancedSampler, TrainSemiRatioSampler

def get_cifar_subset_plus_od_partition(teacher_model, dataset, od_dataset,
                                     labeled_samples_per_class, samples_per_class, unlabeled_samples,
                                     selection_model=None, semi_ratio=1,
                                     class_tpr_min=None, od_exclusion_threshold=None,
                                     calibrate_temperature=False, verbose_exclude=False, soft_labels=True, batch_size=100,
                                     augm_type='default',
                                     cutout_window=16, aa_magnitude=1, size=32, num_workers=0,
                                     id_config_dict=None, od_config_dict=None,
                                     ssl_config=None):

    od_loader = get_CIFAR10_subset_plus_OD('unlabeled', labeled_samples_per_class, od_dataset, unlabeled_samples,
                                            augm_type=augm_type, num_workers=8,
                                            size=size)
    unlabeled_dataset = od_loader.dataset

    if dataset == 'cifar10':
        class_labels = get_CIFAR10_labels()
        if labeled_samples_per_class <= 0:
            train_loader = get_CIFAR10(train=True,  augm_type=augm_type,
                                          size=size)
        else:
            train_loader = get_CIFAR10_subset('train', labeled_samples_per_class, augm_type=augm_type,
                                                  shuffle=True, size=size)
    else:
        raise NotImplementedError()

    train_dataset = train_loader.dataset

    dataset_classifications_path =\
        os.path.join('DatasetClassifications/', f'{dataset}_{labeled_samples_per_class * len(class_labels)}_{od_dataset}_{unlabeled_samples}')
    teacher_logits, selection_logits, class_thresholds, temperature = load_teacher_data(dataset_classifications_path, teacher_model,
                                                                      selection_model=selection_model,
                                                                      class_tpr_min=class_tpr_min,
                                                                      od_exclusion_threshold=od_exclusion_threshold,
                                                                      calibrate_temperature=calibrate_temperature,
                                                                      ssl_config=ssl_config)

    augm_config = {}
    get_cifar10_augmentation(augm_type, cutout_window, out_size=size,
                                         magnitude_factor=aa_magnitude, config_dict=augm_config)

    top_dataset = TrainSetPlusODSSLTopK(teacher_logits, train_dataset, unlabeled_dataset, samples_per_class,
                                        class_thresholds, class_labels, temperature=temperature,
                                        soft_labels=soft_labels)

    if semi_ratio <= 0:
        sampler = BalancedSampler(top_dataset, subdivide_epochs=False)
        sampler_description = 'Balanced Sampler'
    else:
        sampler = TrainSemiRatioSampler(top_dataset, batch_size, ratio=semi_ratio)
        sampler_description = f'Fixed Ratio {semi_ratio}:1'

    top_loader = torch.utils.data.DataLoader(top_dataset, sampler=sampler, batch_size=batch_size, num_workers=num_workers)

    top_k_indices = top_dataset.get_used_semi_indices(verbose_exclude)
    bottom_dataset = TrainSetPlusODSSLBottomK(teacher_logits, unlabeled_dataset, top_k_indices,
                                              temperature=temperature, soft_labels=soft_labels)

    bottom_loader = torch.utils.data.DataLoader(bottom_dataset, shuffle=True, batch_size=batch_size, num_workers=num_workers)

    if id_config_dict is not None:
        id_config_dict['Dataset'] = f'CifarSubset-SSL + {od_dataset}'
        id_config_dict['Labeled samples per class'] = labeled_samples_per_class
        id_config_dict['OD Samples'] = unlabeled_samples
        id_config_dict['Sampler'] = sampler_description
        id_config_dict['Batch out_size'] = batch_size
        id_config_dict['SSL Samples per class'] = samples_per_class
        id_config_dict['Soft labels'] = soft_labels
        id_config_dict['Augmentation'] = augm_config

    if od_config_dict is not None:
        od_config_dict['Dataset'] = 'ODPartition'
        od_config_dict['Batch out_size'] = batch_size
        od_config_dict['Verbose exclude'] = verbose_exclude
        od_config_dict['Exclude CIFAR'] = True if dataset in ['CIFAR10', 'CIFAR100'] else False
        od_config_dict['Augmentation'] = augm_config

    return top_loader, bottom_loader



class TrainSetPlusODSSLTopK(Dataset):
    def __init__(self, unlabeled_logits, train_dataset, unlabeled_dataset, samples_per_class, min_conf,
                 class_labels, temperature=1.0, soft_labels=True):
        self.samples_per_class = samples_per_class
        self.soft_labels = soft_labels
        self.temperature = temperature

        self.model_logits = unlabeled_logits
        predicted_max_conf, predicted_class = torch.max(torch.softmax(unlabeled_logits,dim=1), dim=1)

        self.num_classes = len(class_labels)
        self.train_dataset = train_dataset
        self.unlabeled_dataset = unlabeled_dataset

        self.num_train_samples = len(self.train_dataset)
        self.train_per_class = torch.zeros(self.num_classes, dtype=torch.long)

        self.train_class_idcs = []
        targets_tensor = torch.LongTensor(self.train_dataset.targets)
        for i in range(self.num_classes):
            train_i = torch.nonzero(targets_tensor == i, as_tuple=False).squeeze()
            self.train_class_idcs.append(train_i)
            self.train_per_class[i] = len(train_i)

        self.in_use_indices = []
        self.valid_indices = []
        self.semi_per_class = torch.zeros(self.num_classes, dtype=torch.long)

        min_sampels_per_class = int(1e13)
        max_samples_per_class = 0

        for i in range(self.num_classes):
            min_conf_flag = predicted_max_conf >= min_conf[i]
            included_correct_class_bool_idcs = (predicted_class == i)  & min_conf_flag

            included_correct_class_linear_idcs = torch.nonzero(included_correct_class_bool_idcs, as_tuple=False).squeeze()
            included_correct_class_confidences = predicted_max_conf[included_correct_class_bool_idcs]
            included_correct_class_sort_idcs = torch.argsort(included_correct_class_confidences, descending=True)

            num_samples_i = int( min( samples_per_class, len(included_correct_class_linear_idcs) ))
            class_i_idcs = included_correct_class_linear_idcs[included_correct_class_sort_idcs[: num_samples_i]]

            self.valid_indices.append(included_correct_class_linear_idcs)

            self.in_use_indices.append(class_i_idcs)
            self.semi_per_class[i] = len(class_i_idcs)

            min_sampels_per_class = min(min_sampels_per_class, len(class_i_idcs))
            max_samples_per_class = max(max_samples_per_class, len(class_i_idcs))

            if num_samples_i < samples_per_class:
                print(f'Incomplete class {class_labels[i]} - Target count: {samples_per_class} - Found samples {len(class_i_idcs)}')

        self.num_semi_samples = torch.sum(self.semi_per_class)
        self.length = self.num_train_samples + self.num_semi_samples

        #internal idx ranges
        self.train_idx_ranges = []
        self.semi_idx_ranges = []

        train_idx_start = 0
        semi_idx_start = self.num_train_samples
        for i in range(self.num_classes):
            i_train_samples = self.train_per_class[i]
            i_semi_samples = self.semi_per_class[i]

            train_idx_next = train_idx_start + i_train_samples
            semi_idx_next = semi_idx_start + i_semi_samples
            self.train_idx_ranges.append( (train_idx_start, train_idx_next))
            self.semi_idx_ranges.append( (semi_idx_start, semi_idx_next))

            assert ( semi_idx_next - semi_idx_start) == self.semi_per_class[i]

            train_idx_start = train_idx_next
            semi_idx_start = semi_idx_next

        self.cum_train_lengths = torch.cumsum(self.train_per_class, dim=0)
        self.cum_semi_lengths = torch.cumsum(self.semi_per_class, dim=0)

        print(f'Top K -  Temperature {self.temperature} - Soft labels {soft_labels}'
              f'  -  Target Samples per class { self.samples_per_class} - Train Samples {self.num_train_samples}')
        print(f'Min Semi Samples {min_sampels_per_class} - Max Semi samples {max_samples_per_class}'
              f' - Total semi samples {self.num_semi_samples} - Total length {self.length}')

    def get_used_semi_indices(self, verbose_exclude=False):
        if verbose_exclude:
            return torch.cat(self.valid_indices)
        else:
            return torch.cat(self.in_use_indices)

    def _load_train_image(self, class_idx, sample_idx):
        train_idx = self.train_class_idcs[class_idx][sample_idx]
        img, label = self.train_dataset[train_idx]
        if self.soft_labels:
            one_hot_label = torch.zeros(self.num_classes)
            one_hot_label[label] = 1.0
            return img, one_hot_label
        else:
            return img, label

    def _load_unlabeled_image(self, class_idx, sample_idx):
        valid_index = self.in_use_indices[class_idx][sample_idx].item()
        img, _ = self.unlabeled_dataset[valid_index]

        if self.soft_labels:
            label = torch.softmax(self.model_logits[valid_index, :] / self.temperature, dim=0)
        else:
            label = torch.argmax(self.model_logits[valid_index, :]).item()
        return img, label

    def __getitem__(self, index):
        if index < self.num_train_samples:
            class_idx = torch.nonzero(self.cum_train_lengths > index, as_tuple=False)[0]
            if class_idx > 0:
                sample_idx = index - self.cum_train_lengths[class_idx - 1]
            else:
                sample_idx = index
            return self._load_train_image(class_idx, sample_idx)
        else:
            index_semi = index - self.num_train_samples
            class_idx = torch.nonzero(self.cum_semi_lengths > index_semi, as_tuple=False)[0]
            if class_idx > 0:
                sample_idx = index_semi - self.cum_semi_lengths[class_idx - 1]
            else:
                sample_idx = index_semi

            return self._load_unlabeled_image(class_idx, sample_idx)

    def __len__(self):
        return self.length

class TrainSetPlusODSSLBottomK(Dataset):
    def __init__(self, unlabeled_logits, unlabeled_dataset,
                 top_k_indices, temperature=1, soft_labels=True):
        self.soft_labels = soft_labels
        self.num_classes = unlabeled_logits.shape[1]
        self.temperature = temperature

        self.unlabeled_dataset = unlabeled_dataset

        self.model_logits = unlabeled_logits

        valid_bool_indices  = torch.ones(self.model_logits.shape[0], dtype=torch.bool)
        valid_bool_indices[top_k_indices] = 0
        self.valid_indices = torch.nonzero(valid_bool_indices, as_tuple=False).squeeze()

        self.length = len(self.valid_indices)

    def __getitem__(self, index):
        valid_index = self.valid_indices[index]
        img, _ = self.unlabeled_dataset[valid_index]

        if self.soft_labels:
            model_prediction = torch.softmax(self.model_logits[valid_index, :] / self.temperature, dim=0)
        else:
            model_prediction = (1./self.num_classes) * torch.ones(self.num_classes)

        return img, model_prediction

    def __len__(self):
        return self.length
