"""
Qwen2.5-7B base model + ppo

debug running command in single node:
python -m playground.orz_7b_ppo

Multi-node Training:

on master node, first run `ray start --head`
then on other nodes, run `ray start --address='<master-node-ip>:<master-node-port>'`
then on master node, run `NUM_NODE=4 python -m playground.orz_7b_ppo`

"""

import asyncio
import json
import os
from dataclasses import dataclass
from functools import cached_property
from typing import Any, Awaitable, Callable, List, Optional, Tuple

from loguru import logger
from omegaconf.listconfig import ListConfig
from typing_extensions import override

from thinker_task.exps.examples.ppo.ppo_base_exp import BasePPOExp, BasePPOExpConfig
from thinker_task.ppo.gen_data import CustomRewardTrainer
from playground.zero_setting_base import CustomDataset, EvalCustomDataset

NUM_NODE = int(os.environ.get("NUM_NODE", 1))

@dataclass
class PPOExpConfig(BasePPOExpConfig):
    run_name = f"test"

    use_compute_reward_fn: bool = True
    use_orm_score: bool = False

    # summary setting
    summary: bool = False
    
    # inference setting
    summary_min_token: int = 300 # minimum token for summary
    summary_max_token: int = 1000 # maximum token for summary / fast response
    verify_max_token: int = 6000 # maximum token for verify response
    slow_max_token: int = 6000 # maximum token for slow response
    summary_temperature: float = 0.6 # temperature for sampling summary        
    summary_skip: bool = False # skip summary stage
    verify_skip: bool = False # skip verify stage

    # summary reward setting
    summary_reward_coef: float = 1 # multipler to reward summary reward    
    fast_reward_coef: float = 1 # multipler to reward fast-response reward
    summary_consist_coef: float = 0.001 # multipler to consistency-score based reward
    summary_consist_end: bool = False # add all consist reward at the end
    summary_consist_mean: bool = False # mean over summary length
    summary_nonstop_discount: float = 1. # penalize non-stop response
    reward_right_format: float = 0.25 # reward for right format
    verify_reweight: bool = False # reweight verify reward based on fast response acc

    # summary sft setting
    summary_buffer_size: int = 512
    summary_packing_max_len: int = 48000
    summary_sft_max_len: int = 18000 # divide the sft training sample into division that has most summary_sft_max_len; -1 for not active
    summary_policy_update_steps: int = -1 # number of steps to update summary policy; -1 for not active    

    # multi-attempt setting
    multi_attempt: bool = False
    min_attempt: int = 1
    max_attempt: int = 3
    repeat_question: bool = False

    # add critic head to actor
    actor_value_coef: float = 0.
    actor_value_gammas: Tuple[float]  = (1.0, 0.9999, 0.999, 0.99, 0.97)
    

    # Conditional settings with production values first
    total_num_nodes: int = 8 * NUM_NODE

    # resource related settings
    ref_num_nodes: int = total_num_nodes
    ref_num_gpus_per_node: int = 1
    actor_num_nodes: int = total_num_nodes
    actor_num_gpus_per_node: int = 1
    critic_num_nodes: int = total_num_nodes
    critic_num_gpus_per_node: int = 1
    colocate_all: bool = True
    colocate_critic_reward: bool = True
    colocate_actor_ref: bool = True
    vllm_num_engines: int = total_num_nodes
    vllm_tensor_parallel_size: int = 1
    param_offload: bool = False
    adam_offload: bool = False
    zero_stage: int = 3
    use_grpo: bool = False
    
    grad_accum_dtype: str = "bf16"
    zpg: int = 1

    # path related settings    
    pretrain: Optional[str] = "Qwen/Qwen2.5-7B" # TODO: or put your downloaded model path here!
    critic_pretrain: Optional[str] = ""
    ref_pretrain: Optional[str] = ""
    reward_pretrain: Optional[str] = None
    save_path: str = f"large_data/checkpoints/{run_name}"
    save_interval: int = 50        
    ckpt_path: str = f"large_data/checkpoints/{run_name}"
    ckpt_interval: int = 10
    max_ckpt_num: int = 1
    max_ckpt_mem: int = 1e9
    load_checkpoint: bool = True
    tensorboard_log_dir: str = f"large_data/logs/{run_name}"

    # wandb setting
    use_wandb: bool = True
    wandb_api_key: str = None
    wandb_entity: str = None
    wandb_group: str = None
    wandb_project: str = "duality"
    wandb_run_name: str = run_name

    # MathTrain dataset and Math500 eval dataset
    # data related settings
    prompt_data: ListConfig = ListConfig(
        [
            "data/clean_orz_129k.json",
        ]
    )
    eval_prompt_data: ListConfig = ListConfig(
        [
            "data/eval_data/math500.json",
            "data/eval_data/aime24.json",
            "data/eval_data/aime25.json",
            "data/eval_data/gpqa_diamond.json",
            "data/eval_data/minerva_math.json",
            "data/eval_data/amc23.json",
            "data/eval_data/olympiadbench.json",
        ]
    )
    prompt_data_probs: ListConfig = ListConfig([1.0])

    # ppo related settings
    actor_learning_rate: float = 1e-6
    critic_learning_rate: float = 5e-6
    num_warmup_steps: int = 50
    prompt_max_len: int = 2048
    enable_prefix_caching: bool = True
    update_ref_every_epoch: bool = True
    advantage_normalize: bool = True

    num_episodes: int = 20
    rollout_batch_size: int = 16
    n_samples_per_prompt: int = 2
    micro_rollout_batch_size: int = 128

    policy_update_steps: int = 1
    critic_update_steps: int = 1
    micro_train_batch_size: int = 1
    micro_forward_batch_size: int = 1
    freezing_actor_steps: int = -1

    # kl loss
    init_kl_coef: float = 0    
    kl_loss_coef: float = 0.0
    use_kl_loss: bool = True
    use_kl_estimator_k3: bool = True

    entropy_coef: float = 0.0 # entropy regularization

    enable_eval: bool = True
    eval_interval: int = 10

    # score settings
    score_for_wrong_format: float = 0. # score for correct format, wrong answer

    # generate related settings
    packing_max_len: int = 16384
    generate_max_len: int = 8000  # TODO: change to larger later
    max_len: int = 8192  # TODO: change to larger later
    temperature: float = 1.0
    top_p: float = 1.0
    top_k: int = -1
    stop: ListConfig = ListConfig(["User:", "Human:", "Assistant:", "</answer>"])

    # grpo related settings    
    gpu_memory_utilization: float = 0.75 if use_grpo else 0.5    

    gamma: float = 1.0
    lambd: float = 1.0

    # resume setting
    resume_global_step: int = -1

