import argparse
import yaml
import pprint

from src.seip_package.training.pretrain import * 
from src.seip_package.training.train import *

# This script is the main entry point for the training pipeline. 
# It takes in a configuration file as input, which contains all the necessary parameters for training a model. 
# The script then loads the configuration file, prints the parameters, and calls the appropriate function based on the mode specified (pretrain or finetune). 
# The pretrain function is used for training a model from scratch, while the finetune function is used for fine-tuning a pretrained model on a downstream task or to train a benchmark without pretraining.

parser = argparse.ArgumentParser()
parser.add_argument(
    '--config_name', type=str,
    help='name of config file')
parser.add_argument(
    '--mode', type=str, choices=['pretrain', 'finetune'],
    help='mode to run: pretrain or finetune',
    required=True)

def process_main(config_name, mode):
    # Load script params
    params = None
    with open(config_name, 'r') as y_file:
        params = yaml.load(y_file, Loader=yaml.FullLoader)
        pp = pprint.PrettyPrinter(indent=4)
        pp.pprint(params)
    
    if mode == 'pretrain':
        pretrain(params)
    elif mode == 'finetune':
        train(params)
    else:
        raise ValueError("Invalid mode. Choose 'pretrain' or 'finetune'.")

if __name__ == '__main__':
    args = parser.parse_args()
    process_main(args.config_name, args.mode)
