import sys
import os
sys.path.append(os.path.abspath('./'))
import hydra
import torch
import gc


# To proceed with training, you must manually modify both the model name and the task name in the Hydra configuration to match your intended setup.
@hydra.main(config_path="../configs/dd", config_name="box", version_base=None)
def main(args):
    # setting for cache
    gc.collect()
    torch.cuda.empty_cache()
    
    """
    Input the number of dataset size (list)
    """
    datalen = [10, 50, 100] # 10, 50, 100 # 50, 100, 250
    
    for l in datalen:
        # BC
        if args.model_name == 'bc':
            from BC.trainer import BCTrainer
            print(f"--- Model: 'BC', Env: '{args.env_name}' --- ")
            print('======================================================\n')
            BCTrainer(args, l).train()
            print(f"============ Finish Training 'BC' in env: {args.env_name} ============\n")
        
        
        # DBC
        elif args.model_name == 'dp':
            from DBC.dp import DPTrainer
            print(f"--- Model: 'DP', Env: '{args.env_name}' --- ")
            print('======================================================\n')
            DBCTrainer(args, l)
            print(f"============ Finish Training 'DP' in env: {args.env_name} ============\n")
        
        
        # DD
        elif args.model_name == 'dd':
            from DBC.dd import DDTrainer
            print(f"--- Model: 'DD', Env: '{args.env_name}' --- ")
            print('======================================================\n')
            DDTrainer(args, l)
            print(f"============ Finish Training 'DD' in env: {args.env_name} ============\n")
        
        
        # FB
        elif args.model_name == 'dbc':
            from DBC2.ddpm import DDPMTrainer
            print(f"--- Model: 'DBC', Env: '{args.env_name}' --- ")
            print('====================== step 1 ========================\n')
            #DDPMTrainer(args, l)
            print('====================== step 2 ========================\n')
            from DBC2.bc import BCTrainer
            BCTrainer(args, l).train()
            print(f"============ Finish Training 'DBC' in env: {args.env_name} ============\n")
            

if __name__ == "__main__":
    main()