import json
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import numpy as np

def custom_placement(circuit_data, order_list, rotation_list=None):
    """
    根据用户指定的顺序和旋转角度放置电路单元
    
    参数:
    circuit_data: 原始电路数据
    order_list: 用户指定的cell顺序列表，如[1,3,2]表示将第二个和第三个cell位置互换
    rotation_list: 用户指定的旋转角度列表，如['R0','MX','R0']，如果不指定则全部使用R0
    
    返回:
    修改后的电路数据
    """
    
    # 创建数据的深拷贝
    modified_data = json.loads(json.dumps(circuit_data))
    
    # 获取所有单元
    cells = modified_data["cells"]
    
    # 验证order_list的有效性
    if len(order_list) != len(cells):
        raise ValueError(f"order_list长度({len(order_list)})必须等于cell数量({len(cells)})")
    
    if min(order_list) < 0 or max(order_list) >= len(cells):
        raise ValueError(f"order_list中的索引必须在0到{len(cells)-1}之间")
    
    if len(set(order_list)) != len(cells):
        raise ValueError("order_list中包含重复的索引")
    
    # 根据order_list重新排列cells
    reordered_cells = [cells[i] for i in order_list]
    
    # 处理旋转角度
    if rotation_list is not None:
        if len(rotation_list) != len(cells):
            raise ValueError(f"rotation_list长度({len(rotation_list)})必须等于cell数量({len(cells)})")
        
        for i, cell in enumerate(reordered_cells):
            cell["rotation"] = rotation_list[i]
            # 如果翻转，需要调整pin的offset
            if rotation_list[i] == "MX":
                for pin in cell["pins"]:
                    # 保持x坐标不变，y坐标变为高度减去原y坐标
                    pin["offset"][1] = cell["height"] - pin["offset"][1]
    else:
        # 如果不指定旋转角度，全部设为R0
        for cell in reordered_cells:
            cell["rotation"] = "R0"
    
    # 重新计算位置（紧挨着放置，无spacing）
    current_x = 0
    for cell in reordered_cells:
        # 更新位置
        cell["location"] = [current_x, 0]
        current_x += cell["width"]
    
    # 更新网络连接
    modified_data = update_nets(modified_data)
    
    return modified_data

def update_nets(modified_data):
    """
    根据新的单元位置更新网络连接
    """
    cells = modified_data["cells"]
    nets = modified_data["nets"]
    
    # 创建pin位置映射
    pin_positions = {}
    for cell in cells:
        cell_x, cell_y = cell["location"]
        for pin in cell["pins"]:
            pin_x, pin_y = pin["offset"]
            absolute_x = cell_x + pin_x
            absolute_y = cell_y + pin_y
            net_name = pin["net"]
            
            if net_name not in pin_positions:
                pin_positions[net_name] = []
            pin_positions[net_name].append((absolute_x, absolute_y))
    
    # 更新每个网络的wire和via
    for net in nets:
        net_name = net["name"]
        if net_name in pin_positions:
            positions = pin_positions[net_name]
            if len(positions) > 1:
                # 按x坐标排序
                positions.sort(key=lambda pos: pos[0])
                
                # 生成水平连线
                min_x = min(pos[0] for pos in positions)
                max_x = max(pos[0] for pos in positions)
                avg_y = sum(pos[1] for pos in positions) / len(positions)
                
                # 更新wire
                net["wires"] = [{
                    "id": "",
                    "location": [[min_x, avg_y], [max_x, avg_y]],
                    "layer": "METAL1"
                }]
                
                # 更新via
                net["vias"] = []
                for pos in positions:
                    net["vias"].append({
                        "id": "",
                        "location": [pos[0], avg_y],
                        "layer": ["METAL0", "METAL1"]
                    })
    
    return modified_data

