import os
from argparse import ArgumentParser

from src.gflownet.tasks.gsk_synthesis import GskSynthesisTrainer
from src.gflownet.tasks.jnk_synthesis import JnkSynthesisTrainer
from src.gflownet.tasks.seh_synthesis import SehSynthesisTrainer


def parse_args():
    parser = ArgumentParser("RxnFlow", description="QED-UniDock Optimization with RxnFlow")

    run_cfg = parser.add_argument_group("Operation Config")
    run_cfg.add_argument("-o", "--out_dir", type=str, required=True, help="Output directory")
    run_cfg.add_argument(
        "-n",
        "--num_oracles",
        type=int,
        default=1000,
        help="Number of Oracles (64 molecules per oracle; default: 1000)",
    )
    run_cfg.add_argument("--env_dir", type=str, default="./data/envs/rgfn_original_aug", help="Environment Directory Path")
    run_cfg.add_argument(
        "--subsampling_ratio",
        type=float,
        default=0.01,
        help="Action Subsampling Ratio. Memory-variance trade-off (Smaller ratio increase variance; default: 0.01)",
    )
    run_cfg.add_argument(
        "--override",
        action="store_true",
        default=False,
    )
    run_cfg.add_argument(
        '--device',
        type=str,
        default="cuda",
    )
    run_cfg.add_argument(
        '--num_workers',
        type=int,
        default=0,
    )
    run_cfg.add_argument(
        '--seed',
        type=int,
        default=0,
    )
    run_cfg.add_argument(
        '--task',
        type=str,
        required=True,
    )
    run_cfg.add_argument(
        '--setup',
        type=str,
        required=True,
    )
    return parser.parse_args()


if __name__ == "__main__":
    args = parse_args()
    from gflownet.config import Config, init_empty

    config = init_empty(Config())
    config.env_dir = f'./data/envs/{args.setup}'
    if args.setup == 'synflow_128':
        config.num_training_steps = 4000
    else:
        config.num_training_steps = 5000

    config.algo.action_sampling.sampling_ratio_reactbi = args.subsampling_ratio
    config.overwrite_existing_exp = args.override
    config.device = args.device
    config.log_dir = args.out_dir
    config.print_every = 1
    config.num_workers_retrosynthesis = 8
    config.seed = args.seed
    config.num_workers = 0
    config.replay.use = False
    config.replay.warmup = 0
    config.num_final_gen_steps = 100
    config.algo.max_len = 5

    config.opt.learning_rate = 1e-3
    config.algo.tb.Z_learning_rate = 1e-1
    config.algo.train_random_action_prob = 0.05

    # Replay Buffer
    config.replay.use = False
    config.replay.warmup = 0
    config.replay.capacity = 10_000

    if args.task == "gsk":
        config.cond.temperature.dist_params = [16, 48]
        trainer = GskSynthesisTrainer(config)
    elif args.task == 'jnk3':
        config.cond.temperature.dist_params = [16, 48]
        trainer = JnkSynthesisTrainer(config)
    elif args.task == 'seh':
        config.cond.temperature.dist_params = [16, 64]
        trainer = SehSynthesisTrainer(config)
    else:
        raise ValueError(f"Unknown task: {args.task}")
    trainer.run()
