import argparse
import os
from runner import Runner
from utils import parse_command_line_args




def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model', type=str, default='SteerRec', help='Model name')
    parser.add_argument('--dataset_code', type=str, default='O', help="Dataset code (e.g., 'O' for Sports, 'B' for Beauty, 'T' for Toys).")

    parser.add_argument('--mu', type=float, default=0.4, help="Loss balancing coefficient between reconstruction and GAL.")
    parser.add_argument('--margin', type=float, default=0.1, help='Margin for the Guidance Alignment Triplet Loss (GAL).')
    parser.add_argument('--neg_samples', type=int, default=64, help='Number of negative items to sample for constructing the negative condition.')

    return parser.parse_known_args()


if __name__ == '__main__':
    os.chdir(os.path.dirname(os.path.realpath(__file__)))
    
    args, unparsed_args = parse_args()
    command_line_configs = parse_command_line_args(unparsed_args)
    
    config_dict = {**vars(args), **command_line_configs}

    runner = Runner(
        model_name=config_dict['model'],
        config_dict=config_dict
    )
    runner.run()