import torch
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.sampler import WeightedRandomSampler, Sampler
import torch.nn.functional as F
from datasets.pds_dataset import PDSSubDataset


class PseudoLabelDataset(Dataset):
    """Add pseudo labels to an unlabeled dataset"""

    def __init__(self, dataset: Dataset, labels):
        self.dataset = dataset
        self.labels = labels
        assert len(labels) == len(dataset)

    def __getitem__(self, index):
        return self.dataset[index], self.labels[index]

    def __len__(self):
        return len(self.dataset)


class ConcatBatchSampler(Sampler):
    """
    A cyclic batch sampler for a list of samplers.
    In Pytorch, if multiple DataLoaders are used at the same time with
    pin_memory = True and num_workers > 1, then a process deadlock could
    happen. To avoid such a deadlock, we only construct one DataLoader,
    and use it with ConcatBatchSampler and ConcatDataset.

    sampler: A list of samplers.
    batch_size: int or list[int]. Batch size of each dataset.
    offsets: list[int]. Offsets of each sampler in the dataset.
    index_change: list[int]. Specify when to change sampler.

    For example, if there are 3 samplers and index_change = [1,2,1], then
    [Dataset 0] [Dataset 1] [Dataset 1] [Dataset 2] [Dataset 0] [Dataset 1] ...
    """
    def __init__(self, samplers, batch_size, offsets=None, index_change=None):
        self.samplers = samplers
        self.n_samplers = len(samplers)
        self.batch_sizes = batch_size if isinstance(batch_size, list) else [batch_size] * self.n_samplers
        self.offsets = offsets
        self.index_change = index_change if index_change is not None else ([1] * self.n_samplers)

    def __iter__(self):
        index = 0
        cnt = 0
        iter_samplers = [iter(sampler) for sampler in self.samplers]
        while True:
            try:
                # print(index)
                n = self.batch_sizes[index]
                offset = 0 if self.offsets is None else self.offsets[index]
                batch = [next(iter_samplers[index]) + offset for _ in range(n)]
                yield batch
                cnt += 1
                if cnt == self.index_change[index]:
                    cnt = 0
                    index += 1
                    if index == self.n_samplers:
                        index = 0
            except StopIteration:
                break


class ConcatDataset(Dataset):
    """
    A concatenation of multiple datasets.
    """
    def __init__(self, datasets):
        self.datasets = datasets
        self.offsets = [0]
        n = 0
        for d in self.datasets:
            n += len(d)
            self.offsets.append(n)
        self.n_total = self.offsets[-1]
        self.offsets = self.offsets[:-1]

    def __getitem__(self, index):
        assert 0 <= index < len(self)
        k = 0
        while k < len(self.offsets) and index >= self.offsets[k]:
            k += 1
        k -= 1
        index -= self.offsets[k]
        return self.datasets[k][index]

    def __len__(self):
        return self.n_total


def get_balanced_loader(datasets, n, **kwargs):
    """
    Get a reweighted balanced DataLoader from a number of datasets.
    datasets  - A list of datasets.
    n         - Number of samples in the loader.
    """
    weights = []
    dts = []
    for d in datasets:
        if len(d) == 0:
            continue
        dts.append(d)
        weights.append(torch.ones((len(d),)).float() / len(d))
    weights = torch.cat(weights)
    dataset = ConcatDataset(dts)
    sampler = WeightedRandomSampler(weights, n, replacement=True)
    loader = DataLoader(dataset, sampler=sampler, **kwargs)
    return loader


def build_pseudo_labels(dataset, model, config):
    """
    Builds pseudo labels for a dataset with a model.
    Returns a tensor of pseudo labels.
    """
    loader = DataLoader(dataset, batch_size=config.batch_size,
                        shuffle=False, **config.loader_kwargs)
    pseudo_labels = torch.zeros((len(dataset))).long()
    confidence = torch.zeros((len(dataset))).float()
    model.eval()
    k = 0
    for _, x in enumerate(loader):
        x = x.to(config.device)
        bsz = len(x)
        outputs = model(x)
        probs = F.softmax(outputs, dim=1)
        top2 = torch.topk(probs, 2, dim=1)
        predictions = top2.indices[:, 0]
        pseudo_labels[k:k + bsz] = predictions.flatten().detach().cpu()
        conf = top2.values[:, 0] - top2.values[:, 1]
        confidence[k:k + bsz] = conf.flatten().detach().cpu()
        k += bsz
    assert k == len(dataset)
    return pseudo_labels, confidence


def build_pseudo_dataset(dataset_unlabeled: PDSSubDataset, pseudo_labels, confidence, gamma) -> PseudoLabelDataset:
    """Select the gamma fraction with the highest confidence"""
    if gamma < 1:
        n = len(dataset_unlabeled)
        m = int(gamma * n)
        a = confidence.argsort(descending=True)  
        a = a[:m]
        dataset_unlabeled = dataset_unlabeled.get_subset(a, False)
        pseudo_labels = pseudo_labels[a]
    return PseudoLabelDataset(dataset_unlabeled, pseudo_labels)