import click
import logging
from pathlib import Path
from os.path import abspath, join, exists, dirname
import sys

# Add parent directory to path to find teacher_student package
sys.path.insert(0, dirname(dirname(abspath(__file__))))

from teacher_student.train_teacher_student import train_teacher_student
from teacher_student.exp_control import independent_sweep, combination_sweep, submit_parallel_sweep

@click.group()
def cli():
    pass

@click.command()
@click.option("--exp_id", help="Experiment folder containing the config file")
@click.option("--debug", default=False, help="Debug mode")
@click.option("--cluster", default=False, help="On the cluster?")
@click.option("--scratch_path", default=None, help="Scratch path")
@click.option("--sweep_mode", default="combination", help= "sweep mode, either 'independent' or 'combination'")
@click.option("--parallel", default=False, help="Use parallel sweep with SLURM job arrays")
@click.option("--max_concurrent_jobs", default=6, help="Maximum concurrent SLURM jobs for parallel sweep")
@click.option("--walltime", default="5-00:00:00", help="SLURM walltime limit")
@click.option("--gpu_mem", default="24g", help="GPU memory requirement")
@click.option("--mem_per_cpu", default="10g", help="CPU memory requirement")
@click.option("--submit_jobs", default=True, help="Actually submit jobs (False for dry run)")
def train(exp_id: str, debug: bool, cluster: bool, scratch_path: Path, sweep_mode: str,
          parallel: bool, max_concurrent_jobs: int, walltime: str, gpu_mem: str, 
          mem_per_cpu: str, submit_jobs: bool) -> None:
    
    
    # folders
    ROOT_DIR = dirname(dirname(abspath(__file__)))
    
    if scratch_path is None:
        exp_dir = join(ROOT_DIR, "experiments","training", exp_id)
    else:
        exp_dir = join(scratch_path)
    
    data_dir = join(ROOT_DIR, "data")
    print("Data dir:", data_dir)
    
    # loggers
    logger_info = logging.getLogger("logger_info")
    info_handler = logging.FileHandler(join(ROOT_DIR, "train_info.log"))
    logger_info.setLevel(logging.INFO)
    formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
    info_handler.setFormatter(formatter)
    logger_info.addHandler(info_handler)
    
    
    # Choose sweep approach based on parallel option
    if parallel:
        # Determine home experiment directory for config files
        home_exp_dir = join(ROOT_DIR, "experiments", "training", exp_id)
        
        # Use direct function call for parallel sweep
        submit_parallel_sweep(
            exp_dir=exp_dir,
            home_exp_dir=home_exp_dir,
            data_dir=data_dir,
            cluster=cluster,
            scratch_path=scratch_path,
            mode=sweep_mode,
            max_concurrent_jobs=max_concurrent_jobs,
            walltime=walltime,
            gpu_mem=gpu_mem,
            mem_per_cpu=mem_per_cpu,
            submit_jobs=submit_jobs
        )
    else:
        # Use decorator for sequential sweep
        @combination_sweep(exp_dir, mode=sweep_mode)
        def run_sweep(config, save_dir):
            train_teacher_student(
                config=config, 
                save_dir=save_dir, 
                data_dir=data_dir, 
                cluster=cluster)
        
        run_sweep()
    
    
cli.add_command(train)

if __name__ == "__main__":
    cli()
