import os
import json
import argparse
import importlib
import torch
import socket
import ray
from ray.runtime_env import RuntimeEnv
import asyncio
from datetime import datetime
from omegaconf.listconfig import ListConfig

from transformers import AutoModelForCausalLM
from thinker_task.ppo.utils import create_vllm_engines
from thinker_task.exp_engine.parallels.orz_distributed_c10d import orz_init_process_group
from playground.ppo_base import PPOExp

def update_vllm_weight(pretrain, vllm_engines, model_update_group):
    torch.cuda.empty_cache()  

    model = AutoModelForCausalLM.from_pretrained(pretrain, torch_dtype=torch.bfloat16)
    count, num_params = 0, len(list(model.named_parameters()))
    for name, param in model.named_parameters():
        count += 1

        # Launch remote update calls for each vLLM engine.
        refs = [
            engine.update_weight.remote(
                name,
                dtype=param.dtype,
                shape=param.shape,
                empty_cache=(count == num_params)
            )
            for engine in vllm_engines
        ]
        torch.distributed.broadcast(param.data, src=0, group=model_update_group)
        ray.get(refs)

def init_vllm_engines_actor_group(cfg, vllm_engines):
    master_address = ray._private.services.get_node_ip_address()
    with socket.socket() as sock:
        sock.bind(("", 0))
        master_port = sock.getsockname()[1]

    vllm_num_engines, vllm_tensor_parallel_size = (
        cfg.vllm_num_engines,
        cfg.vllm_tensor_parallel_size,
    )
    world_size = vllm_num_engines * vllm_tensor_parallel_size + 1

    backend = getattr(cfg, "vllm_sync_backend", "nccl")
    # https://github.com/OpenRLHF/OpenRLHF/issues/313
    import vllm

    if vllm.__version__ > "0.4.2" and os.getenv("NCCL_P2P_DISABLE", "0") == "0":
        backend = "gloo"

    refs = [
        engine.init_process_group.remote(
            master_address,
            master_port,
            i * vllm_tensor_parallel_size + 1,
            world_size,
            "openrlhf",
            backend=backend,
        )
        for i, engine in enumerate(vllm_engines)
    ]
    model_update_group = orz_init_process_group(
        backend=backend,
        init_method=f"tcp://{master_address}:{master_port}",
        world_size=world_size,
        rank=0,
        group_name="openrlhf",
    )

    ray.get(refs) 
    return model_update_group    
    

if __name__ == "__main__":

    run_name = "thinker_r1_5b"  
    suffix = "_eval_1t"
    multi_iter = 1 # numer of sample
    eval_data_ls = [
        "data/eval_data/math500.json",
        "data/eval_data/aime24.json",
        "data/eval_data/aime25.json",
        "data/eval_data/gpqa_diamond.json",
        "data/eval_data/amc23.json",
        #"data/eval_data/college_math.json",
        "data/eval_data/minerva_math.json",
        "data/eval_data/olympiadbench.json",
    ]
    temperature = 1 # not None to override
    only_last_step = False # eval only last step
    eval_global_step = None # restrict to a specific global step
    override_ckps = None # [(0, "large_data/base/Qwen/Qwen2.5-1.5B")]

    submodule = run_name
    module_path = f"playground.{submodule}"
    module = importlib.import_module(module_path)
    PPOExpConfig_ = getattr(module, "PPOExpConfig_")

     # initialize the ray cluster
    _temp_dir = os.environ.get("RAY_TEMP_DIR", None)
    use_ib0 = os.environ.get("USE_IB0", "").lower() == "true"
    env_vars = {
        "NCCL_DEBUG": "WARN",
        "PYTORCH_CUDA_ALLOC_CONF" :"expandable_segments:False",    
    }
    if use_ib0:
        env_vars.update({
            "NCCL_NET_GDR_LEVEL": "2",
            "NCCL_SOCKET_IFNAME": "ib0",
            "NCCL_IB_DISABLE": "0",
            #"NCCL_IB_HCA": "mlx5_3", 
        })

    ray.init(
        runtime_env=RuntimeEnv(
            address="auto",
            env_vars=env_vars,
        ),
        _temp_dir=_temp_dir,
    )

    exp = PPOExp().set_cfg(PPOExpConfig_())
    exp.cfg.wandb_run_name = exp.cfg.wandb_run_name + suffix
    # add folder with format datetime_suffix
    exp.cfg.save_path = os.path.join(exp.cfg.save_path, datetime.now().strftime("%Y%m%d_%H%M%S") + suffix)
    exp.cfg.eval_prompt_data = ListConfig(eval_data_ls)
    if temperature is not None:
        exp.cfg.temperature = temperature
    exp.cfg.colocate_with_actor = False
    exp.cfg.vllm_num_engines = 8 # single node

    # create save pth if not exist
    if not os.path.exists(exp.cfg.save_path):
        os.makedirs(exp.cfg.save_path, exist_ok=True)
    print("Save path: ", exp.cfg.save_path)

    file_names = os.listdir(os.path.join(exp.cfg.ckpt_path, "_actor_hf"))
    if override_ckps is None:
        ckps = []    
        for file_name in file_names:
            if file_name.startswith("global_step"):
                global_step = int(file_name.split("_")[-1])
                if eval_global_step is not None and global_step not in eval_global_step:
                    continue
                ckps.append((global_step, os.path.join(exp.cfg.ckpt_path, "_actor_hf", file_name)))
        ckps.sort(key=lambda x: x[0])
    else:
        ckps = override_ckps
    print(f"Checkpoints found: {ckps}")    

    if only_last_step:
        ckps = [ckps[-1]]
    print(f"Evaluating on: {ckps}")
    
    trainer = exp.trainer
    asyncio.run(trainer.init_vllm_engines())
    model_update_group = init_vllm_engines_actor_group(
        exp.cfg, trainer.vllm_engines
    )
    backload_tasks = []
    for engine in trainer.vllm_engines:
        backload_tasks.append(engine.backload_to_gpu.remote())
    ray.get(backload_tasks)

    # loop over all actors checkpoints in the checkpoint path
    for global_step, pretrain in ckps:
        # update the weights
        print(f"Loading checkpoint {pretrain}")        
        update_vllm_weight(pretrain, trainer.vllm_engines, model_update_group)
        print(f"Loaded checkpoint {pretrain}")
        trainer.global_step = global_step
        status_eval = asyncio.run(trainer.eval(multi_iter=multi_iter))

        status_eval = {f"eval/{k}": v for k, v in status_eval.items()}
        status_eval["train/global_step"] = trainer.global_step
        trainer._wandb.log(status_eval)

    trainer._wandb.finish()



   


