import json
from pathlib import Path
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel

# 配置路径
PROJECT_ROOT = Path(__file__).parent.parent
MODEL_PATH = "/modelscope/hub/models/Qwen/Qwen3-14B"  # 基础模型路径
LORA_PATH = PROJECT_ROOT / "SFT" / "output" / "final_model"  # LoRA模型路径

def load_model():
    """加载训练好的模型"""
    print("\n" + "🤖 "*40)
    print("加载推理模型")
    print("🤖 "*40 + "\n")
    
    print(f"📂 基础模型路径: {MODEL_PATH}")
    print(f"📂 LoRA模型路径: {LORA_PATH}")
    
    if not LORA_PATH.exists():
        print(f"\n❌ 错误: LoRA模型不存在！")
        print(f"   请先运行训练脚本: python train.py")
        exit(1)
    
    print("\n⏳ 加载tokenizer...")
    tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)
    print("✓ Tokenizer加载完成")
    
    print("\n⏳ 加载基础模型（这可能需要几分钟）...")
    base_model = AutoModelForCausalLM.from_pretrained(
        MODEL_PATH,
        device_map="auto",
        torch_dtype=torch.float16,
        trust_remote_code=True
    )
    print("✓ 基础模型加载完成")
    
    print("\n⏳ 加载LoRA权重...")
    model = PeftModel.from_pretrained(base_model, str(LORA_PATH))
    model.eval()
    print("✓ LoRA权重加载完成")
    
    # 显示设备信息
    device = next(model.parameters()).device
    print(f"\n🖥️  模型设备: {device}")
    print(f"💾 模型数据类型: {next(model.parameters()).dtype}")
    
    print("\n" + "="*80)
    print("✓ 模型加载完成，准备推理！")
    print("="*80 + "\n")
    
    return model, tokenizer

def clean_generated_code(response):
    """清理生成的代码，只保留Python代码部分"""
    import re
    
    # 移除 <think> 标签及其内容
    response = re.sub(r'<think>.*?</think>', '', response, flags=re.DOTALL)
    
    # 提取代码块
    code_blocks = re.findall(r'```python\s*(.*?)\s*```', response, re.DOTALL)
    
    if code_blocks:
        # 如果找到代码块，使用第一个代码块
        code = code_blocks[0].strip()
    else:
        # 如果没有代码块标记，尝试直接使用响应内容
        # 移除可能的解释性文本（以 ### 或 --- 开头的部分）
        lines = response.split('\n')
        code_lines = []
        skip_explanation = False
        
        for line in lines:
            # 跳过分隔符和说明部分
            if line.strip().startswith('###') or line.strip().startswith('---') or line.strip() == '```':
                skip_explanation = True
                continue
            
            # 如果遇到代码行，停止跳过
            if not skip_explanation or (line.strip() and not line.strip().startswith('#') and not line.strip().startswith('**')):
                code_lines.append(line)
        
        code = '\n'.join(code_lines).strip()
    
    # 移除可能的尾部说明（从最后一个空行后的内容）
    # 查找最后一个函数调用或代码语句的位置
    lines = code.split('\n')
    last_code_line = -1
    
    for i in range(len(lines) - 1, -1, -1):
        line = lines[i].strip()
        # 检查是否是代码行（不是注释，不是空行，不是Markdown格式）
        if line and not line.startswith('#') and not line.startswith('**') and not line.startswith('###'):
            # 检查是否看起来像代码
            if '=' in line or line.endswith(')') or line.endswith('}') or line.endswith(']'):
                last_code_line = i
                break
    
    if last_code_line >= 0:
        code = '\n'.join(lines[:last_code_line + 1])
    
    return code.strip()

