"""
train.py
GRPO (Group Relative Policy Optimization) 训练脚本
基于SFT模型，使用组内相对优势进行强化学习优化
"""

import json
import os
from pathlib import Path
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training, PeftModel
import numpy as np
from tqdm import tqdm
from datetime import datetime
from typing import List, Dict
import torch.nn.functional as F

os.environ["TRANSFORMERS_OFFLINE"] = "1"
os.environ["HF_DATASETS_OFFLINE"] = "1"

# 配置路径
PROJECT_ROOT = Path(__file__).parent.parent
DATA_DIR = PROJECT_ROOT / "data"
GRPO_DIR = PROJECT_ROOT / "GRPO"
SFT_MODEL_PATH = PROJECT_ROOT / "SFT" / "output" / "final_model"
BASE_MODEL_PATH = "/research/d7/gds/yhhan25/.cache/modelscope/hub/models/Qwen/Qwen3-14B"
OUTPUT_DIR = GRPO_DIR / "output"

# GRPO超参数
GRPO_CONFIG = {
    'learning_rate': 1e-6,  # 比SFT更小的学习率
    'num_epochs': 3,
    'batch_size': 1,  # GRPO通常使用小batch
    'gradient_accumulation_steps': 4,
    'max_grad_norm': 0.5,
    'kl_coef': 0.1,  # KL散度系数，防止偏离SFT模型太远
    'clip_range': 0.2,  # PPO-style clipping
    'value_clip_range': 0.2,
}

class GRPODataset(Dataset):
    """GRPO训练数据集"""
    
    def __init__(self, samples_metadata, rewards_data, template):
        self.data = []
        self.template = template
        
        # 构建训练数据
        for circuit_name, samples in samples_metadata.items():
            if circuit_name not in rewards_data:
                continue
            
            # 读取电路数据
            circuit_json = DATA_DIR / f"{circuit_name}.json"
            if not circuit_json.exists():
                continue
            
            with open(circuit_json, 'r') as f:
                circuit_data = json.load(f)
            
            # 构造prompt（与SFT相同）
            prompt = f"""请根据以下模板和电路数据生成代码：

模板代码：
```python
{template}
```

电路数据：
```json
{json.dumps(circuit_data, indent=2)}
```

要求生成完整的Python代码来处理该电路数据。具体来说，order_list和rotation_list应根据cells的数量进行调整，以实现合理的布局和旋转。rotation_list中的旋转选项包括'R0'和'MY'。只输出Python代码，不要包含任何解释、说明或思考过程。"""
            
            # 为每个样本添加数据点
            circuit_rewards = {r['name']: r for r in rewards_data[circuit_name]}
            
            for sample in samples:
                sample_name = sample['sample_id']
                if sample_name not in circuit_rewards:
                    continue
                
                # 读取生成的代码作为response
                script_path = Path(sample['script_path'])
                if not script_path.exists():
                    continue
                
                with open(script_path, 'r') as f:
                    response = f.read()
                
                reward_info = circuit_rewards[sample_name]
                
                self.data.append({
                    'prompt': prompt,
                    'response': response,
                    'reward': reward_info['reward'],
                    'distance': reward_info['distance'],
                    'circuit_name': circuit_name,
                    'sample_name': sample_name
                })
        
        print(f"构建GRPO数据集，共 {len(self.data)} 个样本")
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx]

