"""
This script take the lmdb dataset and split it into 3 lmdb files train/valid/test
"""

from dataclasses import dataclass
from pprint import pprint

import torch

from Utils.HyperParams import HP_parsed
from datasets.BlenderDataset.BlenderDataset import folder2lmdb
from datasets.dataloading import get_blender3


@dataclass
class Params:
    dataset_root_folder: str = "FULL DATASET BLENDER LMDB"  # source folder
    dataset_name       : str = "blender_dataset_32.lmdb"  # source file name
    train_proportion   : float = 0.6
    valid_proportion   : float = 0.2
    test_proportion    : float = 0.2
    reduced_size_train : int = -1
    reduced_size_valid : int = -1
    reduced_size_test  : int = -1
    return_params      : bool = True

    batch_size_train: int = 20  # Not used
    batch_size_valid: int = 20  # Not used
    batch_size_test : int = 20  # Not used

    lmbd_dest_path : str = 'test_reverse'
    lmdb_dest_name_train: str = 'blender_dataset_train.lmdb'
    lmdb_dest_name_valid: str = 'blender_dataset_valid.lmdb'
    lmdb_dest_name_test : str = 'blender_dataset_test.lmdb'
    write_frequency: int = 100


from params.slurm_params import cfg_slurm
params = Params()
parser = HP_parsed([params, cfg_slurm])
params = parser.parse_and_fuze()
pprint(params)


ds = get_blender3(
    params=params, sample_transforms=None, target_transforms=None, random_split=True,
)
full_dataset = torch.utils.data.ConcatDataset([ds.train_ds, ds.valid_ds, ds.test_ds])
print(f'{len(full_dataset)=}')
# create lmdb dataset
folder2lmdb(ds.train_ds, params['lmbd_dest_path'], params['lmdb_dest_name_train'], write_frequency=params['write_frequency'])
folder2lmdb(ds.valid_ds, params['lmbd_dest_path'], params['lmdb_dest_name_valid'], write_frequency=params['write_frequency'])
folder2lmdb(ds.test_ds, params['lmbd_dest_path'], params['lmdb_dest_name_test'], write_frequency=params['write_frequency'])
# folder2lmdb(full_dataset, params['lmbd_dest_path'], 'blender_dataset.lmdb', write_frequency=params['write_frequency'])
