"""
Usage
>>> uv run scripts/qsub_crbench.py --reserved_queue_name R327976 --crbench_task AdditionTaskSmall
>>> tail -f ~/lltm-{taskname}.o{job_id}
"""

import os 
import argparse 
import subprocess 
import importlib 
from omegaconf import OmegaConf 


def get_resources (resource_type :str ="rt_HF",reserved_queue_name :str |None =None ):
    if reserved_queue_name :
        return f"#PBS -q {reserved_queue_name }\n#PBS -v RTYPE={resource_type }"
    else :
        return f"#PBS -q {resource_type }"

def generate_sbatch (
base :str ,
env_path :str ,
cache_dir :str ,
logdir :str ,
num_nodes :int ,
reserved_queue_name :str ,
resource_type :str ,
group_name :str ,
job_name :str ,
extra :str ="",
walltime :str ="168:00:00",
):
    return f"""#!/bin/sh
#PBS -P {group_name }
#PBS -N lltm-{job_name }
#PBS -k oe
#PBS -j oe
#PBS -o {cache_dir }
{get_resources (resource_type ,reserved_queue_name )}
#PBS -l select={num_nodes }
#PBS -l walltime={walltime }

cd $PBS_O_WORKDIR

# Setup environment
source /etc/profile.d/modules.sh
module load cuda/12.1/12.1.1
module load cudnn/9.5/9.5.1
module load hpcx/2.20
module load nccl/2.23/2.23.4-1


echo "Starting the training script..."
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7

uv run main.py --base {base } --name {job_name } --logdir {logdir } {extra }
# Evaluation
CKPT={logdir }/{job_name }/checkpoints/last.ckpt
uv run scripts/run_crbench.py --ckpt $CKPT --num_gpus=8 --num_examples=100
"""


def get_obj_from_str (string ):
    module ,cls =string .rsplit (".",1 )
    return getattr (importlib .import_module (module ,package =None ),cls )


def main (
crbench_task :str ="AdditionTaskSmall",
name :str |None =None ,
base :str ="configs/cr/ltm-0.5b-base-h200.yaml",
logdir :str ="crbench_logs",
group_name :str ="gcg51585",
resource_type :str ="rt_HF",
reserved_queue_name :str |None =None ,
num_nodes :int =1 ,
cache_dir :str =".qsub_cache",
extra :str ="",
walltime :str ="168:00:00",
):
    assert resource_type =="rt_HF","Only rt_HF is currently supported for CRBench tasks."
    assert num_nodes ==1 ,"Only 1 node is currently supported for CRBench tasks."
    work_dir =os .path .abspath (os .path .join (os .path .dirname (__file__ ),".."))
    env_path =os .environ .get ("ENV_PATH",os .path .join (work_dir ,".venv"))
    assert os .path .exists (env_path ),f"Conda environment {env_path } does not exist"
    try :
        crbench_string =f"crbench.{crbench_task }"
        _ =get_obj_from_str (crbench_string )
    except :
        raise ImportError (f"CRBench task {crbench_task } not found")
    if name is None :
        name =crbench_task .split (".")[-1 ].lower ()

    config =OmegaConf .load (base )
    config .data .params .crbench_task_config .target =crbench_string 

    config_path =f"configs/cr/{name }-ltm-0.5b.yaml"
    OmegaConf .save (config ,config_path )

    sbatch_script =os .path .join (cache_dir ,f"{name }.sh")
    print (f"Writing qsub script to {sbatch_script }...")
    print ("#"*100 )
    print (f"group_name: {group_name }")
    print (f"resource_type: {resource_type }")
    print (f"reserved_queue_name: {reserved_queue_name }")
    print (f"num_nodes: {num_nodes }")
    print (f"cache_dir: {cache_dir }")
    print (f"base: {config_path }")
    print (f"logdir: {logdir }")
    os .makedirs (cache_dir ,exist_ok =True )
    with open (sbatch_script ,"w",encoding ="utf-8")as fw :
        fw .write (
        generate_sbatch (
        base =config_path ,
        logdir =logdir ,
        group_name =group_name ,
        resource_type =resource_type ,
        reserved_queue_name =reserved_queue_name ,
        num_nodes =num_nodes ,
        job_name =name ,
        extra =extra ,
        env_path =env_path ,
        cache_dir =cache_dir ,
        walltime =walltime ,
        )
        )

    print ("Submitting the job...")
    submit_command =f"qsub -P {group_name } {sbatch_script }"
    try :
        print (submit_command )
        output =subprocess .run (
        submit_command .split (),capture_output =True ,text =True ,check =True 
        ).stdout 
        print (output )
    except subprocess .CalledProcessError as e :
        print (e .stdout )
        raise e 


if __name__ =="__main__":
    parser =argparse .ArgumentParser ()
    parser .add_argument ("--crbench_task",type =str ,required =True )
    parser .add_argument ("--name",type =str ,required =False ,default =None )
    parser .add_argument (
    "--base",
    type =str ,
    required =False ,
    default ="configs/cr/ltm-0.5b-base-h200.yaml",
    )
    parser .add_argument ("--logdir",type =str ,required =False ,default ="crbench_logs")
    parser .add_argument ("--cache_dir",type =str ,required =False ,default =".qsub_cache")
    parser .add_argument ("--group_name",type =str ,required =False ,default ="gcg51585")
    parser .add_argument ("--resource_type",type =str ,required =False ,default ="rt_HF")
    parser .add_argument ("--reserved_queue_name",type =str ,required =False ,default =None )
    parser .add_argument ("--num_nodes",type =int ,required =False ,default =1 )
    parser .add_argument ("--extra",type =str ,required =False ,default ="")
    parser .add_argument ("--walltime",type =str ,required =False ,default ="168:00:00")
    args =parser .parse_args ()
    main (**vars (args ))