def load_models():
    """加载基础模型和SFT模型（作为参考策略）"""
    print("\n" + "="*80)
    print("加载模型")
    print("="*80 + "\n")
    
    # 加载tokenizer
    print("⏳ 加载tokenizer...")
    tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_PATH, trust_remote_code=True)
    tokenizer.pad_token = tokenizer.eos_token
    print("✓ Tokenizer加载完成")
    
    # 加载基础模型
    print("\n⏳ 加载基础模型...")
    base_model = AutoModelForCausalLM.from_pretrained(
        BASE_MODEL_PATH,
        device_map="auto",
        torch_dtype=torch.float16,
        trust_remote_code=True
    )
    print("✓ 基础模型加载完成")
    
    # 加载SFT模型作为policy model（要训练的）
    print("\n⏳ 加载SFT模型作为policy model...")
    policy_model = PeftModel.from_pretrained(base_model, str(SFT_MODEL_PATH))
    
    # 启用适配器的梯度计算
    policy_model.train()
    for name, param in policy_model.named_parameters():
        # 只让LoRA/adapter参数可训练
        if 'lora' in name.lower() or 'adapter' in name.lower():
            param.requires_grad = True
        # 或者让所有PEFT相关参数可训练
        elif any(keyword in name for keyword in ['modules_to_save', 'classifier', 'score']):
            param.requires_grad = True
    
    # 打印可训练参数信息
    trainable_params = 0
    all_params = 0
    lora_params_found = []
    for name, param in policy_model.named_parameters():
        all_params += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
            if len(lora_params_found) < 5:  # 只打印前5个可训练参数
                lora_params_found.append(f"  可训练参数: {name}, shape: {param.shape}")
    
    for msg in lora_params_found:
        print(msg)
    
    print(f"✓ Policy模型加载完成")
    print(f"  可训练参数: {trainable_params:,} / {all_params:,} ({100 * trainable_params / all_params:.2f}%)")
    
    if trainable_params == 0:
        # 如果还是没有可训练参数，尝试解冻所有参数（fallback）
        print("⚠️  警告：没有找到LoRA参数，尝试启用所有参数梯度...")
        for param in policy_model.parameters():
            param.requires_grad = True
        trainable_params = sum(p.numel() for p in policy_model.parameters() if p.requires_grad)
        print(f"  现在可训练参数: {trainable_params:,}")
        
        if trainable_params == 0:
            raise ValueError("❌ 错误：仍然没有可训练参数！请检查PEFT配置")
    
    # 加载SFT模型作为reference model（固定的，用于计算KL散度）
    print("\n⏳ 加载SFT模型作为reference model...")
    base_model_ref = AutoModelForCausalLM.from_pretrained(
        BASE_MODEL_PATH,
        device_map="auto",
        torch_dtype=torch.float16,
        trust_remote_code=True
    )
    reference_model = PeftModel.from_pretrained(base_model_ref, str(SFT_MODEL_PATH))
    reference_model.eval()
    for param in reference_model.parameters():
        param.requires_grad = False
    print("✓ Reference模型加载完成")
    
    return policy_model, reference_model, tokenizer

def compute_log_probs(model, input_ids, attention_mask, labels):
    """计算给定序列的log probabilities"""
    # 直接计算，不使用autocast（避免梯度问题）
    outputs = model(
        input_ids=input_ids,
        attention_mask=attention_mask
    )
    logits = outputs.logits
    
    # 计算log probabilities
    log_probs = F.log_softmax(logits, dim=-1)
    
    # 获取实际token的log probabilities
    # Shift so that tokens < n predict n
    shift_log_probs = log_probs[..., :-1, :].contiguous()
    shift_labels = labels[..., 1:].contiguous()
    
    # 只计算非padding token的log prob
    gathered_log_probs = torch.gather(
        shift_log_probs,
        dim=-1,
        index=shift_labels.unsqueeze(-1)
    ).squeeze(-1)
    
    # Mask padding tokens
    shift_attention_mask = attention_mask[..., 1:].contiguous().float()
    masked_log_probs = gathered_log_probs * shift_attention_mask
    
    # Sum log probs (per sequence)
    sum_log_probs = masked_log_probs.sum(dim=-1)
    
    return sum_log_probs

