import numpy as np
from copy import deepcopy
from datasets.bases import MergedDatasetMask

def subsample_instances(dataset, prop_indices_to_subsample=0.8, is_random=False):

    if is_random:
        np.random.seed(0)
        subsample_indices = np.random.choice(range(len(dataset)), replace=False,
                                            size=(int(prop_indices_to_subsample * len(dataset)),))
    else:
        subsample_indices = []
        target_uqidx_dict = {}
        # Iterate over the labels list and populate the dictionary
        for target, idx in zip(dataset.targets, dataset.uq_idxs):
            if target in target_uqidx_dict:
                target_uqidx_dict[target].append(idx)
            else:
                target_uqidx_dict[target] = [idx]
        for target, uq_idxs in target_uqidx_dict.items():
            subsample_indices.extend(np.random.choice(uq_idxs, replace=False, size=(int(prop_indices_to_subsample * len(uq_idxs)))))
    return np.array(subsample_indices)

def merge_dataset(dataset1, dataset2):
    dataset1.data = np.vstack((dataset1.data, dataset2.data))
    dataset1.targets.extend(dataset2.targets)
    dataset1.uq_idxs = np.hstack((dataset1.uq_idxs, dataset2.uq_idxs))
    return dataset1

def subsample_dataset(dataset, idxs):

    # Allow for setting in which all empty set of indices is passed

    if len(idxs) > 0:
        dataset.data = dataset.data[idxs]
        dataset.targets = np.array(dataset.targets)[idxs].tolist()
        dataset.uq_idxs = dataset.uq_idxs[idxs]
        return dataset
    else:
        return None

def subsample_dataset_ratio(dataset, test_sample_ratio):
    # Create a dictionary to hold the samples for each class
    class_samples = {class_idx: [] for class_idx in range(len(dataset.classes))}

    # Iterate over the original dataset and store samples for each class
    from tqdm import tqdm
    for _, target, uq_idx in tqdm(dataset, desc='Subsample dataset via ratio'):
        class_samples[target].append(uq_idx)

    # Create a new list to store the subsampled indices
    subsampled_indices = []
    # Iterate over each class and randomly select the desired number of samples
    for _, sample_idxs in class_samples.items():
        # Add the first 'min_samples' indices to the subsampled indices list
        subsampled_indices.extend(np.random.choice(sample_idxs, int(len(sample_idxs)*test_sample_ratio)))
    # Create a Subset of the original dataset using the subsampled indices
    subsampled_dataset = subsample_dataset(dataset, subsampled_indices)
    subsampled_dataset.uq_idxs = np.arange(len(subsampled_dataset))
    return subsampled_dataset

def subsample_classes(dataset, include_classes=(0, 1, 8, 9)):

    cls_idxs = [x for x, t in enumerate(dataset.targets) if t in include_classes]

    target_xform_dict = {}
    for i, k in enumerate(include_classes):
        target_xform_dict[k] = i

    dataset = subsample_dataset(dataset, cls_idxs)

    # dataset.target_transform = lambda x: target_xform_dict[x]

    return dataset

def get_train_val_indices(train_dataset, val_split=0.2):

    train_classes = np.unique(train_dataset.targets)

    # Get train/test indices
    train_idxs = []
    val_idxs = []
    for cls in train_classes:

        cls_idxs = np.where(train_dataset.targets == cls)[0]

        v_ = np.random.choice(cls_idxs, replace=False, size=((int(val_split * len(cls_idxs))),))
        t_ = [x for x in cls_idxs if x not in v_]

        train_idxs.extend(t_)
        val_idxs.extend(v_)

    return train_idxs, val_idxs


def get_split(
        whole_training_set, test_dataset, 
        common_classes, private_classes, 
        train_transforms, test_transforms,
        prop_train_labels,
        split_train_val,
        merge_mask
    ):
    unlabeled_classes = set(np.unique((whole_training_set.targets))) - set(private_classes)
    labeled_classes = set(common_classes).union(set(private_classes))
    # Get labelled training set which has subsampled classes, then subsample some indices from that
    train_dataset_common = subsample_classes(deepcopy(whole_training_set), include_classes=common_classes)
    train_dataset_private = subsample_classes(deepcopy(whole_training_set), include_classes=private_classes)
    subsample_indices = subsample_instances(train_dataset_common, prop_indices_to_subsample=prop_train_labels)
    train_dataset_common = subsample_dataset(deepcopy(whole_training_set), subsample_indices)
    train_dataset_labelled = merge_dataset(deepcopy(train_dataset_private), deepcopy(train_dataset_common))

    # Get unlabelled data
    unlabelled_indices = set(whole_training_set.uq_idxs) - set(train_dataset_labelled.uq_idxs)
    train_dataset_unlabelled = subsample_dataset(deepcopy(whole_training_set), np.array(list(unlabelled_indices)))

    # Train split (labelled and unlabelled classes) for training
    if merge_mask:
        train_dataset_mask = MergedDatasetMask(labelled_dataset=deepcopy(train_dataset_labelled), unlabelled_dataset=deepcopy(train_dataset_unlabelled))
    # Split into training and validation sets
    if split_train_val:
        train_idxs, val_idxs = get_train_val_indices(train_dataset_labelled)
        train_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), train_idxs)
        val_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), val_idxs)
        val_dataset_labelled_split.transform = test_transforms
        # Get test set for all classes; Either split train into train and val or use test set as val
        train_dataset_labelled = train_dataset_labelled_split
        val_dataset_labelled = val_dataset_labelled_split
    else:
        val_dataset_labelled = None
    # reset the uq_idxs
    train_dataset_labelled.uq_idxs = np.array(range(len(train_dataset_labelled)))
    train_dataset_unlabelled.uq_idxs = np.array(range(len(train_dataset_unlabelled)))
    all_datasets = {
        'train': whole_training_set,
        'train_mask': train_dataset_mask,
        'train_labelled': train_dataset_labelled,
        'train_unlabelled': train_dataset_unlabelled,
        'val': val_dataset_labelled,
        'test': test_dataset,
        'num_known_classes': len(labeled_classes),
        'num_unknown_classes': len(unlabeled_classes),
        'common_classes': common_classes,
        'private_classes': private_classes,
        'unlabeled_classes': unlabeled_classes
    }
    
    return all_datasets