from src.packllm import PackLLM, PackLLMConfig

import argparse
import os
from typing import List
from src.utils import get_base_url, get_lora_pools


def parse_args():
    parser = argparse.ArgumentParser(description="PackLLM Search")
    parser.add_argument("--N", type=int, default=10)
    parser.add_argument("--method", type=str, default="simple")
    parser.add_argument("--tao", type=float, default=0.1)
    parser.add_argument("--topK", type=int, default=5)
    parser.add_argument("--test_tasks", type=str, nargs="+", required=True)
    parser.add_argument("--model_path", type=str, required=True)
    parser.add_argument("--lora_dir", type=str, required=True)
    parser.add_argument("--tasks", type=str, nargs="+", required=True)
    parser.add_argument("--task_weights", type=float, nargs="+", help="Weights for each task (must match number of tasks)")
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--early_stop", action="store_true")
    parser.add_argument("--early_stop_iter", type=int, default=5)
    parser.add_argument("--plot_enabled", action="store_true", help="Enable plotting")
    parser.add_argument("--combine_method", type=str, default="linear")
    parser.add_argument(
        '--ports', type=int, nargs='+', default=[9112],
    )
    return parser.parse_args()

def main():
    args = parse_args()
    
    config = PackLLMConfig(
        N=args.N,
        method=args.method,
        topK=args.topK,
        tao=args.tao,
        model_name_or_path=args.model_path,
        pools=get_lora_pools(args.lora_dir),
        llm_base_url=get_base_url(args.ports),
        tasks=args.tasks,
        test_tasks=args.test_tasks,
        task_weights=args.task_weights,
        seed=args.seed,
        early_stop=args.early_stop,
        early_stop_iter=args.early_stop_iter,
        plot_enabled=args.plot_enabled,
        combine_method=args.combine_method,
    )
    
    packllm = PackLLM(config=config)
    packllm.search()
    

if __name__ == '__main__':
    main()