def grpo_loss(policy_log_probs, reference_log_probs, rewards, kl_coef=0.1, clip_range=0.2):
    """
    计算GRPO损失
    
    Args:
        policy_log_probs: 当前策略的log概率 (requires grad)
        reference_log_probs: 参考策略的log概率 (detached, no grad)
        rewards: 奖励值
        kl_coef: KL散度系数
        clip_range: PPO-style裁剪范围
    """
    # 确保reference_log_probs被detach（如果还没有的话）
    reference_log_probs = reference_log_probs.detach()
    
    # 确保rewards是正确的tensor格式
    if not isinstance(rewards, torch.Tensor):
        rewards = torch.tensor(rewards, dtype=torch.float32, device=policy_log_probs.device)
    rewards = rewards.detach()
    
    # 计算KL散度 (在log space中)
    kl = policy_log_probs - reference_log_probs
    
    # 计算ratio (重要性采样比率)
    ratio = torch.exp(policy_log_probs - reference_log_probs)
    
    # PPO-style clipping
    clipped_ratio = torch.clamp(ratio, 1 - clip_range, 1 + clip_range)
    
    # 计算优势函数 (使用reward作为优势的基准)
    # 在GRPO中，advantages通常就是rewards本身（已经是相对于baseline的优势）
    advantages = rewards
    
    # 计算policy loss (PPO-style objective)
    policy_loss = -torch.min(
        ratio * advantages,
        clipped_ratio * advantages
    ).mean()
    
    # KL penalty (regularization term)
    kl_penalty = kl_coef * kl.mean()
    
    # Total loss
    total_loss = policy_loss + kl_penalty
    
    return total_loss, policy_loss, kl_penalty

