import hydra
import gc
import torch


"""
            |  agent state  |   adv state  | agent num | adv num | action | traj length
tag         |       14      |      16      |     1     |    3    |   5    |     25
push        |       19      |      8       |     1     |    1    |   5    |     25
spread      |       18      |      -       |     3     |    0    |   5    |     25
connect4    |     2*6*7     |     2*6*7    |     1     |    1    |   7    |     42 (21)
holdem      |       72      |      72      |     1     |    1    |   4    |     50 (25)
boxing      |    84*84*6    |    84*84*6   |     1     |    1    |   18   |     256
tennis      |    84*84*6    |    84*84*6   |     1     |    1    |   18   |     128
badminton   |       17      |      17      |     1     |    1    |   15   |     60 (30)
"""


### To proceed with training, you must manually modify both the model name and the task name in the Hydra configuration to match your intended setup.
@hydra.main(config_path="./configs/ddgi", config_name="push", version_base=None)
def main(args):
    # setting for cache
    gc.collect()
    torch.cuda.empty_cache()

    #### ------------------ ###
    """
    If you want to train the embedding model, include only 'emb' in train_mode.
    If you want to train the DDGI policy, include only 'policy' in train_mode.
    To train both simultaneously, include both 'emb' and 'policy' in train_mode.
    """
    train_mode = ['policy'] # ['emb', 'policy']
    
    """
    Input the number of dataset size (list)
    """
    datalen = [500] # (10, 50, 100), (50, 100, 250)
    ### ------------------ ###


    if 'emb' in train_mode:
        from emb.emb_trainer import EmbTrainer
            
        # Train/Valid
        print(f"--- Model: 'Embedding', Env: '{args.env_name}' --- ")
        print('======================================================\n')
        EmbTrainer(args, max(datalen)).train(args)
        print(f"============ Finish Training 'Emb model' in env: {args.env_name} ============\n")
        
        
    for l in datalen:
        if 'policy' in train_mode:
            from demo.policy_trainer import PolicyTrainer
            #from demo.tran_trainer import TranTrainer
            # Train/Valid
            print(f"--- Model: 'Policy', Env: '{args.env_name}' --- ")
            print('======================================================\n')
            PolicyTrainer(args, l).train()
            print(f"============ Finish Training 'Policy model' in env: {args.env_name} ============\n")



if __name__ == "__main__":
    main()