"""
inference.py
使用GRPO训练后的模型进行推理
"""

import json
from pathlib import Path
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
import re
import sys

# 配置路径
PROJECT_ROOT = Path(__file__).parent.parent
MODEL_PATH = "/modelscope/hub/models/Qwen/Qwen3-14B"  # 基础模型路径
GRPO_MODEL_PATH = PROJECT_ROOT / "GRPO" / "output" / "final_model"  # GRPO模型路径
DATA_DIR = PROJECT_ROOT / "data"
OUTPUT_DIR = PROJECT_ROOT / "GRPO" / "inference_output"

# 确保输出目录存在
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

def load_model():
    """加载GRPO训练好的模型"""
    print("\n" + "🤖 "*40)
    print("加载GRPO推理模型")
    print("🤖 "*40 + "\n")
    
    print(f"📂 基础模型路径: {MODEL_PATH}")
    print(f"📂 GRPO模型路径: {GRPO_MODEL_PATH}")
    
    if not GRPO_MODEL_PATH.exists():
        print(f"\n❌ 错误: GRPO模型不存在！")
        print(f"   请先运行训练脚本: python GRPO/train.py")
        exit(1)
    
    print("\n⏳ 加载tokenizer...")
    tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)
    print("✓ Tokenizer加载完成")
    
    print("\n⏳ 加载基础模型（这可能需要几分钟）...")
    base_model = AutoModelForCausalLM.from_pretrained(
        MODEL_PATH,
        device_map="auto",
        torch_dtype=torch.float16,
        trust_remote_code=True
    )
    print("✓ 基础模型加载完成")
    
    print("\n⏳ 加载GRPO权重...")
    model = PeftModel.from_pretrained(base_model, str(GRPO_MODEL_PATH))
    model.eval()
    print("✓ GRPO权重加载完成")
    
    # 显示设备信息
    device = next(model.parameters()).device
    print(f"\n🖥️  模型设备: {device}")
    print(f"💾 模型数据类型: {next(model.parameters()).dtype}")
    
    print("\n" + "="*80)
    print("✓ 模型加载完成，准备推理！")
    print("="*80 + "\n")
    
    return model, tokenizer

def clean_generated_code(response):
    """清理生成的代码，只保留Python代码部分"""
    # 移除 <think> 标签及其内容
    response = re.sub(r'<think>.*?</think>', '', response, flags=re.DOTALL)
    
    # 提取代码块
    code_blocks = re.findall(r'```python\s*(.*?)\s*```', response, re.DOTALL)
    
    if code_blocks:
        # 如果找到代码块，使用第一个代码块
        code = code_blocks[0].strip()
    else:
        # 如果没有代码块标记，尝试直接使用响应内容
        lines = response.split('\n')
        code_lines = []
        
        for line in lines:
            # 跳过分隔符和说明部分
            if line.strip().startswith('###') or line.strip().startswith('---') or line.strip() == '```':
                continue
            code_lines.append(line)
        
        code = '\n'.join(code_lines).strip()
    
    # 移除可能的尾部说明
    lines = code.split('\n')
    last_code_line = -1
    
    for i in range(len(lines) - 1, -1, -1):
        line = lines[i].strip()
        # 检查是否是代码行
        if line and not line.startswith('#') and not line.startswith('**') and not line.startswith('###'):
            if '=' in line or line.endswith(')') or line.endswith('}') or line.endswith(']'):
                last_code_line = i
                break
    
    if last_code_line >= 0:
        code = '\n'.join(lines[:last_code_line + 1])
    
    return code

def load_template():
    """加载代码模板"""
    template_path = PROJECT_ROOT / "main.py"
    with open(template_path, 'r') as f:
        return f.read()

