"""
generate_samples.py
为GRPO训练生成样本数据
从data文件夹中的每个JSON文件生成5个Python代码样本，并运行得到对应的JSON结果
"""

import json
import os
import sys
import subprocess
from pathlib import Path
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
import re
import shutil

# 配置路径
PROJECT_ROOT = Path(__file__).parent.parent
DATA_DIR = PROJECT_ROOT / "data"
MODEL_PATH = "/research/d7/gds/yhhan25/.cache/modelscope/hub/models/Qwen/Qwen3-14B"
LORA_PATH = PROJECT_ROOT / "SFT" / "output" / "final_model"
GRPO_DIR = PROJECT_ROOT / "GRPO"
SAMPLES_DIR = GRPO_DIR / "samples"
RESULTS_DIR = GRPO_DIR / "results"

# 确保目录存在
SAMPLES_DIR.mkdir(parents=True, exist_ok=True)
RESULTS_DIR.mkdir(parents=True, exist_ok=True)

def load_template():
    """加载代码模板"""
    template_path = PROJECT_ROOT / "main.py"
    print(f"加载代码模板: {template_path}")
    with open(template_path, 'r') as f:
        template = f.read()
    return template

def load_model():
    """加载训练好的SFT模型"""
    print("\n" + "="*80)
    print("加载SFT模型用于生成样本")
    print("="*80 + "\n")
    
    print(f"基础模型路径: {MODEL_PATH}")
    print(f"LoRA模型路径: {LORA_PATH}")
    
    if not LORA_PATH.exists():
        print(f"\n❌ 错误: LoRA模型不存在！")
        print(f"   请先运行训练脚本: python SFT/train.py")
        exit(1)
    
    print("⏳ 加载tokenizer...")
    tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)
    
    print("⏳ 加载基础模型...")
    base_model = AutoModelForCausalLM.from_pretrained(
        MODEL_PATH,
        device_map="auto",
        torch_dtype=torch.float16,
        trust_remote_code=True
    )
    
    print("⏳ 加载LoRA权重...")
    model = PeftModel.from_pretrained(base_model, str(LORA_PATH))
    model.eval()
    
    print("✓ 模型加载完成\n")
    return model, tokenizer

def clean_generated_code(response):
    """清理生成的代码，只保留Python代码部分"""
    # 移除 <think> 标签及其内容
    response = re.sub(r'<think>.*?</think>', '', response, flags=re.DOTALL)
    
    # 提取代码块
    code_blocks = re.findall(r'```python\s*(.*?)\s*```', response, re.DOTALL)
    
    if code_blocks:
        code = code_blocks[0].strip()
    else:
        # 如果没有代码块标记，尝试直接使用响应内容
        lines = response.split('\n')
        code_lines = []
        
        for line in lines:
            if line.strip().startswith('###') or line.strip().startswith('---') or line.strip() == '```':
                continue
            code_lines.append(line)
        
        code = '\n'.join(code_lines).strip()
    
    return code

def generate_code_sample(model, tokenizer, prompt, temperature=0.8, top_p=0.9):
    """使用模型生成一个代码样本"""
    messages = [
        {"role": "system", "content": "你是一个专业的Python代码生成助手，只输出代码，不要包含任何解释。"},
        {"role": "user", "content": prompt}
    ]
    
    text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True
    )
    
    model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
    
    with torch.no_grad():
        generated_ids = model.generate(
            **model_inputs,
            max_new_tokens=2048,
            temperature=temperature,
            top_p=top_p,
            do_sample=True,
            pad_token_id=tokenizer.eos_token_id
        )
    
    generated_ids = [
        output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
    ]
    
    response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
    return clean_generated_code(response)

def run_python_script(script_path, timeout=300):
    """运行Python脚本并返回是否成功"""
    try:
        result = subprocess.run(
            [sys.executable, str(script_path)],
            capture_output=True,
            text=True,
            timeout=timeout,
            cwd=PROJECT_ROOT
        )
        return result.returncode == 0, result.stdout, result.stderr
    except subprocess.TimeoutExpired:
        return False, "", "Script execution timeout"
    except Exception as e:
        return False, "", str(e)

def extract_output_json_path(code):
    """从生成的代码中提取输出JSON文件路径"""
    # 查找ss_router中的output_file_path参数
    match = re.search(r"output_file_path\s*=\s*['\"](.+?)['\"]", code)
    if match:
        return match.group(1)
    return None