class PPOExp(BasePPOExp):
    @cached_property
    def trainer(self):
        return CustomRewardTrainer(
            cfg=self.cfg,
            strategy=self.strategy,
            tokenizer=self.tokenizer,
            train_dataset=self.train_dataset,
            eval_dataset=self.eval_dataset,
            colocate_pg=self.get_colocate_pg,
        )

    @override
    @cached_property
    def train_dataset(self):
        dialogues = []
        for file_path in self.cfg.prompt_data:
            with open(file_path, "r") as f:
                for line in f:
                    dialogues.append(json.loads(line.strip()))  # Parse each line as a JSON object
        logger.info(f"Start processing {len(dialogues)} dialogues")
        prompts_dataset = CustomDataset(
            dialogues,
            self.tokenizer,
            self.cfg.prompt_max_len,
            self.strategy,
            pretrain_mode=False,
            num_processors=1,
            no_template=self.cfg.multi_attempt or self.cfg.summary,
            prompt_type=self.cfg.prompt_type,
        )
        logger.info(f"Finished processing {len(prompts_dataset)} dialogues")
        return prompts_dataset

    @override
    @cached_property
    def eval_dataset(self):
        dialogues = []
        for file_path in self.cfg.eval_prompt_data:
            with open(file_path, "r") as f:
                loaded_data = json.load(f)
                for loaded_data_item in loaded_data:
                    # only keep file name, without suffix
                    loaded_data_item["file_name"] = os.path.splitext(os.path.basename(file_path))[0]
                dialogues.extend(loaded_data)
        logger.info(f"Start processing {len(dialogues)} dialogues")
        prompts_dataset = EvalCustomDataset(
            dialogues,
            self.tokenizer,
            self.cfg.prompt_max_len,
            self.strategy,
            pretrain_mode=False,
            num_processors=1,
            no_template=self.cfg.multi_attempt or self.cfg.summary,
            prompt_type=self.cfg.prompt_type,
        )
        logger.info(f"Finished processing {len(prompts_dataset)} dialogues")
        return prompts_dataset


if __name__ == "__main__":
    exp = PPOExp().set_cfg(PPOExpConfig())
    logger.info(exp.get_cfg_as_str(exp.cfg))
    if not os.path.exists(exp.cfg.save_path):
        os.makedirs(exp.cfg.save_path, exist_ok=True)
    if not os.path.exists(exp.cfg.tensorboard_log_dir):
        os.makedirs(exp.cfg.tensorboard_log_dir, exist_ok=True)
    if not os.path.exists(exp.cfg.ckpt_path):
        os.makedirs(exp.cfg.ckpt_path, exist_ok=True)
    asyncio.run(exp.run())
