import os
import json
import logging
from datasets import DatasetDict, Dataset

def _load_image_file_names(image_dir):
    """
    Load image file names from the specified directory.
    """
    image_files = [f for f in os.listdir(image_dir) if os.path.isfile(os.path.join(image_dir, f))]
    return [{"id": idx, "image": img_file} for idx, img_file in enumerate(image_files)]

def _split_dataset(dataset, config):
    """
    Split a single dataset into train, validation, and test splits.
    """
    train_valid_test = dataset.train_test_split(test_size=0.2, shuffle=True, seed=config['seed'])
    test_valid_split = train_valid_test['test'].train_test_split(test_size=0.5, shuffle=True, seed=config['seed'])

    return DatasetDict({
        'train': train_valid_test['train'],
        'validation': test_valid_split['train'],
        'test': test_valid_split['test'],
    })

def main(config):
    # Load image file names from the COCO14 dataset
    train_files = _load_image_file_names(os.path.join(config['coco_dir'], 'train2014'))
    val_files = _load_image_file_names(os.path.join(config['coco_dir'], 'val2014'))
    test_files = _load_image_file_names(os.path.join(config['coco_dir'], 'test2014'))

    # Create Dataset objects
    dataset_dict = DatasetDict({
        'train': Dataset.from_dict({"id": [item["id"] for item in train_files], "image": [item["image"] for item in train_files]}),
        'validation': Dataset.from_dict({"id": [item["id"] for item in val_files], "image": [item["image"] for item in val_files]}),
        'test': Dataset.from_dict({"id": [item["id"] for item in test_files], "image": [item["image"] for item in test_files]}),
    })

    # Split datasets if necessary (assuming already split as per the directory structure)
    if len(dataset_dict) == 1:
        dataset_dict = _split_dataset(dataset_dict['train'], config)

    # Save the dataset to disk
    output_dir = config['output_datasets_dir']
    os.makedirs(output_dir, exist_ok=True)
    dataset_dict.save_to_disk(output_dir)

    # Save metadata
    meta_info = {
        "description": "COCO14 dataset split into train, validation, and test sets.",
        "train_size": len(dataset_dict['train']),
        "validation_size": len(dataset_dict['validation']),
        "test_size": len(dataset_dict['test'])
    }
    with open(os.path.join(output_dir, 'meta.json'), 'w') as f:
        json.dump(meta_info, f, indent=4)

    logging.info(f"Datasets saved to {output_dir}")


if __name__ == "__main__":
    config = {
        "coco_dir": "/XXXX-5/home-XXXX-3/data/MSD/datasets/original/COCO14",
        "output_datasets_dir": "/XXXX-5/home-XXXX-3/data/MSD/datasets/COCO2014",
        "train_split": 0.8,
        "val_split": 0.1,
        "test_split": 0.1,
        "seed": 2024,
        "tiny_data": False
    }
    main(config)



"""
# ls 
root@speculative-decoding-0:/XXXX-5/home-XXXX-3/data/MSD/datasets/COCO14# ls
test2014  test2014.zip  train2014  train2014.zip  val2014  val2014.zip

- huggingface datasets package
"
def load_datasets(config, tokenizer, drf_image_processor) -> Dict[str, Dataset]:
    path_dataset = os.path.join(config['input_datasets_dir'], 'meta.json')
    map_datasets = load_dataset("json", data_files=path_dataset)

    # If there's only one dataset, split it into train, validation, and test
    if len(map_datasets) == 1:
        split_single = list(map_datasets.keys())[0]
        full_train_dataset = map_datasets[split_single]
        map_datasets = _split_dataset(full_train_dataset, config)

    # Apply tiny data filter if debugging
    if config['tiny_data']:
        tiny_map = {'train': 80, 'validation': 10, 'test': 10}
        map_datasets = _apply_tiny_data_filter(map_datasets, tiny_map)

    # Log dataset info
    for split, dataset in map_datasets.items():
        logging.info(f"[Dataset] {split} dataset: {len(dataset)} samples")

    # Wrap datasets with MLLMDataset
    map_datasets = _wrap_with_mllm_dataset(map_datasets, config, tokenizer, drf_image_processor)

    return map_datasets
"

"""