"""
compute_rewards.py
计算GRPO奖励
调用Representation的inference.py计算embeddings，然后计算样本与ground truth之间的距离
"""

import json
import os
import sys
import numpy as np
import torch
from pathlib import Path
import subprocess
import shutil
from typing import Dict, List, Tuple

# 添加Representation到路径以便导入CircuitProcessor
sys.path.insert(0, str(Path(__file__).parent.parent / "Representation"))
from data_generator import CircuitProcessor

# 配置路径
PROJECT_ROOT = Path(__file__).parent.parent
GRPO_DIR = PROJECT_ROOT / "GRPO"
RESULTS_DIR = GRPO_DIR / "results"
REPRESENTATION_DIR = PROJECT_ROOT / "Representation"
BAYES_DIR = PROJECT_ROOT / "bayes"
DATA_DIR = PROJECT_ROOT / "data"

# Representation inference相关路径
REPRESENTATION_MODEL = REPRESENTATION_DIR / "checkpoints" / "best_model.pt"
REPRESENTATION_OUTPUT = GRPO_DIR / "embeddings"

def process_circuit_json(json_path, output_dir):
    """
    处理电路JSON文件，生成graph.json, shape.npy, vector.npy
    
    Args:
        json_path: 输入的routed JSON文件路径
        output_dir: 输出目录
    """
    try:
        # 确保输出目录存在
        output_dir.mkdir(parents=True, exist_ok=True)
        
        # 创建CircuitProcessor实例
        processor = CircuitProcessor(scale_factor=1)
        
        # 加载电路数据
        circuit_data = processor.load_circuit_data(json_path)
        if circuit_data is None:
            print(f"    ⚠️ 无法加载电路数据: {json_path}")
            return False
        
        # 处理数据格式：如果是数组，提取第一个元素
        if isinstance(circuit_data, list) and len(circuit_data) > 0:
            circuit_data = circuit_data[0]
        
        # 确保数据是字典格式
        if not isinstance(circuit_data, dict):
            print(f"    ⚠️ 电路数据格式不正确，应为字典或包含字典的数组")
            return False
        
        # 生成图结构
        graph_data = processor.build_circuit_graph(circuit_data)
        if graph_data:
            graph_output = output_dir / "graph.json"
            with open(graph_output, 'w') as f:
                json.dump(graph_data, f, indent=2)
            print(f"      ✓ 生成graph.json")
        else:
            print(f"    ⚠️ 无法生成图结构")
            return False
        
        # 生成形状矩阵 (shape.npy)
        shape_matrix = processor.convert_to_shape_matrix(circuit_data)
        if shape_matrix is not None:
            shape_output = output_dir / "shape.npy"
            np.save(shape_output, shape_matrix)
            print(f"      ✓ 生成shape.npy，形状: {shape_matrix.shape}")
        else:
            print(f"    ⚠️ 无法生成形状矩阵")
            return False
        
        # 生成特征向量
        vector_data = processor.analyze_circuit_features(circuit_data)
        if vector_data is not None:
            vector_output = output_dir / "vector.npy"
            np.save(vector_output, vector_data)
            print(f"      ✓ 生成vector.npy，形状: {vector_data.shape}")
        else:
            print(f"    ⚠️ 无法生成特征向量")
            return False
        
        return True
        
    except Exception as e:
        print(f"    ❌ 处理电路JSON时出错: {e}")
        import traceback
        traceback.print_exc()
        return False

