# llmHMER/QwenHMER/config_utils.py

import os
import subprocess
from pathlib import Path
from typing import Any, Dict

import yaml
import time

class ConfigError(Exception):
    pass


class ConfigFileNotFoundError(ConfigError):
    pass


class ConfigParseError(ConfigError):
    pass


def read_yaml_config(file_path: str) -> Dict[str, Any]:
    """
    read yaml config file
    :param file_path: yaml file path
    :return: dict config
    """
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            config = yaml.safe_load(f)
        if not isinstance(config, dict):
            raise ConfigParseError(f"Config must be a dictionary, got {type(config)}")
        return config
    except FileNotFoundError as e:
        raise ConfigFileNotFoundError(f"Config file not found: {file_path}") from e
    except yaml.YAMLError as e:
        raise ConfigParseError(f"Failed to parse config file: {e}") from e
    except Exception as e:
        raise ConfigError(f"Unexpected error reading config: {e}") from e


def write_yaml_config(config: Dict[str, Any], file_path: str, sort_keys: bool = False) -> None:
    """
    write config to yaml file
    :param config: config to write
    :param file_path: output file path
    :param sort_keys: whether to sort keys
    """
    try:
        with open(file_path, 'w', encoding='utf-8') as f:
            yaml.safe_dump(
                config,
                f,
                default_flow_style=False,
                sort_keys=sort_keys,
                allow_unicode=True,
                indent=2
            )
    except (yaml.YAMLError, IOError) as e:
        raise Exception(f"Error writing YAML file: {e}")
    
    

def check_existing_checkpoints(output_dir: str, required_checkpoints: int = 15) -> bool:
    """
    check if there are enough checkpoints in the output directory
    
    Args:
        output_dir: output directory
        required_checkpoints: required checkpoints, default is 15
        
    Returns:
        bool: if there are enough checkpoints, return True, otherwise return False
    """
    # check if the output directory exists
    if not os.path.exists(output_dir):
        print(f"Output directory '{output_dir}' does not exist.")
        return False
    
    try:
        # get all checkpoint directories
        checkpoint_dirs = [
            d for d in Path(output_dir).iterdir() 
            if d.is_dir() and d.name.startswith("checkpoint-")
        ]
        
        if not checkpoint_dirs:
            print(f"No checkpoint directories found in '{output_dir}'")
            return False
            
        # check if there are enough checkpoint folders, and each folder has safetensors index file
        valid_checkpoints = 0
        for checkpoint_dir in checkpoint_dirs:
            # check if the model file exists
            model_files = list(checkpoint_dir.glob("*.safetensors*"))
            if model_files:
                valid_checkpoints += 1
        
        # if the number of valid checkpoints is 0, return False
        if valid_checkpoints == 0:
            print(f"No valid checkpoints found in '{output_dir}'")
            return False
            
        # check if the number of valid checkpoints is enough
        if valid_checkpoints >= required_checkpoints:
            print(f"Found {valid_checkpoints} valid checkpoints in '{output_dir}'")
            return True
        else:
            print(f"Found only {valid_checkpoints} valid checkpoints, but {required_checkpoints} are required.")
            # if the number of valid checkpoints is not enough, but at least one valid checkpoint exists, return True
            return True if valid_checkpoints > 0 else False
            
    except Exception as e:
        print(f"Error checking checkpoints: {e}")
        return False


def execute_command(command, show_output=True):
    """
    Execute a system command, measure execution time, and handle errors.
    
    Args:
        command (str): The command to execute
        show_output (bool): Whether to display command output in real-time
                           If False, output will be captured but not displayed
    
    Returns:
        dict: A dictionary containing:
            - success (bool): Whether the command executed successfully
            - return_code (int): The command's return code
            - execution_time (float): Time taken to execute in seconds
            - output (str): Command output (if show_output=False)
            - error (str): Error message (if show_output=False)
    """
    start_time = time.time()
    result = {"success": False, "return_code": None, "execution_time": None}
    
    try:
        if show_output:
            # Execute with output displayed to console
            return_code = subprocess.call(command, shell=True)
            result["return_code"] = return_code
            result["success"] = return_code == 0
        else:
            # Capture output instead of displaying it
            process = subprocess.Popen(
                command, 
                shell=True,
                stdout=subprocess.PIPE,
                stderr=subprocess.PIPE,
                universal_newlines=True
            )
            stdout, stderr = process.communicate()
            result["return_code"] = process.returncode
            result["success"] = process.returncode == 0
            result["output"] = stdout
            result["error"] = stderr
            
    except Exception as e:
        result["error"] = str(e)
        raise e       

    result["execution_time"] = time.time() - start_time
    print(f"Command: {command}")
    print(f"Return code: {result['return_code']}")
    print(f"Execution time: {result['execution_time']:.2f} seconds")
    if result["success"]:
        print("Command executed successfully.")
        
    return result

def run_command(command: str):
    """
    use subprocess to run shell command, and output log in real-time. if the command fails, raise an exception to terminate the program.
    
    Args:
        command: command to run
        
    Returns:
        int: return code of the command execution
        
    Raises:
        RuntimeError: if the command execution fails
    """
    try:
        # use subprocess.run instead of Popen to simplify the logic
        process = subprocess.run(
            command,
            shell=True,
            stdout=subprocess.PIPE,
            stderr=subprocess.STDOUT,
            text=True,
            check=False,  # do not raise an exception automatically, we will manually check the return code
        )
        
        # output the command result
        if process.stdout:
            print(process.stdout)
            
        # check the command execution result, if it fails, raise an exception
        if process.returncode != 0:
            raise RuntimeError(f"Command '{command}' failed with return code {process.returncode}")
            
        return process.returncode
        
    except Exception as e:
        # capture all exceptions, including the command does not exist
        raise RuntimeError(f"Error executing command '{command}': {e}")
        
    return 0
