import argparse
import importlib
import os
import random
import torch
import numpy as np
import logging
import time
import subprocess


def load_args():
    parser = argparse.ArgumentParser(
        description="Load config and print seq_path variable"
    )
    parser.add_argument("cfg_path", type=str, help="Path to the config file")
    args = parser.parse_args()
    return args


def set_seed(seed: int):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    print(f"Set seed to {seed} (torch, numpy, random)")


def import_config(config_path: str):
    spec = importlib.util.spec_from_file_location("config", config_path)
    assert spec is not None and spec.loader is not None, (
        f"Failed to load config from {config_path}"
    )
    config = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(config)
    return config


def load_config(config_path: str = None):
    if config_path is None:
        args = load_args()
        config_path = args.cfg_path

    if not os.path.exists(config_path):
        raise FileNotFoundError(f"Config file {config_path} does not exist.")

    config = import_config(config_path)
    return config


def get_free_gpu_memory():
    try:
        # Query free memory (MB) for all GPUs
        result = subprocess.run(
            ["nvidia-smi", "--query-gpu=memory.free", "--format=csv,nounits,noheader"],
            stdout=subprocess.PIPE,
            encoding="utf-8",
            check=True,
        )
        all_free = [int(x) for x in result.stdout.strip().split("\n")]

        # Read CUDA_VISIBLE_DEVICES
        visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES")
        if visible_devices:
            # Only return memory for visible devices
            indices = [int(x) for x in visible_devices.split(",")]
            free_visible = [all_free[i] for i in indices]
        else:
            # Return all devices if CUDA_VISIBLE_DEVICES is not set
            free_visible = all_free

    except Exception as e:
        print(f"Error getting free GPU memory: {e}")
        free_visible = [9999999]

    return free_visible


def wait_for_gpu_memory(required_free_memory: float, time_least_wait: int):
    start_time = time.time()
    cnt = 0
    while True:
        free_memory = get_free_gpu_memory()
        elapsed_time = time.time() - start_time
        print(
            f"\rFree GPU Memory: {free_memory}MB | Elapsed Time: {int(elapsed_time)}s",
            end="",
        )

        if all(memory >= required_free_memory for memory in free_memory):
            # print("\nSufficient GPU memory available. Proceeding with the program...")
            cnt += 1
            if cnt >= time_least_wait:
                print(
                    "\nSufficient GPU memory available. Proceeding with the program..."
                )
                break
            time.sleep(0.1)
        else:
            time.sleep(1)


def setup_logger(output_dir: str):
    # Create dated log file
    os.makedirs(output_dir, exist_ok=True)
    log_filename = f"{output_dir}/log.txt"
    logging.basicConfig(
        level=logging.INFO,
        format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
        handlers=[
            logging.FileHandler(log_filename),
            logging.StreamHandler(),
        ],
    )
    logger = logging.getLogger(__name__)
    return logger
