import random
import torch
import torchvision
from reconstruction_dataset.augment_pipe import AugmentPipe

def scale(data, mean, factors_to_hit):
    positive_entries = torch.where(
        factors_to_hit > 0,
        factors_to_hit,
        torch.full(factors_to_hit.shape, float("inf"), device=data.device),
    )
    factors = torch.amin(positive_entries, dim=(1, 2, 3), keepdims=True)
    return data * factors + mean


def scale_to_interval(data, mean):
    """Scale each image in data (about mean) such that either the max
    pixel in the image is 1 and the min pixel is above 0 or the
    min pixel is 0 and the max pixel is below 1 (exactly one of these
    will be possible for a given image)."""
    data = data - mean
    # factors_to_hit are, for each pixel, the scalar required to scale it to (1 - mean)
    # or (0 - mean), respectively
    one_scaled = scale(data, mean, factors_to_hit=(1 - mean) / data)
    zero_scaled = scale(data, mean, factors_to_hit=(0 - mean) / data)
    # now work out which option to use for each image
    one_scaled_min = one_scaled.amin(dim=(1, 2, 3), keepdims=True)
    return torch.where(one_scaled_min < 0, zero_scaled, one_scaled)


def generate_random_data(size=1000, data_per_class_train=150, extraction_init_scale=0.001, datasets_dir=".cifar10cache"):

    if isinstance(size, list):
        size = sum(size)

    augment_settings = dict(p=0.12, xflip=1e8, yflip=1, scale=1, rotate_frac=1, aniso=1, translate_frac=1)
    augment_pipe = AugmentPipe(**augment_settings)
    transforms = torchvision.transforms.ToTensor()
    cifar10 = torchvision.datasets.CIFAR10(datasets_dir, train=True, transform=transforms, target_transform=None, download=True)

    while size > 0:
        train_indices = random.sample(range(50000), data_per_class_train * 2)
        data_subset = torch.utils.data.Subset(cifar10, train_indices)
        Xtrn, Ytrn = next(iter(torch.utils.data.DataLoader(data_subset, batch_size=len(data_subset))))
        Xtrn, augment_labels, dict_labels = augment_pipe(Xtrn)
        ds_mean = Xtrn.mean(dim=0, keepdims=True)
        Xtrn = Xtrn - ds_mean

        x = torch.randn(*(Xtrn.shape)).to(ds_mean.device) * extraction_init_scale
        scaled_x = scale_to_interval(x + ds_mean, ds_mean)
        dict_labels = {i: j.permute(*torch.arange(j.ndim - 1, -1, -1)) for i, j in dict_labels.items()}

        if size >= len(train_indices):
            size -= len(train_indices)
        else:
            train_indices = train_indices[:size]
            scaled_x = scaled_x[:size]
            dict_labels = {i: j[:size] for i, j in dict_labels.items()}
            size = 0

        for train, noise, augment_label in zip(train_indices, scaled_x, [dict(zip(dict_labels, v)) for v in zip(*dict_labels.values())]):
            yield {
                "orig": train,
                "corrupted": torch.round(noise*255).int().tolist(),
                "epoch": 0,
                "run_id": "dummystrings",
                "reconstruction_id": random.randint(0, 400),
            } | augment_label