def prepare_circuits_for_representation(samples_metadata):
    """
    准备电路数据用于Representation推理
    将所有JSON文件组织到一个目录结构中，便于inference.py处理
    """
    print("\n" + "="*80)
    print("准备电路数据用于Representation推理")
    print("="*80 + "\n")
    
    # 创建临时目录用于存放所有电路JSON
    temp_circuits_dir = GRPO_DIR / "temp_circuits"
    temp_circuits_dir.mkdir(exist_ok=True)
    
    circuit_paths = {}
    
    # 为每个电路准备数据
    for circuit_name, samples in samples_metadata.items():
        print(f"\n处理电路: {circuit_name}")
        
        # Ground truth路径
        gt_script_path = BAYES_DIR / f"{circuit_name}_script.py"
        if not gt_script_path.exists():
            print(f"  ⚠️  未找到ground truth脚本: {gt_script_path}")
            continue
        
        # 运行ground truth脚本获取routed JSON
        gt_result_path = RESULTS_DIR / circuit_name / "ground_truth_routed.json"
        gt_result_path.parent.mkdir(parents=True, exist_ok=True)
        
        # 修改ground truth脚本的输出路径
        with open(gt_script_path, 'r') as f:
            gt_code = f.read()
        
        import re
        modified_gt_code = re.sub(
            r"output_file_path\s*=\s*['\"](.+?)['\"]",
            f"output_file_path='{gt_result_path}'",
            gt_code
        )
        
        temp_gt_script = GRPO_DIR / f"temp_gt_{circuit_name}.py"
        with open(temp_gt_script, 'w') as f:
            f.write(modified_gt_code)
        
        # 运行ground truth脚本
        if not gt_result_path.exists():
            print(f"  ⏳ 运行ground truth脚本...")
            try:
                result = subprocess.run(
                    [sys.executable, str(temp_gt_script)],
                    capture_output=True,
                    text=True,
                    timeout=300,
                    cwd=PROJECT_ROOT
                )
                if result.returncode == 0 and gt_result_path.exists():
                    print(f"  ✓ Ground truth生成成功")
                else:
                    print(f"  ⚠️  Ground truth生成失败: {result.stderr[:200]}")
                    continue
            except Exception as e:
                print(f"  ❌ 运行ground truth失败: {e}")
                continue
        else:
            print(f"  ✓ Ground truth已存在")
        
        # 为该电路创建目录
        circuit_dir = temp_circuits_dir / circuit_name
        circuit_dir.mkdir(exist_ok=True)
        
        # 创建ground truth子目录
        gt_dir = circuit_dir / "ground_truth"
        gt_dir.mkdir(exist_ok=True)
        
        # 处理ground truth JSON，生成graph.json, shape.npy, vector.npy
        print(f"  处理ground truth...")
        if gt_result_path.exists():
            if process_circuit_json(gt_result_path, gt_dir):
                print(f"    ✓ 已处理ground truth: {gt_dir}")
            else:
                print(f"    ❌ 处理ground truth失败")
                continue
        
        # 处理所有样本JSON
        sample_dirs = []
        for sample in samples:
            sample_path = Path(sample['result_path'])
            if sample_path.exists():
                print(f"  处理样本: {sample['sample_id']}")
                sample_dir = circuit_dir / sample['sample_id']
                sample_dir.mkdir(exist_ok=True)
                if process_circuit_json(sample_path, sample_dir):
                    print(f"    ✓ 已处理样本: {sample_dir}")
                    sample_dirs.append(str(sample_dir))
                else:
                    print(f"    ❌ 处理样本失败")
        
        circuit_paths[circuit_name] = {
            'ground_truth': str(gt_dir),
            'samples': sample_dirs
        }
    
    print(f"\n{'='*80}")
    print(f"✓ 电路数据准备完成: {temp_circuits_dir}")
    print(f"{'='*80}\n")
    return temp_circuits_dir, circuit_paths

