import os
import subprocess
import itertools
import argparse
import time
from pathlib import Path
from colorama import Fore
from pathlib import Path
from multiprocessing import Process, Queue, Lock


def get_default_cache_dir():
    if "HF_HOME" in os.environ:
        return os.environ["HF_HOME"]
    else:
        return os.path.expanduser("~/.cache/huggingface")


def get_available_gpus():
    import torch
    return list(range(torch.cuda.device_count()))


def get_log_path(log_dir, model, dataset_id):
    model = model.replace("/", "--")
    return Path(log_dir) / f"{model}_{dataset_id}"


def check_gpu_memory(gpu_id):
    try:
        result = subprocess.run(f"nvidia-smi --query-gpu=memory.used --format=csv,noheader --id={gpu_id}", shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
        if result.returncode != 0:
            print(f"Failed to check GPU memory: {gpu_id}")
            return False

        lines = result.stdout.decode("utf-8").strip().split("\n")
        if len(lines) > 1:
            print(f"Failed to check GPU memory: {gpu_id}")
            print(lines)
            return False

        memory_line = lines[0]

        memory_used = float(memory_line.replace("MiB", ""))
        if memory_used > 100:
            print(f"GPU memory is not released: {gpu_id} ({memory_used} MiB)")
            return False

        return True
    except subprocess.CalledProcessError as e:
        print(f"Failed to check GPU memory: {e}")
        return False


def wait_for_gpu_memory_release(gpu_id):
    is_gpu_memory_released = check_gpu_memory(gpu_id)
    while not is_gpu_memory_released:
        print(f"Waiting for GPU memory to be released: {gpu_id}")
        time.sleep(10)
        is_gpu_memory_released = check_gpu_memory(gpu_id)


def run_command(command, timeout, gpu_id):
    try:
        print(f"Running command: {command}")
        process = subprocess.Popen(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
        process.wait(timeout=timeout)
        return True
    except subprocess.TimeoutExpired:
        print(f"Command timed out: {command}")
    except subprocess.CalledProcessError as e:
        print(f"Command failed with error: {e}")

    # wait for the process to terminate
    process.wait()
    wait_for_gpu_memory_release(gpu_id)

    return False


def nsys_worker(queue, gpu_id, lock):
    os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)

    while True:
        with lock:
            if queue.empty():
                time.sleep(5)
                continue
            data = queue.get()
            print("Queue size:", queue.qsize())
            if data is None:
                break

        log_output_path, model, label, dataset_id, pretrained_dir, depth, temperature, threshold, eagle, original, vanilla = data

        gpu_cmd = (
            f"CUDA_VISIBLE_DEVICES={str(gpu_id)} kernprof -lv evaluate.py --model_id {model} --ea-model-path {pretrained_dir/label/'final'} --output_file {log_output_path}.jsonl --dataset_id {dataset_id} --top_base 60 --top_draft 10 --top_node 100 --depth {depth} --temperature {temperature} --threshold {threshold} {'--eagle' if eagle else ''} {'--original' if original else ''} {'--vanilla' if vanilla else ''} > {log_output_path}.log 2>&1"
        )

        print(Fore.BLUE + f"[GPU {gpu_id}] Running NSYS for {model}-{label}-{dataset_id}_t-{temperature}_thr-{threshold}{'_eagle' if eagle else ''}" + Fore.RESET)

        success = run_command(gpu_cmd, 120 * 60, gpu_id)

        if success:
            print(Fore.GREEN + f"[GPU {gpu_id}] Successfully ran NSYS for {model}-{label}-{dataset_id}_t-{temperature}_thr-{threshold}{'_eagle' if eagle else ''}{'_original' if original else ''}" + Fore.RESET)
        else:
            print(Fore.RED + f"[GPU {gpu_id}] Failed to run NSYS for {model}-{label}-{dataset_id}_t-{temperature}_thr-{threshold}{'_eagle' if eagle else ''}{'_original' if original else ''}" + Fore.RESET)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Run model benchmarks')
    parser.add_argument(
        '--model',
        required=True,
        nargs='+',
        choices=[
            'meta-llama/Llama-2-7b-chat-hf,llama2-7b-ce-3',
            'meta-llama/Llama-2-7b-chat-hf,llama2-7b-chained-3',
            'meta-llama/Llama-2-7b-chat-hf,llama2-7b-chained-tree-3',
            # 'meta-llama/Llama-2-7b-chat-hf,llama2-7b-chained-5',
            # 'meta-llama/Llama-2-7b-chat-hf,llama2-7b-chained-tree-5',
            'meta-llama/Llama-3.1-8B-Instruct,llama3_1-8b-ce-3',
            'meta-llama/Llama-3.1-8B-Instruct,llama3_1-8b-chained-3',
            'meta-llama/Llama-3.1-8B-Instruct,llama3_1-8b-chained-tree-3',
            'deepseek-ai/DeepSeek-R1-Distill-Llama-8B,r1-distill-llama-8b-ce-3',
            'deepseek-ai/DeepSeek-R1-Distill-Llama-8B,r1-distill-llama-8b-chained-3',
            'deepseek-ai/DeepSeek-R1-Distill-Llama-8B,r1-distill-llama-8b-chained-rl-3',
            'deepseek-ai/DeepSeek-R1-Distill-Llama-8B,r1-distill-llama-8b-chained-tree-3',
            'deepseek-ai/DeepSeek-R1-Distill-Llama-8B,r1-distill-llama-8b-chained-tree-3-topk2',
        ],
        help='Select the model',
    )
    parser.add_argument('--dataset_id', nargs='+', choices=["mt_bench", "humaneval", "alpaca", "gsm8k", "qa", "sum"], help='Dataset IDs to evaluate')
    parser.add_argument('--output_dir', type=Path, default='paper/logs', help='Output directory')
    parser.add_argument('--pretrained_dir', type=Path, default='paper', help='Pretrained model directory')
    parser.add_argument('--depth', type=int, default=7, help='Depth of the model')
    parser.add_argument('--gpus', nargs='+', type=int, default=get_available_gpus(), help='GPU indices')
    parser.add_argument('--temperature', nargs='+', type=float, default=[0.0], help='Temperature values for sampling')
    parser.add_argument('--eagle', nargs='+', type=int, default=[0], help='use eagle or not')
    parser.add_argument('--original', nargs='+', type=int, default=[0], help='use original or not')
    parser.add_argument('--vanilla', nargs='+', type=int, default=[0], help='use vanilla or not')
    parser.add_argument('--threshold', nargs='+', type=float, default=[0.0], help='threshold for tals')

    args = parser.parse_args()

    print(f"Models: {args.model}")
    print(f"Datasets: {args.dataset_id}")
    print(f"Output directory: {args.output_dir}")
    print(f"Pretrained directory: {args.pretrained_dir}")
    print(f"GPUs: {args.gpus}")

    Path(args.output_dir).mkdir(parents=True, exist_ok=True)

    queue = Queue()
    lock = Lock()

    processes = []
    for gpu_id in args.gpus:
        process = Process(target=nsys_worker, args=(queue, gpu_id, lock))
        process.start()
        processes.append(process)

    for gpu_id in args.gpus:
        wait_for_gpu_memory_release(gpu_id)

    cache = set()
    for temperature, threshold, dataset_id, model_label, eagle, original, vanilla in itertools.product(args.temperature, args.threshold, args.dataset_id, args.model, args.eagle, args.original, args.vanilla):
        def check():
            if threshold == 0.0 and "ce-3" in model_label and eagle == 0 and original == 0 and vanilla == 1:
                return True
            if threshold == 0.0 and "ce-3" in model_label and eagle == 0 and original == 1 and vanilla == 0:
                return True
            if threshold == 0.0 and "ce-3" in model_label and eagle == 1 and original == 0 and vanilla == 0:
                return True
            if threshold == 0.0 and "chained-3" in model_label and eagle == 1 and original == 0 and vanilla == 0:
                return True
            if threshold == 0.0 and "chained-rl-3" in model_label and eagle == 1 and original == 0 and vanilla == 0:
                return True
            if threshold in [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9] and "chained-tree-3" in model_label and eagle == 0 and original == 0 and vanilla == 0:
                return True
            if threshold == 0.0 and "llama2-7b-chained-tree-3" in model_label and eagle == 0 and original == 0 and temperature == 0.0 and vanilla == 0:
                return True
            if threshold == 0.0 and "llama2-7b-chained-tree-3" in model_label and eagle == 1 and original == 0 and temperature == 0.0 and vanilla == 0:
                return True
            if threshold == 0.0 and "llama3_1-8b-chained-tree-3" in model_label and eagle == 0 and original == 0 and temperature == 0.0 and vanilla == 0:
                return True
            if threshold == 0.0 and "llama3_1-8b-chained-tree-3" in model_label and eagle == 1 and original == 0 and temperature == 0.0 and vanilla == 0:
                return True
            if threshold == 0.0 and "r1-distill-llama-8b-chained-tree-3" in model_label and eagle == 0 and original == 0 and temperature == 0.0 and vanilla == 0:
                return True
            if threshold == 0.0 and "r1-distill-llama-8b-chained-tree-3" in model_label and eagle == 1 and original == 0 and temperature == 0.0 and vanilla == 0:
                return True
            return False

        # Check if the condition is met
        if not check():
            continue

        eagle = bool(eagle)
        original = bool(original)
        model, label = model_label.split(",")
        log_output_path = get_log_path(args.output_dir, f"{model}--{label}_d-{args.depth}_t-{temperature:.1f}_thr-{threshold:.1f}{'_eagle' if eagle else ''}{'_original' if original else ''}{'_vanilla' if vanilla else ''}", dataset_id)
        log_output_path.parent.mkdir(parents=True, exist_ok=True)

        data = (log_output_path, model, label, dataset_id, args.pretrained_dir, args.depth, temperature, threshold, eagle, original, vanilla)
        if os.path.exists(f"{log_output_path}.jsonl"):
            num_lines = sum(1 for _ in open(f"{log_output_path}.jsonl"))
            if num_lines == 80:
                print(f"Already evaluated: {log_output_path}.jsonl")
                continue
            else:
                print(f"File exists but not complete: {log_output_path}.jsonl ({num_lines} lines)")
        if data not in cache:
            cache.add(data)
            queue.put(data)

    for _ in range(len(args.gpus)):
        queue.put(None)

    for process in processes:
        process.join()