def visualize_circuit(circuit_data, filename="circuit_visualization.png"):
    """
    可视化电路布局
    """
    fig, ax = plt.subplots(1, 1, figsize=(15, 8))
    
    cells = circuit_data["cells"]
    nets = circuit_data["nets"]
    
    # 计算总宽度用于设置x轴范围
    total_width = max(cell["location"][0] + cell["width"] for cell in cells) + 10000
    max_height = max(cell["height"] for cell in cells) + 10000
    
    # 设置坐标轴范围
    ax.set_xlim(-5000, total_width)
    ax.set_ylim(-5000, max_height)
    
    # 绘制每个单元
    colors = plt.cm.Set3(np.linspace(0, 1, len(cells)))
    
    for i, cell in enumerate(cells):
        x, y = cell["location"]
        width = cell["width"]
        height = cell["height"]
        
        # 绘制单元边框
        rect = patches.Rectangle((x, y), width, height, 
                                linewidth=2, edgecolor='black', 
                                facecolor=colors[i], alpha=0.7,
                                label=cell["id"])
        ax.add_patch(rect)
        
        # 添加单元ID标签
        ax.text(x + width/2, y + height/2, cell["id"], 
               ha='center', va='center', fontsize=8, fontweight='bold')
        
        # 标记旋转状态
        rotation_text = f"Rot: {cell['rotation']}"
        ax.text(x + width/2, y + height + 500, rotation_text, 
               ha='center', va='bottom', fontsize=7, color='red')
        
        # 绘制pin
        for pin in cell["pins"]:
            pin_x = x + pin["offset"][0]
            pin_y = y + pin["offset"][1]
            ax.plot(pin_x, pin_y, 'ro', markersize=6)
            ax.text(pin_x + 200, pin_y + 200, f"{pin['name']}\n({pin['net']})", 
                   fontsize=6, ha='left', va='bottom')
    
    # 绘制网络连线
    net_colors = plt.cm.tab10(np.linspace(0, 1, len(nets)))
    
    for i, net in enumerate(nets):
        color = net_colors[i % len(net_colors)]
        
        # 绘制wire
        for wire in net["wires"]:
            start, end = wire["location"]
            ax.plot([start[0], end[0]], [start[1], end[1]], 
                   color=color, linewidth=2, label=net["name"])
        
        # 绘制via
        for via in net["vias"]:
            x, y = via["location"]
            ax.plot(x, y, 's', color=color, markersize=8, markeredgecolor='black')
    
    # 设置图例和标签
    ax.set_xlabel('X Coordinate')
    ax.set_ylabel('Y Coordinate')
    ax.set_title('Circuit Placement Visualization')
    ax.grid(True, alpha=0.3)
    
    # 简化图例（避免太多条目）
    handles, labels = ax.get_legend_handles_labels()
    by_label = dict(zip(labels, handles))
    ax.legend(by_label.values(), by_label.keys(), loc='upper right', bbox_to_anchor=(1.15, 1))
    
    plt.tight_layout()
    plt.savefig(filename, dpi=300, bbox_inches='tight')
    plt.show()
    
    return fig

def save_circuit_json(circuit_data, filename="modified_circuit.json"):
    """
    保存修改后的电路数据为JSON文件
    """
    with open(filename, 'w') as f:
        json.dump(circuit_data, f, indent=2)
    
    print(f"电路数据已保存到 {filename}")

def print_circuit_info(circuit_data):
    """
    打印电路布局信息
    """
    print("-" * 80)
    for i, cell in enumerate(circuit_data["cells"]):
        print(f"{i}: {cell['id']} - 位置: {cell['location']}, 旋转: {cell['rotation']}, 宽度: {cell['width']}")
    print("-" * 80)

def main():
    """
    主函数：演示自定义placement功能
    """
    # 使用您提供的完整电路数据
    with open("../data/circuit_data_2.json", 'r') as f:
        circuit_data = json.load(f)
    
    # 打印原始电路信息
    print_circuit_info(circuit_data)
    
    # 示例1：简单的顺序调整 [1,3,2] 表示将第二个和第三个cell位置互换
    order_list = [0, 2, 1, 3, 4, 5, 6, 7, 8]  # 交换索引1和2的cell
    rotation_list = ['R0', 'MX', 'R0', 'R0', 'R0', 'R0', 'R0', 'R0', 'R0']
    
    modified_circuit1 = custom_placement(circuit_data, order_list, rotation_list)
    print_circuit_info(modified_circuit1)
    
    # 保存修改后的JSON文件
    save_circuit_json(modified_circuit1, "custom_placement.json")
    
    # 生成可视化结果
    print("\n生成可视化结果...")
    visualize_circuit(circuit_data, "original_circuit.png")
    visualize_circuit(modified_circuit1, "custom_placement.png")
    
    print("完成！")

if __name__ == "__main__":
    main()