def run_representation_inference(circuit_dir, circuit_paths):
    """
    运行Representation的inference.py来提取embeddings
    为每个电路的ground truth和样本分别运行推理
    """
    print("\n" + "="*80)
    print("运行Representation推理")
    print("="*80 + "\n")
    
    if not REPRESENTATION_MODEL.exists():
        print(f"❌ 错误: Representation模型不存在: {REPRESENTATION_MODEL}")
        print("   请先训练Representation模型")
        return False
    
    # 创建输出目录
    REPRESENTATION_OUTPUT.mkdir(exist_ok=True)
    
    # 为每个电路运行推理
    for circuit_name, paths in circuit_paths.items():
        print(f"\n处理电路: {circuit_name}")
        circuit_input_dir = circuit_dir / circuit_name
        
        if not circuit_input_dir.exists():
            print(f"  ⚠️  电路目录不存在: {circuit_input_dir}")
            continue
        
        # 调用inference.py处理该电路的所有数据（ground_truth + samples）
        cmd = [
            sys.executable,
            str(REPRESENTATION_DIR / "inference.py"),
            "--model_path", str(REPRESENTATION_MODEL),
            "--test_dir", str(circuit_input_dir),
            "--output_dir", str(REPRESENTATION_OUTPUT / circuit_name),
            "--batch_size", "16"
        ]
        
        print(f"  执行命令: {' '.join(cmd)}")
        
        try:
            result = subprocess.run(
                cmd,
                capture_output=True,
                text=True,
                timeout=600,
                cwd=REPRESENTATION_DIR
            )
            
            if result.returncode != 0:
                print(f"  ❌ 推理失败:")
                print(f"     {result.stderr[:500]}")
                continue
            else:
                print(f"  ✓ 推理完成")
                # 打印简要输出
                output_lines = result.stdout.strip().split('\n')
                for line in output_lines[-5:]:  # 只打印最后5行
                    if line.strip():
                        print(f"     {line}")
            
        except Exception as e:
            print(f"  ❌ 运行推理时出错: {e}")
            continue
    
    print("\n✓ 所有电路的Representation推理完成\n")
    return True

def load_embeddings(circuit_paths):
    """
    从Representation的输出中加载embeddings
    """
    print("\n" + "="*80)
    print("加载Embeddings")
    print("="*80 + "\n")
    
    embeddings = {}
    
    for circuit_name, paths in circuit_paths.items():
        circuit_embeddings = {
            'ground_truth': None,
            'samples': []
        }
        
        # 加载ground truth embedding
        # 新的路径结构: REPRESENTATION_OUTPUT / circuit_name / ground_truth / embedding.npy
        gt_embedding_path = REPRESENTATION_OUTPUT / circuit_name / "ground_truth" / "embedding.npy"
        if gt_embedding_path.exists():
            circuit_embeddings['ground_truth'] = np.load(gt_embedding_path)
            print(f"✓ 加载 {circuit_name} ground truth embedding: {circuit_embeddings['ground_truth'].shape}")
        else:
            print(f"⚠️  未找到 {circuit_name} ground truth embedding: {gt_embedding_path}")
        
        # 加载样本embeddings
        for sample_path in paths['samples']:
            sample_name = Path(sample_path).name
            # 新的路径结构: REPRESENTATION_OUTPUT / circuit_name / sample_name / embedding.npy
            sample_embedding_path = REPRESENTATION_OUTPUT / circuit_name / sample_name / "embedding.npy"
            if sample_embedding_path.exists():
                embedding = np.load(sample_embedding_path)
                circuit_embeddings['samples'].append({
                    'name': sample_name,
                    'embedding': embedding
                })
                print(f"✓ 加载 {circuit_name}/{sample_name} embedding: {embedding.shape}")
            else:
                print(f"⚠️  未找到 {circuit_name}/{sample_name} embedding: {sample_embedding_path}")
        
        if circuit_embeddings['ground_truth'] is not None or circuit_embeddings['samples']:
            embeddings[circuit_name] = circuit_embeddings
    
    print(f"\n{'='*80}")
    print(f"✓ 加载了 {len(embeddings)} 个电路的embeddings")
    print(f"{'='*80}\n")
    return embeddings

def compute_distances(embeddings):
    """
    计算样本与ground truth之间的欧氏距离
    """
    print("\n" + "="*80)
    print("计算Embedding距离")
    print("="*80 + "\n")
    
    distances = {}
    
    for circuit_name, circuit_emb in embeddings.items():
        if circuit_emb['ground_truth'] is None:
            print(f"⚠️  跳过 {circuit_name}（缺少ground truth）")
            continue
        
        gt_emb = circuit_emb['ground_truth']
        sample_distances = []
        
        print(f"\n电路: {circuit_name}")
        print(f"Ground truth embedding shape: {gt_emb.shape}")
        
        for sample in circuit_emb['samples']:
            sample_emb = sample['embedding']
            # 计算欧氏距离
            distance = np.linalg.norm(gt_emb - sample_emb)
            sample_distances.append({
                'name': sample['name'],
                'distance': float(distance)
            })
            print(f"  {sample['name']}: 距离 = {distance:.4f}")
        
        distances[circuit_name] = sample_distances
    
    print(f"\n✓ 距离计算完成\n")
    return distances

