import datetime
import shutil
import sys
import json
import math
import os
import random
import tempfile

import torch
import torchvision
from datasets import Dataset, load_dataset, concatenate_datasets, disable_progress_bars

import datasets.config as config
try:
    from torch_utils import distributed as dist
except ImportError:
    pass # loading module falied

from reconstruction_dataset.augment_pipe import AugmentPipe
from reconstruction_dataset.random_noise_data import generate_random_data


# logging.set_verbosity(logging.ERROR)
#disable_progress_bars()
config.IN_MEMORY_MAX_SIZE = 25000000000


def trans(batch, cifar, pipeline, device):
    result = {"epoch": [batch["epoch"]], "run_id": [batch["run_id"]], "reconstruction_id": [batch["reconstruction_id"]], "corrupted": [(batch["corrupted"].to(device) / 255)]}
    labels = {key: batch[key].to(device) for key in ["xflip", "yflip", "scale", "rotate_frac", "aniso_w", "aniso_r", "translate_frac"]}
    images, aug_labels, dict_labels = pipeline(torch.stack([image.to(device) for image, _ in torch.utils.data.Subset(cifar, batch["orig"])]), labels)
    result["orig_img"] = [images]
    result["augment_label"] = [aug_labels]
    return result


def data_iterator(dataset_path, batch_size=64, pure_noise_proportion=0.1, in_memory=None, num_proc=None, device="cpu"):
    transforms = torchvision.transforms.ToTensor()
    cifar = torchvision.datasets.CIFAR10(".cifar10cache", train=True, transform=transforms, target_transform=None, download=True)

    print(f"{datetime.datetime.now()}: loading dataset")
    # if dataset path is a directory, load from disk
    if os.path.isdir(dataset_path):
        dataset = Dataset.load_from_disk(dataset_path)
    else:
        dataset = load_dataset(dataset_path, split="train", streaming=False)
    
    if (not "torch_utils.distributed" in sys.modules) or dist.get_rank() == 0:
        print("cleaning up cache files", dataset.cleanup_cache_files())

    print(f"{datetime.datetime.now()}: find counts")
    counts_filename = f"counts_{dataset._fingerprint}.json"
    if not os.path.isfile(counts_filename):
        counts = dataset.remove_columns(["corrupted"]).to_pandas()["orig"].value_counts().to_dict()
        with open(counts_filename, "w") as f:
            json.dump(counts, f)

    with open(counts_filename, "r") as f:
        counts = json.load(f)

    smallest_count = min(counts.values())
    print("smallest count:", smallest_count)

    print(f"{datetime.datetime.now()}: find thresholds")
    thresholds_filename = f"thresholds_{dataset._fingerprint}.json"
    if not os.path.isfile(thresholds_filename):
        thresholds = dataset.remove_columns(["corrupted"]).to_pandas()["orig"].apply(lambda x: 1 - smallest_count / counts[str(x)]).to_list()
        thresholds = [max(0, x) for x in thresholds]
        with open(thresholds_filename, "w") as f:
            json.dump(thresholds, f)

    with open(thresholds_filename, "r") as f:
        thresholds = json.load(f)

    print(f"{datetime.datetime.now()}: thresholds found")

    # the expected size of the dataset after filtering is len(thresholds)-sum(thresholds). We will ensure that it is exactly 98% of that
    target_size = round(0.98 * (len(thresholds) - sum(thresholds)))
    
    while True:

        print(f"{datetime.datetime.now()}: filtering dataset")

        print(f"target size is {target_size}")
        indices = []
        while len(indices) < target_size:
            # identify the indicies for which the value in thresholds is smaller than the corresponding value in numbers
            indices = [i for i, (x, y) in enumerate(zip(thresholds, [random.random() for _ in range(len(thresholds))])) if x < y]
            if len(indices) < target_size:
                print(f"{datetime.datetime.now()}: not enough indices, trying again. Need {target_size}, have {len(indices)}")
        # select target_size many entries of indices at random
        indices = sorted(random.sample(indices, target_size))
        ds = dataset.select(indices, keep_in_memory=in_memory)

        
        print(f"{datetime.datetime.now()}: generating random dataset")
        random_noise_size = round(len(ds) * pure_noise_proportion)

      if random_noise_size:
            # generate temporary folder name
            temp_dir = tempfile.mkdtemp()
            random_ds = Dataset.from_generator(generate_random_data, gen_kwargs={"size": [1] * random_noise_size}, features=ds.features, keep_in_memory=True, num_proc=num_proc, cache_dir=temp_dir)
            # delete temporary folder
            shutil.rmtree(temp_dir)

        print(f"{datetime.datetime.now()}: ds length ", len(ds))
       
        if random_noise_size:
            ds_combined = concatenate_datasets([ds, random_ds])
        else:
            ds_combined = ds
        print(f"{datetime.datetime.now()}: start shuffle")
        ds_combined = ds_combined.shuffle(keep_in_memory=in_memory).with_format("torch")
        print(f"{datetime.datetime.now()}: shuffle done")

        # print contents of directory of dataset
        print("dataset directory contents", os.listdir(dataset_path))

        batched_ds = ds_combined.map(
            trans,
            num_proc=num_proc,
            keep_in_memory=in_memory,
            fn_kwargs={"device": device, "cifar": cifar, "pipeline": AugmentPipe(p=0.12, xflip=1e8, yflip=1, scale=1, rotate_frac=1, aniso=1, translate_frac=1)},
            batched=True,
            batch_size=batch_size,
            remove_columns=["orig", "xflip", "yflip", "scale", "rotate_frac", "aniso_w", "aniso_r", "translate_frac"],
        )
        batched_ds = batched_ds.with_format(None).to_iterable_dataset().with_format("torch")  # this results in better performance

        print(f"{datetime.datetime.now()}: dataset ready")

        for i in batched_ds:
            yield i
        
        if "torch_utils.distributed" in sys.modules:
            torch.distributed.barrier() 
        if (not "torch_utils.distributed" in sys.modules) or dist.get_rank() == 0:
            print("cleaning up cache files", dataset.cleanup_cache_files())
        if "torch_utils.distributed" in sys.modules:
            torch.distributed.barrier() 
        
        del batched_ds
        del ds_combined
        if random_noise_size:
            del random_ds
        del ds


if __name__ == "__main__":
    dataset_path = "/dummy"
    test = data_iterator(dataset_path, pure_noise_proportion=0.1)
    for x in test:
        torchvision.utils.save_image(torch.cat([x["corrupted"], x["orig_img"]], dim=0), "test.png", nrow=64)
        break

