from pathlib import Path
from time import perf_counter

from tabicl.config.config_pretrain import ConfigPretrain
from tabicl.core.enums import GeneratorName
from tabicl.core.trainer_pretrain_init import create_synthetic_dataset


def main():

    cfg = ConfigPretrain.load(Path("outputs/done/foundation_mix_600k_finetune/config_pretrain.yaml"))
    cfg.workers_per_gpu = 16
    cfg.optim.batch_size = 64
    cfg.preprocessing.use_quantile_transformer = False

    cfg.data.generator = GeneratorName.PERLIN
    cfg.data.generator_hyperparams = {
        'min_features': 3,
        'max_features': 100,
        'n_samples': 1000,
        'max_classes': 10,
        'min_complexity': 0.01,
        'max_complexity': 1.00,
        'n_octaves': 7,
    }

    synthetic_dataset = create_synthetic_dataset(cfg)
    synthetic_iterator = iter(synthetic_dataset)

    tic = perf_counter()

    for i in range(200):
        batch = next(synthetic_iterator)
        print(f"step {i}, support x: {batch['x_support'].shape}")

    toc = perf_counter()
    print(f"Time: {toc - tic:.2f} seconds")


if __name__ == "__main__":
    main()