def compute_grpo_rewards(distances):
    """
    基于距离计算GRPO奖励
    距离越小，奖励越高
    在每组内进行排序，最好的样本获得最高奖励
    """
    print("\n" + "="*80)
    print("计算GRPO奖励")
    print("="*80 + "\n")
    
    rewards = {}
    
    for circuit_name, sample_distances in distances.items():
        if not sample_distances:
            continue
        
        # 按距离排序（升序，距离越小越好）
        sorted_samples = sorted(sample_distances, key=lambda x: x['distance'])
        
        # 计算奖励（基于排名的优势）
        n_samples = len(sorted_samples)
        circuit_rewards = []
        
        print(f"\n电路: {circuit_name}")
        print(f"样本数: {n_samples}")
        
        for rank, sample in enumerate(sorted_samples):
            # 奖励计算：最好的样本得1.0，最差的得-1.0，线性分布
            if n_samples > 1:
                reward = 1.0 - (2.0 * rank / (n_samples - 1))
            else:
                reward = 0.0
            
            circuit_rewards.append({
                'name': sample['name'],
                'distance': sample['distance'],
                'rank': rank + 1,
                'reward': reward
            })
            
            print(f"  {sample['name']}: 距离={sample['distance']:.4f}, 排名={rank+1}/{n_samples}, 奖励={reward:.4f}")
        
        rewards[circuit_name] = circuit_rewards
    
    print(f"\n✓ GRPO奖励计算完成\n")
    return rewards

def main():
    """主函数"""
    print("\n" + "🎁 "*40)
    print("GRPO 奖励计算器")
    print("🎁 "*40 + "\n")
    
    # 加载样本元数据
    metadata_path = GRPO_DIR / "samples_metadata.json"
    if not metadata_path.exists():
        print(f"❌ 错误: 样本元数据不存在: {metadata_path}")
        print("   请先运行 generate_samples.py")
        return
    
    with open(metadata_path, 'r') as f:
        samples_metadata = json.load(f)
    
    print(f"加载了 {len(samples_metadata)} 个电路的样本元数据")
    for circuit_name, samples in samples_metadata.items():
        print(f"  - {circuit_name}: {len(samples)} 个样本")
    
    # 准备电路数据
    temp_circuits_dir, circuit_paths = prepare_circuits_for_representation(samples_metadata)
    
    # 检查是否有成功准备的电路
    if not circuit_paths:
        print("❌ 没有成功准备的电路数据，无法继续")
        return
    
    # 运行Representation推理
    if not run_representation_inference(temp_circuits_dir, circuit_paths):
        print("❌ Representation推理失败，无法继续")
        return
    
    # 加载embeddings
    embeddings = load_embeddings(circuit_paths)
    
    if not embeddings:
        print("❌ 没有成功加载的embeddings，无法继续")
        return
    
    # 计算距离
    distances = compute_distances(embeddings)
    
    # 计算GRPO奖励
    rewards = compute_grpo_rewards(distances)
    
    # 保存结果
    rewards_path = GRPO_DIR / "rewards.json"
    with open(rewards_path, 'w') as f:
        json.dump(rewards, f, indent=2)
    
    distances_path = GRPO_DIR / "distances.json"
    with open(distances_path, 'w') as f:
        json.dump(distances, f, indent=2)
    
    print(f"\n{'='*80}")
    print(f"✓ 奖励计算完成！")
    print(f"  奖励文件: {rewards_path}")
    print(f"  距离文件: {distances_path}")
    print(f"{'='*80}\n")
    
    # 清理临时文件
    print("清理临时文件...")
    if temp_circuits_dir.exists():
        shutil.rmtree(temp_circuits_dir)
        print(f"  ✓ 删除临时电路目录: {temp_circuits_dir}")
    for temp_file in GRPO_DIR.glob("temp_gt_*.py"):
        temp_file.unlink()
        print(f"  ✓ 删除临时脚本: {temp_file.name}")
    print("✓ 清理完成\n")

if __name__ == "__main__":
    main()