def generate_code(model, tokenizer, circuit_data, circuit_path, template):
    """根据电路数据生成代码"""
    # 构造prompt（与训练时相同）
    prompt = f"""请根据以下模板和电路数据生成代码：

模板代码：
```python
{template}
```

电路数据：
```json
{json.dumps(circuit_data, indent=2)}
```

要求生成完整的Python代码来处理该电路数据。具体来说，order_list和rotation_list应根据cells的数量进行调整，以实现合理的布局和旋转。从{circuit_path}加载json格式的文件。rotation_list中的旋转选项包括'R0'和'MY'。只输出Python代码，不要包含任何解释、说明或思考过程。"""
    
    # 准备消息
    messages = [
        {"role": "system", "content": "你是一个专业的Python代码生成助手，只输出代码，不要包含任何解释。"},
        {"role": "user", "content": prompt}
    ]
    
    # 应用聊天模板
    text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True
    )
    
    # Tokenize
    model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
    
    print(f"📝 输入长度: {len(model_inputs.input_ids[0])} tokens")
    print("⏳ 生成中...")
    
    # 生成
    with torch.no_grad():
        generated_ids = model.generate(
            **model_inputs,
            max_new_tokens=2048,
            temperature=0.7,
            top_p=0.9,
            do_sample=True,
            pad_token_id=tokenizer.eos_token_id
        )
    
    # 解码
    generated_ids = [
        output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
    ]
    
    response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
    
    print("✓ 生成完成")
    print(f"📊 生成长度: {len(response)} 字符")
    
    # 清理代码
    cleaned_code = clean_generated_code(response)
    
    return cleaned_code

def run_inference_on_all_circuits():
    """对所有电路数据运行推理"""
    print("\n" + "🎯 "*40)
    print("GRPO模型推理")
    print("🎯 "*40 + "\n")
    
    # 加载模型
    model, tokenizer = load_model()
    
    # 加载模板
    template = load_template()
    
    # 获取所有JSON文件
    json_files = list(DATA_DIR.glob("*.json"))
    print(f"找到 {len(json_files)} 个电路数据文件\n")
    
    # 对每个电路运行推理
    for idx, json_file in enumerate(json_files, 1):
        print(f"\n{'='*80}")
        print(f"处理电路 [{idx}/{len(json_files)}]: {json_file.name}")
        print(f"{'='*80}\n")
        
        # 读取电路数据
        with open(json_file, 'r') as f:
            circuit_data = json.load(f)
        
        print(f"电路包含 {len(circuit_data.get('cells', []))} 个cells")
        
        # 生成代码
        try:
            generated_code = generate_code(model, tokenizer, circuit_data, template)
            
            # 保存生成的代码
            output_script = OUTPUT_DIR / f"{json_file.stem}_grpo_generated.py"
            with open(output_script, 'w') as f:
                f.write(generated_code)
            
            print(f"\n✓ 生成的代码已保存到: {output_script}")
            
            # 显示前几行代码
            lines = generated_code.split('\n')
            print(f"\n生成的代码前10行:")
            print("-" * 80)
            for i, line in enumerate(lines[:10], 1):
                print(f"{i:3}: {line}")
            print("-" * 80)
            
        except Exception as e:
            print(f"\n❌ 生成代码时出错: {e}")
            import traceback
            traceback.print_exc()
            continue
    
    print(f"\n{'='*80}")
    print(f"✓ 所有推理完成！")
    print(f"  输出目录: {OUTPUT_DIR}")
    print(f"{'='*80}\n")

def main():
    """主函数"""
    import argparse
    
    parser = argparse.ArgumentParser(description="GRPO模型推理")
    parser.add_argument("--circuit", type=str, default=None, help="指定电路JSON文件名（不含路径）")
    args = parser.parse_args()
    
    if args.circuit:
        # 单个电路推理
        circuit_path = DATA_DIR / args.circuit
        if not circuit_path.exists():
            print(f"❌ 错误: 电路文件不存在: {circuit_path}")
            return
        
        model, tokenizer = load_model()
        template = load_template()
        
        with open(circuit_path, 'r') as f:
            circuit_data = json.load(f)
        
        generated_code = generate_code(model, tokenizer, circuit_data, circuit_path, template)
        
        output_script = OUTPUT_DIR / f"{circuit_path.stem}_grpo_generated.py"
        with open(output_script, 'w') as f:
            f.write(generated_code)
        
        print(f"\n✓ 生成的代码已保存到: {output_script}")
        print(f"\n{generated_code}")
    else:
        # 所有电路推理
        run_inference_on_all_circuits()

if __name__ == "__main__":
    main()
