from header import * 

def main(args): 
    # Train each of the local models 
    out_clients, out_pooled = federated_gflownets(
        create_gfn, 
        create_env,
        unique_smp, 
        num_clients=args.num_clients, 
        epochs=args.epochs, 
        num_batches=args.num_batches_eval, 
        batch_size_train=args.batch_size_train, 
        batch_size_eval=args.batch_size_eval, 
        lr=args.lr, 
        create_env_args=create_env_args, 
        flow_args=flow_args, 
        is_phylogeny=(args.domain == 'phylogenetics') 
    ) 

    save_dir = f'{args.domain}/federated' 
    pathlib.Path(save_dir).mkdir(exist_ok=True, parents=True) 

    json.dump(
        out_clients, 
        open(f'{save_dir}/clients_dist.json', 'w') 
    ) 

    json.dump(
        out_pooled, 
        open(f'{save_dir}/pooled_dist.json', 'w') 
    )  

if __name__ == '__main__': 
    main(args)  