import torch
import numpy as np
from torch.utils.data.sampler import WeightedRandomSampler, Sampler
from PIL import Image

from .datasets import RealFakeDataset


def get_bal_sampler(dataset):
    targets = []
    # print(dataset)
    # for d in dataset.datasets:
    for d in dataset:
        targets.extend(d.targets)
    ratio = np.bincount(targets)
    w = 1. / torch.tensor(ratio, dtype=torch.float)
    sample_weights = w[targets]
    sampler = WeightedRandomSampler(weights=sample_weights,
                                    num_samples=len(sample_weights))
    return sampler


class BalancedBatchSampler(Sampler):
    def __init__(self, dataset, batch_size):
        self.dataset = dataset
        self.batch_size = batch_size
        self.batch_size_per_class = batch_size // 2

        # get lable
        self.labels = np.array([dataset[i][1] for i in range(len(dataset))])


        self.indices_class0 = np.where(self.labels == 0)[0]
        self.indices_class1 = np.where(self.labels == 1)[0]

    def __iter__(self):
        # random
        indices_class0 = np.random.permutation(self.indices_class0)
        indices_class1 = np.random.permutation(self.indices_class1)

        balanced_indices = []
        for i in range(0, min(len(indices_class0), len(indices_class1)), self.batch_size_per_class):
            batch_indices_class0 = indices_class0[i:i + self.batch_size_per_class]
            batch_indices_class1 = indices_class1[i:i + self.batch_size_per_class]
            balanced_batch = np.concatenate((batch_indices_class0, batch_indices_class1))
            np.random.shuffle(balanced_batch)
            balanced_indices.extend(balanced_batch)

        return iter(balanced_indices)

    def __len__(self):
        return min(len(self.indices_class0), len(self.indices_class1)) * 2 // self.batch_size


def create_dataloader(opt, preprocess=None):
    shuffle = not opt.serial_batches if (opt.isTrain and not opt.class_bal) else False
    dataset = RealFakeDataset(opt)
    if '2b' in opt.arch:
        dataset.transform = preprocess
    print(dataset)


    print(opt.class_bal)

    data_loader = torch.utils.data.DataLoader(dataset,
                                              num_workers=int(opt.num_threads),
                                              batch_size=opt.batch_size,
                                              drop_last=True,
                                              )
    print(data_loader)

    return data_loader