def generate_code(model, tokenizer, circuit_data_path):
    """生成代码"""
    print("\n" + "🔮 "*40)
    print("开始代码生成")
    print("🔮 "*40 + "\n")
    
    # 读取模板
    template_path = PROJECT_ROOT / "main.py"
    print(f"📄 加载模板: {template_path}")
    with open(template_path, 'r') as f:
        template = f.read()
    print(f"✓ 模板加载完成，长度: {len(template)} 字符")
    
    # 读取电路数据
    print(f"\n📄 加载电路数据: {circuit_data_path}")
    with open(circuit_data_path, 'r') as f:
        circuit_data = json.load(f)
    num_cells = len(circuit_data.get('cells', []))
    num_nets = len(circuit_data.get('nets', []))
    print(f"✓ 电路数据加载完成")
    print(f"   Cells数量: {num_cells}")
    print(f"   Nets数量: {num_nets}")
    
    # 构造prompt
    print("\n🔧 构造prompt...")
    prompt = f"""请根据以下模板和电路数据生成代码：

模板代码：
```python
{template}
```

电路数据：
```json
{json.dumps(circuit_data, indent=2)}
```

要求生成完整的Python代码来处理该电路数据。具体来说，order_list和rotation_list应根据cells的数量进行调整，以实现合理的布局和旋转。rotation_list中的旋转选项包括'R0'和'MY'。
在代码中修改，从{circuit_data_path}加载json格式的文件。
只输出Python代码，不要包含任何解释、说明或思考过程。"""
    
    messages = f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n"
    
    print("\n🔤 Tokenize输入...")
    inputs = tokenizer(messages, return_tensors="pt").to(model.device)
    input_length = inputs.input_ids.shape[1]
    print(f"✓ 输入token数: {input_length}")
    
    print("\n⏳ 生成代码...")
    
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=8196,
            temperature=0.2,
            top_p=0.9,
            do_sample=True
        )
    
    output_length = outputs.shape[1]
    generated_tokens = output_length - input_length
    print(f"\n✓ 生成完成")
    print(f"   生成token数: {generated_tokens}")
    print(f"   总token数: {output_length}")
    
    response = tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
    print(f"   原始生成长度: {len(response)} 字符")
    
    # 清理生成的代码
    print("\n🧹 清理生成的代码...")
    cleaned_code = clean_generated_code(response)
    print(f"✓ 清理完成，最终代码长度: {len(cleaned_code)} 字符")
    
    return cleaned_code

if __name__ == "__main__":
    import sys
    
    print("\n" + "="*80)
    print(" "*25 + "代码生成推理系统")
    print("="*80)
    
    if len(sys.argv) < 2:
        print("\n❌ 错误: 缺少参数")
        print("\n用法: python inference.py <circuit_data.json>")
        print("\n示例:")
        print("  python inference.py ../data/circuit_data.json")
        print("  python inference.py ../data/circuit_data_2.json")
        sys.exit(1)
    
    circuit_file = Path(sys.argv[1])
    print(f"\n📁 输入文件: {circuit_file}")
    
    if not circuit_file.exists():
        print(f"\n❌ 错误: 文件不存在！")
        print(f"   路径: {circuit_file.absolute()}")
        sys.exit(1)
    
    print(f"✓ 文件存在: {circuit_file.absolute()}")
    
    # 加载模型
    model, tokenizer = load_model()
    
    # 生成代码
    code = generate_code(model, tokenizer, circuit_file)
    
    # 输出结果
    print("\n" + "📝 "*40)
    print("生成的代码")
    print("📝 "*40)
    print("\n" + "="*80)
    print(code)
    print("="*80 + "\n")
    
    # 保存结果
    output_file = PROJECT_ROOT / "SFT" / "output" / f"{circuit_file.stem}_script.py"
    output_file.parent.mkdir(parents=True, exist_ok=True)
    with open(output_file, 'w') as f:
        f.write(code)
    print(f"💾 代码已保存至: {output_file}")
    
    print("\n" + "🎉 "*40)
    print("推理完成！")
    print("🎉 "*40 + "\n")
