from typing import Callable, Iterable

import torch
from torch.utils.data import Dataset


class TargetToTensor(Callable):
    def __init__(self):
        super().__init__()

    def __call__(self, target: int) -> torch.Tensor:
        return torch.tensor(target, dtype=torch.int64)


class LoadToDevice(Callable):
    def __init__(self, device: str):
        super().__init__()
        self.device = device

    def __call__(self, data: torch.Tensor) -> torch.Tensor:
        return data.to(self.device)


class ExpandChannels(Callable):
    def __init__(self, num_channels: int):
        super().__init__()
        self.num_channels = num_channels

    def __call__(self, image: torch.Tensor) -> torch.Tensor:
        return image.expand(self.num_channels, -1, -1)


def balance_prefix(train_data: Dataset, prefix: Iterable[int], ord_labels: Iterable[int],
                   num_samples: int = 300) -> Iterable[int]:
    """
    train_data: dataset of training samples
    prefix: ordered prefix of samples
    ord_labels: the labels of the dataset ordered
    num_samples: the number of samples to draw for each label
    """
    balanced_prefix_a = []
    i = 0
    max_ind = len(prefix) - 1
    for label in ord_labels:
        samples = 0
        while (samples < num_samples) and (i <= max_ind):
            if train_data.targets[prefix[i]] < label:
                i += 1
            else:
                balanced_prefix_a.append(prefix[i].item())
                samples += 1
                i += 1
    return balanced_prefix_a


def pre_process_dataset(data: Dataset, num_samples: int = -1) -> Iterable[int]:
    """
    data: A training set as input
    num_samples: The samples to be drawn from the dataset if a subset is preferable
                to the whole dataset
    """
    prefix = torch.argsort(torch.tensor(data.targets))
    if num_samples > 0:
        label_set = {x.item() for x in torch.tensor(data.targets)}
        labels = list(label_set)
        prefix_train = balance_prefix(data, prefix, labels, num_samples=num_samples)
        #print(prefix_train[:5], prefix[:5])
        prefix_valid = list(set(prefix.tolist()).difference(prefix_train))
        #print(max(prefix_valid), len(labels))
        labels_valid = {i: data.targets[i] for i in prefix_valid}
        prefix_valid = list(dict(sorted(labels_valid.items(), key = lambda x: x[1])).keys())
        # if len(prefix_train)!=len(prefix):
        #     for i in prefix:
        #         if i not in prefix_train:
        #             prefix_valid.append(i)
    else:
        prefix_train, prefix_valid = prefix, None
    return prefix_train, prefix_valid
