from argparse import ArgumentParser

import yaml

from sde.dataset_generators import (
    BezierDatasetGenerator,
    CircleRectDatasetGenerator,
    GaussianDatasetGenerator,
    ModuloDatasetGenerator,
    MultiPatchSumDataset,
    SingleColorModuloDatasetGenerator,
)
from sde.utils import DictAction, merge_from_options, read_config, set_random_seed, setup_logger


def parse_args():
    parser = ArgumentParser('Make a synthetic dataset.')

    parser.add_argument('config', help='Config to use')
    parser.add_argument('--seed', type=int, default=42, help='Random seed.')
    parser.add_argument(
        '--cfg-options', nargs='+', action=DictAction, help='Key value pairs xxx=yyy to override config options.')

    return parser.parse_args()


if __name__ == '__main__':
    args = parse_args()
    cfg = read_config(args.config)
    if args.cfg_options is not None:
        cfg = merge_from_options(cfg, args.cfg_options)

    logger = setup_logger("sde")
    logger.info(f'Using config:\n{yaml.dump(cfg, indent=4, sort_keys=False)}\n' + '-' * 60)
    if args.seed is not None:
        set_random_seed(args.seed)
        logger.info(f'Using seed: {args.seed}')

    generator_cfg = cfg['generator']
    generator_type = generator_cfg.pop('type')
    if generator_type == "gaussian":
        generator = GaussianDatasetGenerator(**generator_cfg)
    elif generator_type == "modulo":
        generator = ModuloDatasetGenerator(**generator_cfg)
    elif generator_type == "bezier":
        generator = BezierDatasetGenerator(**generator_cfg)
    elif generator_type == "single_color_modulo":
        generator = SingleColorModuloDatasetGenerator(**generator_cfg)
    elif generator_type == "multi_color_sum":
        generator = MultiPatchSumDataset(**generator_cfg)
    elif generator_type == "training":
        generator = CircleRectDatasetGenerator(**generator_cfg)
    else:
        raise TypeError(f'Unsupported dataset generator type: {generator_type}')
    generator.make_dataset()
