from header import * 

def main(args): 

    torch.manual_seed(args.seed) 
    np.random.seed(args.seed) 

    multiclient_out = multiclient(
        create_gfn, 
        create_var, 
        create_env, 
        create_var_prod, 
        samples_from_state, 
        unique_smp, 
        num_clients=args.num_clients, 
        epochs=args.epochs, 
        batch_size_train=args.batch_size_train, 
        batch_size_sampling=args.batch_size_eval, 
        num_batches=args.num_batches_eval, 
        lr=args.lr, 
        create_env_args=create_env_args, 
        flow_args=flow_args  
    )

    save_dir = f'{args.domain}/variational/{args.seed}' 
    pathlib.Path(save_dir).mkdir(exist_ok=True, parents=True) 
    
    current_datetime = datetime.now().strftime('%Y%m%d%H%M%S') 

    # save the models 
    torch.save(
        multiclient_out['gfn'], 
        f'{save_dir}/gfn.pt' 
    )

    torch.save(
        multiclient_out['var'], 
        f'{save_dir}/var.pt' 
    )

    torch.save(
        multiclient_out['gfn_central'], 
        f'{save_dir}/gfn_central.pt' 
    )


    # save the distances 
    json.dump(
        {
            'gfn_l1': multiclient_out['gfn_l1agg'], 
            'var_l1': multiclient_out['var_l1agg'], 
            'gfn_central_l1': multiclient_out['gfn_central_l1agg'], 
        }, 
        open(f'{save_dir}/l1.json', 'w') 
    )

    # save the a subset of the rewards (in float16 format)
    torch.save(
        multiclient_out['gfn_log_rewards'].to(dtype=torch.float16), 
        f'{save_dir}/gfn_log_rewards_float16.pt' 
    ) 

    torch.save(
        multiclient_out['var_log_rewards'].to(dtype=torch.float16), 
        f'{save_dir}/var_log_rewards_float16.pt' 
    )

    torch.save(
        multiclient_out['gfn_central_log_rewards'].to(dtype=torch.float16), 
        f'{save_dir}/gfn_central_log_rewards_float16.pt' 
    ) 

    with open(f'{save_dir}/datetime.txt', 'w') as stream: 
        stream.write(current_datetime) 

if __name__ == '__main__': 
    main(args) 