import asyncio
import os
from dataclasses import dataclass
from typing import Optional
from datetime import datetime
from loguru import logger

from omegaconf.listconfig import ListConfig

from thinker_task.exps.examples.ppo.ppo_base_exp import BasePPOExpConfig
from playground.ppo_base import PPOExp, PPOExpConfig as BasePPOExpConfig

DEBUG_MODE = False if os.environ.get("DEBUG_MODE", "False") == "False" else True  # Global debug flag
file_name = f"{'debug_' if DEBUG_MODE else ''}{os.path.splitext(os.path.basename(__file__))[0]}"

@dataclass
class PPOExpConfig_(BasePPOExpConfig):
    run_name: str = "thinker_q1_5b"    
    pretrain: Optional[str] = "large_data/base/Qwen/Qwen2.5-1.5B"
    critic_pretrain: Optional[str] = "large_data/base/Qwen/Qwen2.5-1.5B"

    prompt_data: ListConfig = ListConfig(
        [
            "data/clean_orz_129k.json",
        ]
    )
    
    ckpt_path: str = f"large_data/checkpoints/{run_name}"
    save_path: str = f"large_data/checkpoints/{run_name}"
    tensorboard_log_dir: str = f"large_data/logs/{run_name}"
    wandb_run_name: str = run_name    

    num_episodes: int = 240
    rollout_batch_size: int = 128
    n_samples_per_prompt: int = 32    

    actor_learning_rate: float = 1e-6
    critic_learning_rate: float = 5e-6

    policy_update_steps: int = 1
    critic_update_steps: int = 12    

    max_len: int = 16000
    
    prompt_type: int = 0
    summary: bool = True
    verify_reweight: bool = True

    ref_pretrain: Optional[str] = None
    use_kl_loss: bool = True
    kl_loss_coef: float = 0.
    use_kl_estimator_k3: bool = True    
    update_ref_every_epoch: bool = False

    num_warmup_steps: int = 50

    # enable_eval: bool = False
    # summary_buffer_size: int = 16


if __name__ == "__main__":
    exp = PPOExp().set_cfg(PPOExpConfig_())
    logger.info(exp.get_cfg_as_str(exp.cfg))
    if not os.path.exists(exp.cfg.save_path):
        os.makedirs(exp.cfg.save_path, exist_ok=True)
    if not os.path.exists(exp.cfg.tensorboard_log_dir):
        os.makedirs(exp.cfg.tensorboard_log_dir, exist_ok=True)
    if not os.path.exists(exp.cfg.ckpt_path):
        os.makedirs(exp.cfg.ckpt_path, exist_ok=True)
    asyncio.run(exp.run())
