import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning) 
import logging
from dataclasses import dataclass, field
from typing import Optional

import torch
import transformers
from transformers import HfArgumentParser, AutoTokenizer, BitsAndBytesConfig
from peft import LoraConfig, TaskType, get_peft_model

# TRL 和我们自己的模块
from trl import GRPOConfig
from grpo_trainer import GRPOTrainer 
from twnm.models.twnm_pretrained_model import TWNM, TWNMConfig
from rewards import format_reward, result_reward, length_reward
from twnm.data.grpo_dataset import AudioDataset

import os

@dataclass
class ScriptArguments:
    """
    定义此训练脚本特定的命令行参数。
    """
    # --- 核心参数 ---
    data_file: str = field(
        metadata={"help": "训练数据文件的路径 (jsonl 格式)"}
    )
    spatial_encoder_ckpt_path: Optional[str] = field(
        default=None, metadata={"help": "空间编码器(Spatial Encoder)的检查点路径"}
    )

    # --- 可选参数 ---
    use_wandb: bool = field(
        default=False, metadata={"help": "是否使用 WandB 来记录日志"}
    )
    wandb_project_name: str = field(
        default="xtwnm-grpo-training", metadata={"help": "WandB 项目名称"}
    )


def main():
    warnings.filterwarnings("ignore", message="Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.")
    warnings.filterwarnings("ignore", message="Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.")

    # 1. 解析命令行参数
    parser = HfArgumentParser((ScriptArguments, GRPOConfig))
    script_args, training_args = parser.parse_args_into_dataclasses()

    # 2. 设置日志
    logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
    logging.info(f"Script Arguments: {script_args}")
    logging.info(f"Training Arguments: {training_args}")

    # 3. 定义量化和LoRA配置
    logging.info("--- Defining Quantization and LoRA Configs ---")
    quantization_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16,
        bnb_4bit_use_double_quant=True,
    )

    grpo_lora_config = LoraConfig(
        target_modules=["q_proj", "v_proj"],
        task_type=TaskType.CAUSAL_LM,
        r=8,
        lora_alpha=32,
        lora_dropout=0.1,
        inference_mode=True,
    )
    policy_config = LoraConfig(
        target_modules=["q_proj", "v_proj"],
        task_type=TaskType.CAUSAL_LM,
        r=8,
        lora_alpha=32,
        lora_dropout=0.1,
        inference_mode=False,
    )

    # 4. 初始化TWNM模型
    # TWNMConfig 现在只负责非LLM部分的配置, 解码器路径指向SFT融合模型
    twnm_config = TWNMConfig()
    model = TWNM(config=twnm_config, peft_config=grpo_lora_config, quantization_config=quantization_config)
    
    state_dict = torch.load("assets/checkpoints/sft2_checkpoint-2502/pytorch_model.bin", map_location="cpu")

    new_state_dict = {}
    for k, v in state_dict.items():
        name = k[7:] if k.startswith('module.') else k
        new_state_dict[name] = v
        
    model.load_state_dict(new_state_dict, strict=False)

    model.decoder = model.decoder.merge_and_unload()

    model.decoder = get_peft_model(model.decoder, policy_config, adapter_name="policy")
    
    tokenizer = model.tokenizer
    logging.info("TWNM model with new GRPO adapters initialized successfully.")
    model.decoder.print_trainable_parameters()

    # 5. 定义奖励函数
    reward_funcs = [format_reward, result_reward, length_reward]

    # 6. 加载数据集
    train_dataset = AudioDataset(script_args.data_file)
    
    # 7. 配置WandB (如果使用)
    if script_args.use_wandb:
        training_args.report_to = "wandb"
        training_args.run_name = f"{script_args.wandb_project_name}-{training_args.output_dir.split('/')[-1]}"
    else:
        training_args.report_to = []

    # 8. 创建优化器
    trainable_params = [p for p in model.parameters() if p.requires_grad]
    optimizer = torch.optim.AdamW(
        trainable_params,
        lr=training_args.learning_rate,
        betas=(training_args.adam_beta1, training_args.adam_beta2),
        eps=training_args.adam_epsilon,
        weight_decay=training_args.weight_decay
    )

    # 9. 初始化GRPOTrainer
    logging.info("--- Initializing GRPOTrainer ---")
    trainer = GRPOTrainer(
        model=model,
        ref_model=None, # 关键：我们不再需要独立的 ref_model
        reward_funcs=reward_funcs,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=None,
        processor=tokenizer,
        optimizers=(optimizer, None)
    )

    # 10. 开始训练
    logging.info("--- Starting Training ---")
    trainer.train()
    logging.info("--- Training Finished ---")

    # 11. 保存最终模型
    logging.info("--- Saving Final Model ---")
    adapter_dir = os.path.join(training_args.output_dir, "lora_policy_adapter")
    os.makedirs(adapter_dir, exist_ok=True)

    # 仅保存 LoRA：model.decoder 是 PeftModel
    model.decoder.save_pretrained(adapter_dir)          # 会写 adapter_config.json + adapter_model.bin
    tokenizer.save_pretrained(adapter_dir)
    logging.info(f"Model saved to {training_args.output_dir}")


if __name__ == "__main__":
    main()