import argparse
from datetime import datetime
import logging
from pathlib import Path
from dataclasses import replace
from pipeline.datagen.trajectory_generator import DataGenerator
from pipeline.datagen.config import DataGenConfig
from pipeline.datagen.save import save_npz

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

def main():
    parser = argparse.ArgumentParser(description="Generate dynamical system dataset")
    parser.add_argument('--config', type=Path, default='./configs/datagen/datagen_config.json', help='Path to config JSON')
    args = parser.parse_args()

    base_config = DataGenConfig.from_json(args.config)

    exp_version = generate_version_string(base_config.__dict__)

    if base_config.parameter_mode == "shared":
        gen_tmp = DataGenerator(base_config)
        params = gen_tmp.sample_parameters()
        base_config = replace(base_config, parameters=params.tolist())

    for split in ['train', 'val', 'test']:
        split_config = base_config.__dict__.copy()
        if split == 'train':
            split_config['n_samples'] = base_config.n_samples
            split_config['output_dir'] = Path(base_config.output_dir) /base_config.experiment / exp_version / 'train_data.npz'
        elif split == 'val':
            split_config['n_samples'] = 1
            split_config['crop_length'] = base_config.crop_length * 10
            split_config['crop_validator'] = base_config.crop_validator.update({'min_switches':10})
            split_config['output_dir'] = Path(base_config.output_dir) /base_config.experiment / exp_version / 'val_data.npz'
        else: # test
            split_config['n_samples'] = 1
            split_config['crop_length'] = base_config.crop_length * 10
            split_config['output_dir'] = Path(base_config.output_dir) /base_config.experiment / exp_version / 'test_data.npz'

        split_data_gen_config = DataGenConfig(**split_config)
        gen = DataGenerator(split_data_gen_config)

        data = gen.generate_dataset(progress=True)

        save_npz(data, gen.config, split_config['output_dir'])

def generate_version_string(config: dict) -> str:
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    
    key_params = f"n{config['n_samples']}_dt{config['dt']}_ti{config['t_start']}_tf{config['t_end']}"

    return f"{timestamp}_{key_params}"

if __name__ == "__main__":
    main()