import os

from functools import partial
from datasets import load_dataset, load_from_disk, Dataset

SEED = 42

def gen_from_iterable_dataset(iterable_ds):
    yield from iterable_ds


def build_random_dataset_hf(seed, hf_name, buffer_size, hf_data_dir=None, num_examples=500,
                            split='train', is_save_to_disk=False, save_path='./dataset', is_shuffle=False):
    ds_iterable = load_dataset(hf_name, hf_data_dir, split=split, streaming=True)
    
    if is_shuffle:
        shuffled_ds_iterable = ds_iterable.shuffle(seed=seed, buffer_size=buffer_size)
        sampled_ds_iterable = shuffled_ds_iterable.take(num_examples)
    else:
        sampled_ds_iterable = ds_iterable.take(num_examples)

    if is_save_to_disk:
        ds = Dataset.from_generator(partial(gen_from_iterable_dataset, sampled_ds_iterable), \
                                     features=sampled_ds_iterable.features)
        ds.save_to_disk(save_path)


def split_train_test_and_save_dataset(full_dataset_path, train_size, out_split_dataset_path):
    ds = load_from_disk(full_dataset_path)

    # deterministic split, because the datasets are randomly sampled
    ds_split = ds.train_test_split(train_size=train_size, shuffle=False)
    ds_split.save_to_disk(out_split_dataset_path)


def create_dataset(storage_root_dir, dataset_names):
    if 'wikiart' in dataset_names:
        # wikiart - only has 'train', shuffling needed
        print(f'Started WikiArt Download')
        save_full_dataset_path = os.path.join(storage_root_dir, 'wikiart_all')
        build_random_dataset_hf(SEED, 'huggan/wikiart', 100_000, num_examples=5_500, \
                                is_save_to_disk=True, save_path=save_full_dataset_path, is_shuffle=True)
        split_save_path = os.path.join(storage_root_dir, 'wikiart')
        split_train_test_and_save_dataset(save_full_dataset_path, 5_000/5_500, split_save_path)

    if 'diffusiondb' in dataset_names:
        # diffusiondb - only has 'train', no shuffling
        print(f'Started DiffusionDB Download')
        save_full_dataset_path = os.path.join(storage_root_dir, 'diffusiondb_all')
        build_random_dataset_hf(SEED, 'poloclub/diffusiondb', 100_000, num_examples=5_500, \
                                is_save_to_disk=True, save_path=save_full_dataset_path, hf_data_dir='2m_random_100k')
        split_save_path = os.path.join(storage_root_dir, 'diffusiondb')
        split_train_test_and_save_dataset(save_full_dataset_path, 5_000/5_500, split_save_path)


# TODO: Hacked up - requires full dataset download due to shuffle issues
def create_coco_dataset(storage_root_dir):
    print(f'Started COCO Train Download')

    save_path = os.path.join(storage_root_dir, 'coco')

    ds = load_dataset("HuggingFaceM4/COCO", '2014_captions', split='train')
    ds_split_train = ds.train_test_split(train_size=5000/len(ds), seed=SEED, shuffle=True)
    len_train, len_test = len(ds_split_train['train']), len(ds_split_train['test'])
    print(f'Size of COCO Train : Train: {len_train}, Test: {len_test}')

    print(f'Started COCO Val Download')
    ds = load_dataset("HuggingFaceM4/COCO", '2014_captions', split='test')
    ds_split_test = ds.train_test_split(train_size=500/len(ds), seed=SEED, shuffle=True)
    len_test = len(ds_split_test['train'])
    print(f'Size of COCO Test : {len_test}')

    ds_split_train['test'] = ds_split_test['train']
    ds_split_train.save_to_disk(save_path)



if __name__ == '__main__':
    storage_root_dir = '/localhome/data/datasets/watermarking/training'
    # create_coco_dataset(storage_root_dir)