def generate_samples_for_circuit(model, tokenizer, json_file, template, num_samples=5):
    """为一个电路数据生成多个样本"""
    print(f"\n{'='*80}")
    print(f"处理电路数据: {json_file.name}")
    print(f"{'='*80}\n")
    
    # 读取电路数据
    with open(json_file, 'r') as f:
        circuit_data = json.load(f)
    
    circuit_name = json_file.stem
    
    # 创建该电路的样本目录
    circuit_samples_dir = SAMPLES_DIR / circuit_name
    circuit_samples_dir.mkdir(exist_ok=True)
    
    circuit_results_dir = RESULTS_DIR / circuit_name
    circuit_results_dir.mkdir(exist_ok=True)
    
    # 构造prompt（与SFT相同）
    prompt = f"""请根据以下模板和电路数据生成代码：

模板代码：
```python
{template}
```

电路数据：
```json
{json.dumps(circuit_data, indent=2)}
```

要求生成完整的Python代码来处理该电路数据。具体来说，order_list和rotation_list应根据cells的数量进行调整，以实现合理的布局和旋转。在代码中修改，从{json_file}加载json格式的文件。rotation_list中的旋转选项包括'R0'和'MY'。只输出Python代码，不要包含任何解释、说明或思考过程。"""
    
    successful_samples = []
    
    # 生成样本
    for i in range(num_samples):
        print(f"\n生成样本 {i+1}/{num_samples}...")
        
        # 使用不同的temperature来增加多样性
        temperature = 0.7 + (i * 0.1)  # 0.7, 0.8, 0.9, 1.0, 1.1
        
        try:
            # 生成代码
            generated_code = generate_code_sample(model, tokenizer, prompt, temperature=temperature)
            
            # 修改代码中的输出路径
            sample_name = f"sample_{i+1}"
            output_json_name = f"{sample_name}_routed.json"
            output_json_path = circuit_results_dir / output_json_name
            
            # 替换输出路径
            modified_code = re.sub(
                r"output_file_path\s*=\s*['\"](.+?)['\"]",
                f"output_file_path='{output_json_path}'",
                generated_code
            )
            
            # 保存生成的Python代码
            sample_script_path = circuit_samples_dir / f"{sample_name}.py"
            with open(sample_script_path, 'w') as f:
                f.write(modified_code)
            
            print(f"✓ 已保存脚本: {sample_script_path}")
            
            # 运行生成的代码
            print(f"⏳ 运行脚本...")
            success, stdout, stderr = run_python_script(sample_script_path)
            
            if success and output_json_path.exists():
                print(f"✓ 成功生成结果: {output_json_path}")
                successful_samples.append({
                    'sample_id': sample_name,
                    'script_path': str(sample_script_path),
                    'result_path': str(output_json_path),
                    'temperature': temperature
                })
            else:
                print(f"⚠️  脚本运行失败")
                print(f"   stderr: {stderr[:200]}")
                
        except Exception as e:
            print(f"❌ 生成样本 {i+1} 时出错: {e}")
            continue
    
    print(f"\n成功生成 {len(successful_samples)}/{num_samples} 个样本\n")
    return successful_samples

def main():
    """主函数"""
    print("\n" + "🎯 "*40)
    print("GRPO 样本生成器")
    print("🎯 "*40 + "\n")
    
    # 加载模板
    template = load_template()
    
    # 加载模型
    model, tokenizer = load_model()
    
    # 获取所有JSON文件
    json_files = list(DATA_DIR.glob("*.json"))
    print(f"找到 {len(json_files)} 个电路数据文件\n")
    
    all_samples = {}
    
    # 为每个电路数据生成样本
    for json_file in json_files:
        samples = generate_samples_for_circuit(model, tokenizer, json_file, template, num_samples=5)
        all_samples[json_file.stem] = samples
    
    # 保存样本元数据
    metadata_path = GRPO_DIR / "samples_metadata.json"
    with open(metadata_path, 'w') as f:
        json.dump(all_samples, f, indent=2)
    
    print(f"\n{'='*80}")
    print(f"✓ 所有样本生成完成！")
    print(f"  样本代码目录: {SAMPLES_DIR}")
    print(f"  结果JSON目录: {RESULTS_DIR}")
    print(f"  元数据文件: {metadata_path}")
    print(f"{'='*80}\n")

if __name__ == "__main__":
    main()
