import os
import subprocess
import sys
from socket import socket
import argparse
import json
import logging

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

script_dir = os.path.dirname(os.path.abspath(__file__))
parent_dir = os.path.abspath(os.path.join(script_dir, os.pardir))

def run_cmd(cmd):
    try:
        p = subprocess.Popen(
            cmd,
            stdout=subprocess.PIPE,
            stderr=subprocess.STDOUT,  # Merge stderr into stdout
            shell=True,
            universal_newlines=True  # Ensure output is in text mode
        )
        
        error_detected = False
        error_keywords = [
            "ERROR",
            "Error",
            "error"
            "Unrecognized model",
            "failed",
            "exception",
            "Traceback"
        ]
        
        # Read output in real-time and detect errors
        while True:
            line = p.stdout.readline()
            if not line:
                break
            logging.info(line.rstrip())  # Log normally
            
            # Check if any error keywords are present
            if any(keyword.lower() in line.lower() for keyword in error_keywords):
                error_detected = True
                logging.error(f"Detected error in output: {line.strip()}")
        
        # Wait for process to finish
        returncode = p.wait()
        
        # If errors were detected or return code is non-zero, return False
        if error_detected or returncode != 0:
            logging.error(f"Command failed (returncode={returncode}, errors detected)")
            return False
        
        return True  # Return True indicates success
        
    except Exception as e:
        logging.error(f"Unexpected error running command: {e}")
        return False

def process(config_path, config):
    if not os.path.isabs(config_path):
        config_path = os.path.join(parent_dir, config_path)

    job_type = config["job_type"]
    
    # Knowledge Distillation tasks
    if job_type in ['kd_train_only', 'sft']:
        cmd_train = [
            'accelerate', 'launch',
            '--config_file', os.path.join(parent_dir, 'configs/accelerate_config/multi_gpu.yaml'),
            os.path.join(script_dir, 'train/kd_sft.py'),
            '--config', config_path
        ]
        cmd_train = ' '.join(cmd_train)
        logging.info(f"Running command: {cmd_train}")
        run_cmd(cmd_train)

    elif job_type in ['kd_api', 'kd_local']:
        cmd_infer = [
            'python', os.path.join(script_dir, 'infer/infer_vllm.py'),
            '--config', config_path
        ]
        cmd_infer = ' '.join(cmd_infer)
        logging.info(f"Running command: {cmd_infer}")
        infer_success = run_cmd(cmd_infer)
        if infer_success:
            cmd_train = [
                'accelerate', 'launch',
                '--config_file', os.path.join(parent_dir, 'configs/accelerate_config/multi_gpu.yaml'),
                os.path.join(script_dir, 'train/kd_sft.py'),
                '--config', config_path
            ]
            cmd_train = ' '.join(cmd_train)
            logging.info(f"Running command: {cmd_train}")
            run_cmd(cmd_train)
        else:
            logging.error("Infer failed, skipping training")
            
    elif job_type in ['naive_cl_local', 'self_evolve_cl_local']:
        if not os.path.exists(config['dataset']['trainset_path']):
            cmd_infer = [
                'python', os.path.join(script_dir, 'infer/infer_vllm.py'),
                '--config', config_path
            ]
            cmd_infer = ' '.join(cmd_infer)
            logging.info(f"Running command: {cmd_infer}")
            infer_success = run_cmd(cmd_infer)
        else:
            infer_success = True
        if infer_success:
            cmd_train = [
                'accelerate', 'launch',
                '--config_file', os.path.join(parent_dir, 'configs/accelerate_config/multi_gpu.yaml'),
                os.path.join(script_dir, 'train/naive_cl.py' if 'naive' in job_type else 'train/self_evolve_cl.py'),
                '--config', config_path,
            ]
            cmd_train = ' '.join(cmd_train)
            logging.info(f"Running command: {cmd_train}")
            run_cmd(cmd_train)
        else:
            logging.error("Infer failed, skipping training")

    elif job_type in ['eval_model_hf']:
        cmd = [
            'accelerate', 'launch',
            '--config_file', os.path.join(parent_dir, 'configs/accelerate_config/multi_gpu.yaml'),
            os.path.join(script_dir, f'eval/eval_model_hf.py'),
            '--config', config_path
        ]  
        cmd = ' '.join(cmd)
        logging.info(f"Running command: {cmd}")
        run_cmd(cmd)

    elif job_type in ['eval_model_vllm']:
        cmd = [
            'python', os.path.join(script_dir, f'eval/eval_model_hf.py'),
            '--config', config_path
        ]  
        cmd = ' '.join(cmd)
        logging.info(f"Running command: {cmd}")
        run_cmd(cmd)

    elif job_type in ['generate_questions']:
        cmd = [
            'python', os.path.join(script_dir, f'generate/generate_vllm.py'),
            '--config', config_path
        ]  
        cmd = ' '.join(cmd)
        logging.info(f"Running command: {cmd}")
        run_cmd(cmd)

    else:
        logging.error(f"Unknown job type: {job_type}")
        sys.exit(1)

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', type=str, required=True, help='path to the json config file')
    args = parser.parse_args()
    config_path = args.config
    config = json.load(open(config_path))
    process(config_path, config) 

if __name__ == '__main__':
    main()