def train_grpo():
    """GRPO训练主函数"""
    print("\n" + "🚀 "*40)
    print("开始GRPO训练")
    print("🚀 "*40 + "\n")
    
    # 创建输出目录
    OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
    
    # 加载模板
    with open(PROJECT_ROOT / "main.py", 'r') as f:
        template = f.read()
    
    # 加载元数据和奖励
    print("加载训练数据...")
    with open(GRPO_DIR / "samples_metadata.json", 'r') as f:
        samples_metadata = json.load(f)
    
    with open(GRPO_DIR / "rewards.json", 'r') as f:
        rewards_data = json.load(f)
    
    # 创建数据集
    dataset = GRPODataset(samples_metadata, rewards_data, template)
    
    if len(dataset) == 0:
        print("❌ 数据集为空，请先运行 generate_samples.py 和 compute_rewards.py")
        return
    
    # 创建数据加载器
    dataloader = DataLoader(
        dataset,
        batch_size=GRPO_CONFIG['batch_size'],
        shuffle=True,
        num_workers=0
    )
    
    # 加载模型
    policy_model, reference_model, tokenizer = load_models()
    
    # 设置优化器
    optimizer = torch.optim.AdamW(
        policy_model.parameters(),
        lr=GRPO_CONFIG['learning_rate']
    )
    
    # 训练循环
    policy_model.train()
    global_step = 0
    
    print("\n" + "="*80)
    print("开始训练")
    print("="*80 + "\n")
    
    for epoch in range(GRPO_CONFIG['num_epochs']):
        print(f"\nEpoch {epoch + 1}/{GRPO_CONFIG['num_epochs']}")
        print("-" * 80)
        
        epoch_stats = {
            'total_loss': 0,
            'policy_loss': 0,
            'kl_penalty': 0,
            'avg_reward': 0
        }
        
        progress_bar = tqdm(dataloader, desc=f"Epoch {epoch + 1}")
        
        for batch_idx, batch in enumerate(progress_bar):
            # 准备输入
            messages = [
                {"role": "system", "content": "你是一个专业的Python代码生成助手，只输出代码，不要包含任何解释。"},
                {"role": "user", "content": batch['prompt'][0]},
                {"role": "assistant", "content": batch['response'][0]}
            ]
            
            text = tokenizer.apply_chat_template(
                messages,
                tokenize=False,
                add_generation_prompt=False
            )
            
            # Tokenize
            encodings = tokenizer(
                text,
                return_tensors="pt",
                padding=True,
                truncation=True,
                max_length=4096
            )
            
            input_ids = encodings['input_ids'].to(policy_model.device)
            attention_mask = encodings['attention_mask'].to(policy_model.device)
            labels = input_ids.clone()
            
            # 获取奖励
            rewards = torch.tensor([batch['reward'][0]], dtype=torch.float32).to(policy_model.device)
            
            # 计算policy model的log probs
            policy_log_probs = compute_log_probs(policy_model, input_ids, attention_mask, labels)
            
            # 调试：检查policy_log_probs是否有梯度
            if batch_idx == 0:
                print(f"\n调试信息:")
                print(f"  policy_log_probs shape: {policy_log_probs.shape}")
                print(f"  policy_log_probs requires_grad: {policy_log_probs.requires_grad}")
                print(f"  policy_log_probs grad_fn: {policy_log_probs.grad_fn}")
            
            # 计算reference model的log probs (no grad)
            with torch.no_grad():
                reference_log_probs = compute_log_probs(reference_model, input_ids, attention_mask, labels)
            
            # 计算GRPO损失
            loss, policy_loss, kl_pen = grpo_loss(
                policy_log_probs,
                reference_log_probs,
                rewards,
                kl_coef=GRPO_CONFIG['kl_coef'],
                clip_range=GRPO_CONFIG['clip_range']
            )
            
            # 调试：检查loss是否有梯度
            if batch_idx == 0:
                print(f"  loss requires_grad: {loss.requires_grad}")
                print(f"  loss grad_fn: {loss.grad_fn}")
            
            # 反向传播
            loss = loss / GRPO_CONFIG['gradient_accumulation_steps']
            loss.backward()
            
            # 梯度累积
            if (batch_idx + 1) % GRPO_CONFIG['gradient_accumulation_steps'] == 0:
                # 梯度裁剪
                torch.nn.utils.clip_grad_norm_(
                    policy_model.parameters(),
                    GRPO_CONFIG['max_grad_norm']
                )
                
                optimizer.step()
                optimizer.zero_grad()
                global_step += 1
            
            # 统计
            epoch_stats['total_loss'] += loss.item() * GRPO_CONFIG['gradient_accumulation_steps']
            epoch_stats['policy_loss'] += policy_loss.item()
            epoch_stats['kl_penalty'] += kl_pen.item()
            epoch_stats['avg_reward'] += rewards.mean().item()
            
            # 更新进度条
            progress_bar.set_postfix({
                'loss': f"{loss.item():.4f}",
                'reward': f"{rewards.mean().item():.4f}",
                'kl': f"{kl_pen.item():.4f}"
            })
        
        # Epoch统计
        n_batches = len(dataloader)
        print(f"\nEpoch {epoch + 1} 统计:")
        print(f"  平均总损失: {epoch_stats['total_loss'] / n_batches:.4f}")
        print(f"  平均策略损失: {epoch_stats['policy_loss'] / n_batches:.4f}")
        print(f"  平均KL惩罚: {epoch_stats['kl_penalty'] / n_batches:.4f}")
        print(f"  平均奖励: {epoch_stats['avg_reward'] / n_batches:.4f}")
        
        # 保存checkpoint
        checkpoint_dir = OUTPUT_DIR / f"checkpoint-epoch-{epoch + 1}"
        checkpoint_dir.mkdir(exist_ok=True)
        policy_model.save_pretrained(checkpoint_dir)
        tokenizer.save_pretrained(checkpoint_dir)
        print(f"✓ 已保存checkpoint: {checkpoint_dir}")
    
    # 保存最终模型
    final_model_dir = OUTPUT_DIR / "final_model"
    final_model_dir.mkdir(exist_ok=True)
    policy_model.save_pretrained(final_model_dir)
    tokenizer.save_pretrained(final_model_dir)
    
    print("\n" + "="*80)
    print("✓ GRPO训练完成！")
    print(f"  最终模型保存在: {final_model_dir}")
    print("="*80 + "\n")

def main():
    """主函数"""
    train_grpo()

if __name__ == "__main__":
    main()
