#!/usr/bin/env python3
# -*- coding: utf-8 -*-

# run as 
# python train_single.py --data_file '../data/listops/List*(add_10,max*.pkl' --model 'NRGPT' --n_embed 128 --n_layer 4 --learning_rate 5e-4 --min_lr 5e-6 --alpha 1.0 --device 0
# OR
# uv run python train_single.py --data_file '../data/List*(add_10,max*.pkl' --model 'NRGPT' --n_embed 128 --n_layer 4 --learning_rate 5e-4 --min_lr 5e-6 --alpha 1.0 --device 0

# import sys 
# sys.path.append("../src/")
from listops.training.train import  train
from listops.model.config import Config

import os
from glob import glob

import argparse

parser = argparse.ArgumentParser(description='Generate ListOps dataset.')
parser.add_argument('--data_file', type=str, default='../data/listops/List*.pkl', help='Path to the data file')
parser.add_argument('--model', type=str, default='NRGPT', help='Model type')
parser.add_argument('--n_embed', type=int, default=128, help='Embedding size')
parser.add_argument('--n_layer', type=int, default=4, help='Number of layers')
parser.add_argument('--learning_rate', type=float, default=5e-4, help='Learning rate')
parser.add_argument('--min_lr', type=float, default=5e-6, help='Minimum learning rate')
parser.add_argument('--alpha', type=float, default=1.0, help='Alpha parameter')
parser.add_argument('--max_iters', type=int, default=2000, help='Maximum iterations')
parser.add_argument('--device', type=str, default='cuda', help='Device to use (e.g., "cuda" or "cpu")')

args = parser.parse_args()

project_name = "ListOps-2026"

# DATA_DIR = '../data/'
data_files = glob(args.data_file)
print(data_files)

from torch import cuda

config = {    
    "data_file": data_files[0],
    "max_iters": args.max_iters,
    "early_stop": False,
    "model": args.model,
    "n_embed": args.n_embed,
    "n_layer": args.n_layer,
    "num_tests": 50,
    # good 
    # "learning_rate": 4e-3,
    # "min_lr": 5e-4,
    "learning_rate": args.learning_rate,
    "min_lr": args.min_lr,
    "alpha": args.alpha,
    "device": args.device if cuda.is_available() else 'cpu',
    }
# config = Config(**config)

train(config = config,)

