import random

import numpy as np
import torch


def group_images(dataset, labels):
    # Group images based on the label in LABELS (using labels not reordered)
    idxs = {lab: [] for lab in labels}

    labels_cum = labels + [0, 255]
    for i in range(len(dataset)):
        cls = np.unique(np.array(dataset[i][1]))
        if all(x in labels_cum for x in cls):
            for x in cls:
                if x in labels:
                    idxs[x].append(i)
    return idxs


def filter_images(dataset, labels, labels_old=None, overlap=True,create_ood = False):
    # Filter images without any label in LABELS (using labels not reordered)
    idxs = []

    if 0 in labels:
        labels.remove(0)

    print(f"Filtering images...")
    if labels_old is None:
        labels_old = []
    labels_cum = labels + labels_old + [0, 255]

    if overlap and not create_ood:
        fil = lambda c: any(x in labels for x in cls)
    elif overlap and create_ood:
        fil = lambda c: any(x in labels for x in cls)
        fil_ood = lambda c: any(x not in labels for x in cls)
    elif  create_ood and not overlap:
        raise NotImplementedError()
    else:
        fil = lambda c: any(x in labels for x in cls) and all(x in labels_cum for x in c)
    idxs_ood = []
    for i in range(len(dataset)):
        cls = np.unique(np.array(dataset[i][1]))
        if not create_ood:
            if fil(cls):
                idxs.append(i)
            if i % 1000 == 0:
                print(f"\t{i}/{len(dataset)} ...")
        else:
            if fil(cls):
                idxs.append(i)
                if i % 1000 == 0:
                    print(f"train data:\t{i}/{len(dataset)} ...")
            if fil_ood(cls):
                idxs_ood.append(i)
                if i % 1000 == 0:
                    print(f"ood_data:\t{i}/{len(dataset)} ...")
    if len( idxs_ood )==0:
        return idxs
    else:
        random_ood = random.sample(idxs_ood,k = len(idxs))
        mixed = random_ood + idxs
        random.shuffle(mixed)
        idxs = mixed
        return idxs


class Subset(torch.utils.data.Dataset):
    """
    Subset of a dataset at specified indices.
    Arguments:
        dataset (Dataset): The whole Dataset
        indices (sequence): Indices in the whole set selected for subset
        transform (callable): way to transform the images and the targets
        target_transform(callable): way to transform the target labels
    """

    def __init__(self, dataset, indices, transform=None, target_transform=None):
        self.dataset = dataset
        self.indices = indices
        self.transform = transform
        self.target_transform = target_transform

    def __getitem__(self, idx):
        try:
            sample, target = self.dataset[self.indices[idx]]
        except Exception as e:
            raise Exception(
                f"dataset = {len(self.dataset)}, indices = {len(self.indices)}, idx = {idx}, msg = {str(e)}"
            )

        if self.transform is not None:
            sample, target = self.transform(sample, target)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return sample, target

    def viz_getter(self, idx):
        image_path, raw_image, sample, target = self.dataset.viz_getter(self.indices[idx])
        if self.transform is not None:
            sample, target = self.transform(sample, target)
        if self.target_transform is not None:
            target = self.target_transform(target)

        return image_path, raw_image, sample, target

    def __len__(self):
        return len(self.indices)


class MaskLabels:
    """
    Use this class to mask labels that you don't want in your dataset.
    Arguments:
    labels_to_keep (list): The list of labels to keep in the target images
    mask_value (int): The value to replace ignored values (def: 0)
    """

    def __init__(self, labels_to_keep, mask_value=0):
        self.labels = labels_to_keep
        self.value = torch.tensor(mask_value, dtype=torch.uint8)

    def __call__(self, sample):
        # sample must be a tensor
        assert isinstance(sample, torch.Tensor), "Sample must be a tensor"

        sample.apply_(lambda t: t.apply_(lambda x: x if x in self.labels else self.value))

        return sample
