import json
import os
import numpy as np
from typing import Dict, List, Tuple, Any
import matplotlib.pyplot as plt
import sys
import random
from pathlib import Path
from collections import defaultdict

class CircuitProcessor:
    """统一的电路处理器，包含图构建、矩阵转换和特征向量提取功能"""
    
    def __init__(self, scale_factor=100):
        """
        初始化电路处理器
        
        Args:
            scale_factor: 缩放因子，用于将电路坐标转换为矩阵索引
        """
        self.scale_factor = scale_factor
        self.net_mapping = {}
        self.component_types = {
            'cell': 1,
            'pin': 2,
            'wire': 3,
            'via': 4,
        }
        
    def load_circuit_data(self, file_path):
        """从JSON文件加载电路数据"""
        try:
            with open(file_path, 'r', encoding='utf-8') as f:
                data = json.load(f)
            return data
        except FileNotFoundError:
            print(f"错误：找不到文件 {file_path}")
            return None
        except json.JSONDecodeError as e:
            print(f"错误：JSON文件格式错误 - {e}")
            return None
        except Exception as e:
            print(f"错误：读取文件时发生错误 - {e}")
            return None
    
    def analyze_circuit_features(self, data):
        """
        分析电路数据，提取track数、wirelength和via数量
        
        参数:
            data: 电路数据（字典格式）
        
        返回:
            包含[track数, wirelength, via数]的numpy数组
        """
        # 初始化统计变量
        y_coordinates = set()  # 存储所有不同的y坐标
        total_wirelength = 0   # 总线长
        total_vias = 0         # via总数
        
        # 遍历所有nets
        if 'nets' in data:
            nets = data['nets']
            
            for net in nets:
                # 统计wires信息
                if 'wires' in net:
                    for wire in net['wires']:
                        if 'location' in wire:
                            locations = wire['location']
                            # 提取y坐标
                            # location格式: [[x1, y1], [x2, y2]]
                            if len(locations) >= 2:
                                y1 = locations[0][1]
                                y2 = locations[1][1]
                                # 添加y坐标到集合中
                                y_coordinates.add(y1)
                                if y1 != y2:  # 如果wire不是水平的，也添加y2
                                    y_coordinates.add(y2)
                                
                                # 计算wire长度（曼哈顿距离）
                                x1, x2 = locations[0][0], locations[1][0]
                                wire_length = abs(x2 - x1) + abs(y2 - y1)
                                total_wirelength += wire_length
                
                # 统计vias数量
                if 'vias' in net:
                    total_vias += len(net['vias'])
        
        # track数就是不同y坐标的数量
        num_tracks = len(y_coordinates)
        
        # 创建结果向量
        result_vector = np.array([num_tracks, total_wirelength, total_vias], dtype=np.float32)
        
        # 打印统计结果
        print(f"  特征向量分析:")
        print(f"    - Track数量: {num_tracks}")
        print(f"    - 总线长(Wirelength): {total_wirelength}")
        print(f"    - Via数量: {total_vias}")
        
        return result_vector
    
    def build_circuit_graph(self, data):
        """构建电路图结构"""
        if not data or not isinstance(data, dict):
            print("错误：数据格式不正确")
            return None
        
        cells_data = data.get("cells", [])
        nets_data = data.get("nets", [])
        
        # 创建图结构
        graph = {
            "nodes": [],
            "edges": []
        }
        
        # 1. 添加cell节点
        cell_nodes = {}
        for cell in cells_data:
            node = {
                "id": cell["id"],
                "type": "cell",
                "location": cell["location"],
                "rotation": cell["rotation"],
                "width": cell["width"],
                "height": cell["height"],
                "pins": cell.get("pins", [])  # 使用get方法，如果没有pins则返回空列表
            }
            graph["nodes"].append(node)
            cell_nodes[cell["id"]] = node
        
        # 2. 添加net节点
        net_nodes = {}
        for net in nets_data:
            node = {
                "id": net["name"],
                "type": "net"
            }
            graph["nodes"].append(node)
            net_nodes[net["name"]] = node
        
        # 3. 建立cell-net连接（基于pin连接）
        for cell in cells_data:
            # 检查cell是否有pins
            if "pins" in cell and cell["pins"]:
                for pin in cell["pins"]:
                    if "net" in pin:
                        edge = {
                            "source": cell["id"],
                            "target": pin["net"],
                            "type": "cell-net",
                            "pin_name": pin["name"],
                            "pin_offset": pin.get("offset", [0, 0])  # 使用get方法处理可能缺失的offset
                        }
                        graph["edges"].append(edge)
        
        # 4. 检查相邻的cells并建立连接
        def are_cells_adjacent(cell1, cell2):
            """判断两个cell是否相邻"""
            x1, y1 = cell1["location"]
            w1, h1 = cell1["width"], cell1["height"]
            x2, y2 = cell2["location"]
            w2, h2 = cell2["width"], cell2["height"]
            
            # 检查是否在水平方向相邻
            if y1 == y2 and h1 == h2:  # 同一水平线上
                if x1 + w1 == x2 or x2 + w2 == x1:  # 左右相邻
                    return True
            
            # 检查是否在垂直方向相邻
            if x1 == x2 and w1 == w2:  # 同一垂直线上
                if y1 + h1 == y2 or y2 + h2 == y1:  # 上下相邻
                    return True
            
            return False
        
        # 建立相邻cell之间的连接
        for i, cell1 in enumerate(cells_data):
            for j, cell2 in enumerate(cells_data):
                if i < j:  # 避免重复边
                    if are_cells_adjacent(cell1, cell2):
                        edge = {
                            "source": cell1["id"],
                            "target": cell2["id"],
                            "type": "cell-cell",
                            "relationship": "adjacent"
                        }
                        graph["edges"].append(edge)
        
        return graph
    
    def get_circuit_bounds(self, data: Dict) -> Tuple[int, int, int, int]:
        """获取电路的边界"""
        min_x, min_y = float('inf'), float('inf')
        max_x, max_y = float('-inf'), float('-inf')
        
        # 检查cells
        if 'cells' in data:
            for cell in data['cells']:
                x, y = cell['location']
                width = cell['width']
                height = cell['height']
                min_x = min(min_x, x)
                min_y = min(min_y, y)
                max_x = max(max_x, x + width)
                max_y = max(max_y, y + height)
        
        # 检查nets中的wires和vias
        if 'nets' in data:
            for net in data['nets']:
                # 检查wires
                if 'wires' in net:
                    for wire in net['wires']:
                        for point in wire['location']:
                            min_x = min(min_x, point[0])
                            min_y = min(min_y, point[1])
                            max_x = max(max_x, point[0])
                            max_y = max(max_y, point[1])
                
                # 检查vias
                if 'vias' in net:
                    for via in net['vias']:
                        x, y = via['location']
                        min_x = min(min_x, x)
                        min_y = min(min_y, y)
                        max_x = max(max_x, x)
                        max_y = max(max_y, y)
        
        return int(min_x), int(min_y), int(max_x), int(max_y)
    
    def create_net_mapping(self, data: Dict) -> Dict[str, int]:
        """创建网络名称到唯一数值的映射"""
        net_names = set()
        
        # 从cells中收集net名称
        if 'cells' in data:
            for cell in data['cells']:
                # 检查cell是否有pins
                if 'pins' in cell and cell['pins']:
                    for pin in cell['pins']:
                        if 'net' in pin:
                            net_names.add(pin['net'])
        
        # 从nets中收集net名称
        if 'nets' in data:
            for net in data['nets']:
                net_names.add(net['name'])
        
        # 创建映射，从10开始避免与组件类型值冲突
        net_mapping = {name: idx + 10 for idx, name in enumerate(sorted(net_names))}
        return net_mapping
    
    def coord_to_matrix_idx(self, x: int, y: int, min_x: int, min_y: int) -> Tuple[int, int]:
        """将电路坐标转换为矩阵索引"""
        matrix_x = (x - min_x) // self.scale_factor
        matrix_y = (y - min_y) // self.scale_factor
        return int(matrix_y), int(matrix_x)
    
    def draw_line_on_matrix(self, matrix: np.ndarray, start: List[int], end: List[int], 
                           value: int, min_x: int, min_y: int):
        """在矩阵上绘制线段"""
        start_row, start_col = self.coord_to_matrix_idx(start[0], start[1], min_x, min_y)
        end_row, end_col = self.coord_to_matrix_idx(end[0], end[1], min_x, min_y)
        
        # 使用Bresenham算法绘制线段
        points = self.bresenham_line(start_row, start_col, end_row, end_col)
        for row, col in points:
            if 0 <= row < matrix.shape[0] and 0 <= col < matrix.shape[1]:
                # 如果该位置已有值，选择较大的值（优先级）
                if matrix[row, col] == 0 or matrix[row, col] == 1:
                    matrix[row, col] = value
    
    def bresenham_line(self, y0: int, x0: int, y1: int, x1: int) -> List[Tuple[int, int]]:
        """Bresenham线段算法"""
        points = []
        dx = abs(x1 - x0)
        dy = abs(y1 - y0)
        sx = 1 if x0 < x1 else -1
        sy = 1 if y0 < y1 else -1
        err = dx - dy
        
        while True:
            points.append((y0, x0))
            
            if x0 == x1 and y0 == y1:
                break
                
            e2 = 2 * err
            if e2 > -dy:
                err -= dy
                x0 += sx
            if e2 < dx:
                err += dx
                y0 += sy
                
        return points
    
    def convert_to_matrix(self, data: Dict) -> np.ndarray:
        """将电路数据转换为矩阵"""
        # 获取边界
        min_x, min_y, max_x, max_y = self.get_circuit_bounds(data)
        print(f"  Circuit bounds: ({min_x}, {min_y}) to ({max_x}, {max_y})")
        
        # 计算矩阵大小
        matrix_width = (max_x - min_x) // self.scale_factor + 1
        matrix_height = (max_y - min_y) // self.scale_factor + 1
        print(f"  Matrix size: {matrix_height} x {matrix_width}")
        
        # 初始化矩阵
        matrix = np.zeros((matrix_height, matrix_width), dtype=int)
        
        # 创建net映射
        self.net_mapping = self.create_net_mapping(data)
        
        # 处理cells
        if 'cells' in data:
            for cell in data['cells']:
                x, y = cell['location']
                width = cell['width']
                height = cell['height']
                
                # 标记cell区域
                start_row, start_col = self.coord_to_matrix_idx(x, y, min_x, min_y)
                end_row, end_col = self.coord_to_matrix_idx(x + width, y + height, min_x, min_y)
                
                for r in range(start_row, min(end_row + 1, matrix_height)):
                    for c in range(start_col, min(end_col + 1, matrix_width)):
                        if matrix[r, c] == 0:
                            matrix[r, c] = self.component_types['cell']
                
                # 标记pins（如果存在）
                if 'pins' in cell and cell['pins']:
                    for pin in cell['pins']:
                        if 'offset' in pin:
                            pin_x = x + pin['offset'][0]
                            pin_y = y + pin['offset'][1]
                            pin_row, pin_col = self.coord_to_matrix_idx(pin_x, pin_y, min_x, min_y)
                            
                            if 0 <= pin_row < matrix_height and 0 <= pin_col < matrix_width:
                                if 'net' in pin and pin['net'] in self.net_mapping:
                                    matrix[pin_row, pin_col] = self.net_mapping[pin['net']] + self.component_types['pin'] * 100
        
        # 处理nets
        if 'nets' in data:
            for net in data['nets']:
                net_value = self.net_mapping.get(net['name'], 0)
                
                # 处理wires
                if 'wires' in net:
                    for wire in net['wires']:
                        points = wire['location']
                        wire_value = net_value + self.component_types['wire'] * 100
                        
                        # 绘制wire路径
                        for i in range(len(points) - 1):
                            self.draw_line_on_matrix(matrix, points[i], points[i+1], 
                                                    wire_value, min_x, min_y)
                
                # 处理vias
                if 'vias' in net:
                    for via in net['vias']:
                        via_x, via_y = via['location']
                        via_row, via_col = self.coord_to_matrix_idx(via_x, via_y, min_x, min_y)
                        
                        if 0 <= via_row < matrix_height and 0 <= via_col < matrix_width:
                            via_value = net_value + self.component_types['via'] * 100
                            matrix[via_row, via_col] = via_value
        
        return matrix
    
    def convert_to_shape_matrix(self, data: Dict) -> np.ndarray:
        """
        将电路数据转换为形状矩阵，其中每个元素是一个二元组
        第一个数字表示metal区域（0表示不在metal区域，非0表示对应的net）
        第二个数字表示track区域（0或1，1表示在track区域）
        """
        # 获取边界
        min_x, min_y, max_x, max_y = self.get_circuit_bounds(data)
        
        # 计算矩阵大小
        matrix_width = (max_x - min_x) // self.scale_factor + 1
        matrix_height = (max_y - min_y) // self.scale_factor + 1
        
        # 初始化形状矩阵（高度×宽度×2）
        shape_matrix = np.zeros((matrix_height, matrix_width, 2), dtype=int)
        
        # 创建net映射
        self.net_mapping = self.create_net_mapping(data)
        
        # 处理cells中的metal和track区域
        if 'cells' in data:
            for cell in data['cells']:
                x, y = cell['location']
                
                # 处理pins
                if 'pins' in cell and cell['pins']:
                    for pin in cell['pins']:
                        net_value = self.net_mapping.get(pin.get('net', ''), 0)
                        
                        # 处理metal区域
                        if 'metal' in pin and pin['metal']:
                            for metal_rect in pin['metal']:
                                if len(metal_rect) >= 2:
                                    # 获取矩形的左下角和右上角坐标
                                    ll_x, ll_y = metal_rect[0][0] + x, metal_rect[0][1] + y
                                    ur_x, ur_y = metal_rect[1][0] + x, metal_rect[1][1] + y
                                    
                                    # 转换为矩阵索引
                                    ll_row, ll_col = self.coord_to_matrix_idx(ll_x, ll_y, min_x, min_y)
                                    ur_row, ur_col = self.coord_to_matrix_idx(ur_x, ur_y, min_x, min_y)
                                    
                                    # 标记metal区域
                                    for r in range(max(0, ll_row), min(ur_row + 1, matrix_height)):
                                        for c in range(max(0, ll_col), min(ur_col + 1, matrix_width)):
                                            shape_matrix[r, c, 0] = net_value
                        
                        # 处理track区域
                        if 'track' in pin and pin['track']:
                            for track_rect in pin['track']:
                                if len(track_rect) >= 2:
                                    # 获取矩形的左下角和右上角坐标
                                    ll_x, ll_y = track_rect[0][0] + x, track_rect[0][1] + y
                                    ur_x, ur_y = track_rect[1][0] + x, track_rect[1][1] + y
                                    
                                    # 转换为矩阵索引
                                    ll_row, ll_col = self.coord_to_matrix_idx(ll_x, ll_y, min_x, min_y)
                                    ur_row, ur_col = self.coord_to_matrix_idx(ur_x, ur_y, min_x, min_y)
                                    
                                    # 标记track区域
                                    for r in range(max(0, ll_row), min(ur_row + 1, matrix_height)):
                                        for c in range(max(0, ll_col), min(ur_col + 1, matrix_width)):
                                            shape_matrix[r, c, 1] = 1
        
        return shape_matrix
    
    def visualize_matrix(self, matrix: np.ndarray, save_path: str = None):
        """可视化矩阵"""
        plt.figure(figsize=(15, 10))
        
        # 创建自定义颜色映射
        im = plt.imshow(matrix, cmap='tab20', interpolation='nearest', aspect='auto')
        plt.colorbar(im, label='Component/Net Value')
        
        plt.title('Circuit Layout Matrix Representation')
        plt.xlabel('X (scaled)')
        plt.ylabel('Y (scaled)')
        
        # 添加网格
        plt.grid(True, alpha=0.3, linewidth=0.5)
        
        if save_path:
            plt.savefig(save_path, dpi=150, bbox_inches='tight')
            plt.close()  # 关闭图形以节省内存
    
    def visualize_shape_matrix(self, shape_matrix: np.ndarray, save_path: str = None):
        """可视化形状矩阵"""
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 8))
        
        # 可视化metal层
        im1 = ax1.imshow(shape_matrix[:, :, 0], cmap='tab10', interpolation='nearest', aspect='auto')
        ax1.set_title('Metal Layer (Net Values)')
        ax1.set_xlabel('X (scaled)')
        ax1.set_ylabel('Y (scaled)')
        plt.colorbar(im1, ax=ax1, label='Net Value')
        
        # 可视化track层
        im2 = ax2.imshow(shape_matrix[:, :, 1], cmap='binary', interpolation='nearest', aspect='auto')
        ax2.set_title('Track Layer (0 or 1)')
        ax2.set_xlabel('X (scaled)')
        ax2.set_ylabel('Y (scaled)')
        plt.colorbar(im2, ax=ax2, label='Track Presence')
        
        plt.tight_layout()
        
        if save_path:
            plt.savefig(save_path, dpi=150, bbox_inches='tight')
            plt.close()  # 关闭图形以节省内存
    
    def print_graph_summary(self, graph):
        """打印图结构的摘要信息"""
        if not graph:
            return
        
        # 统计节点
        cell_nodes = [n for n in graph["nodes"] if n["type"] == "cell"]
        net_nodes = [n for n in graph["nodes"] if n["type"] == "net"]
        
        print(f"  节点总数: {len(graph['nodes'])}")
        print(f"    - Cell节点: {len(cell_nodes)}")
        print(f"    - Net节点: {len(net_nodes)}")
        
        # 统计边
        cell_net_edges = [e for e in graph["edges"] if e["type"] == "cell-net"]
        cell_cell_edges = [e for e in graph["edges"] if e["type"] == "cell-cell"]
        
        print(f"  边总数: {len(graph['edges'])}")
        print(f"    - Cell-Net连接: {len(cell_net_edges)}")
        print(f"    - Cell-Cell连接: {len(cell_cell_edges)}")
    
    def print_matrix_summary(self, matrix: np.ndarray):
        """打印矩阵摘要信息"""
        unique_values = np.unique(matrix)
        non_zero = np.count_nonzero(matrix)
        total = matrix.shape[0] * matrix.shape[1]
        
        print(f"  矩阵形状: {matrix.shape[0]} × {matrix.shape[1]}")
        print(f"  非零元素: {non_zero}/{total} ({non_zero/total*100:.2f}%)")
        print(f"  唯一值数量: {len(unique_values)}")
    
    def print_shape_matrix_summary(self, shape_matrix: np.ndarray):
        """打印形状矩阵摘要信息"""
        metal_layer = shape_matrix[:, :, 0]
        track_layer = shape_matrix[:, :, 1]
        
        metal_non_zero = np.count_nonzero(metal_layer)
        track_non_zero = np.count_nonzero(track_layer)
        total = shape_matrix.shape[0] * shape_matrix.shape[1]
        
        print(f"  形状矩阵形状: {shape_matrix.shape[0]} × {shape_matrix.shape[1]} × {shape_matrix.shape[2]}")
        print(f"  Metal层非零元素: {metal_non_zero}/{total} ({metal_non_zero/total*100:.2f}%)")
        print(f"  Track层非零元素: {track_non_zero}/{total} ({track_non_zero/total*100:.2f}%)")
        print(f"  Metal层唯一值: {np.unique(metal_layer)}")
        print(f"  Track层唯一值: {np.unique(track_layer)}")


def process_all_circuits(input_dir="./input_circuits", output_dir="./data", scale_factor=200):
    """
    处理指定目录下的所有电路JSON文件
    
    Args:
        input_dir: 输入文件目录
        output_dir: 输出文件目录
        scale_factor: 矩阵转换的缩放因子
    """
    
    # 确保输出目录存在
    Path(output_dir).mkdir(parents=True, exist_ok=True)
    
    # 获取所有JSON文件
    json_files = list(Path(input_dir).glob("*.json"))
    
    if not json_files:
        print(f"在 {input_dir} 目录下没有找到JSON文件")
        return
    
    print(f"找到 {len(json_files)} 个JSON文件待处理")
    print("=" * 60)
    
    # 处理每个文件
    for idx, json_file in enumerate(json_files, 1):
        # 生成随机数作为文件夹名
        random_id = random.randint(100000, 999999)
        circuit_dir = Path(output_dir) / f"circuit_{random_id}"
        circuit_dir.mkdir(parents=True, exist_ok=True)
        
        print(f"\n[{idx}/{len(json_files)}] 处理文件: {json_file.name}")
        print(f"输出目录: {circuit_dir}")
        print("-" * 40)
        
        try:
            # 创建处理器
            processor = CircuitProcessor(scale_factor=scale_factor)
            
            # 加载数据
            print("加载电路数据...")
            data = processor.load_circuit_data(json_file)
            
            if data is None:
                print(f"跳过文件 {json_file.name}（加载失败）")
                continue
            
            # 1. 构建并保存图结构
            print("构建图结构...")
            graph = processor.build_circuit_graph(data)
            
            if graph:
                graph_path = circuit_dir / "graph.json"
                with open(graph_path, 'w', encoding='utf-8') as f:
                    json.dump(graph, f, indent=2, ensure_ascii=False)
                print(f"  图结构已保存到: {graph_path}")
                processor.print_graph_summary(graph)
            
            # 2. 生成并保存矩阵
            print("生成矩阵表示...")
            matrix = processor.convert_to_matrix(data)
            
            if matrix is not None:
                matrix_path = circuit_dir / "matrix.npy"
                np.save(matrix_path, matrix)
                print(f"  矩阵已保存到: {matrix_path}")
                processor.print_matrix_summary(matrix)
                
                # 3. 生成可视化
                print("生成可视化...")
                vis_path = circuit_dir / "visualization.png"
                processor.visualize_matrix(matrix, save_path=str(vis_path))
                print(f"  可视化已保存到: {vis_path}")
            
            # 4. 生成特征向量
            print("提取特征向量...")
            feature_vector = processor.analyze_circuit_features(data)
            
            if feature_vector is not None:
                vector_path = circuit_dir / "vector.npy"
                np.save(vector_path, feature_vector)
                print(f"  特征向量已保存到: {vector_path}")
                print(f"    - 特征向量: {feature_vector}")
                
                # 验证保存的向量
                loaded_vector = np.load(vector_path)
                print(f"  验证: 从文件加载的向量 = {loaded_vector}")
            
            # 5. 生成形状矩阵（新增）
            print("生成形状矩阵...")
            shape_matrix = processor.convert_to_shape_matrix(data)
            
            if shape_matrix is not None:
                shape_path = circuit_dir / "shape.npy"
                np.save(shape_path, shape_matrix)
                print(f"  形状矩阵已保存到: {shape_path}")
                processor.print_shape_matrix_summary(shape_matrix)
                
                # 6. 生成形状矩阵可视化（新增）
                print("生成形状矩阵可视化...")
                shape_vis_path = circuit_dir / "shape.png"
                processor.visualize_shape_matrix(shape_matrix, save_path=str(shape_vis_path))
                print(f"  形状矩阵可视化已保存到: {shape_vis_path}")
            
            # 保存处理信息
            info = {
                "source_file": str(json_file),
                "random_id": random_id,
                "scale_factor": scale_factor,
                "processing_time": str(Path().cwd()),
                "graph_nodes": len(graph["nodes"]) if graph else 0,
                "graph_edges": len(graph["edges"]) if graph else 0,
                "matrix_shape": matrix.shape if matrix is not None else None,
                "shape_matrix_shape": shape_matrix.shape if shape_matrix is not None else None,
                "feature_vector": feature_vector.tolist() if feature_vector is not None else None,
                "features": {
                    "num_tracks": int(feature_vector[0]) if feature_vector is not None else 0,
                    "total_wirelength": int(feature_vector[1]) if feature_vector is not None else 0,
                    "total_vias": int(feature_vector[2]) if feature_vector is not None else 0
                },
                "net_mapping": processor.net_mapping
            }
            
            info_path = circuit_dir / "info.json"
            with open(info_path, 'w', encoding='utf-8') as f:
                json.dump(info, f, indent=2, ensure_ascii=False)
            
            print(f"✓ 成功处理: {json_file.name}")
            
        except Exception as e:
            print(f"✗ 处理文件 {json_file.name} 时出错: {e}")
            import traceback
            traceback.print_exc()
    
    print("\n" + "=" * 60)
    print("所有文件处理完成！")
    print(f"输出目录: {output_dir}")
    
    # 显示处理总结
    processed_dirs = list(Path(output_dir).glob("circuit_*"))
    print(f"成功处理的电路数量: {len(processed_dirs)}")
    
    if processed_dirs:
        print("\n生成的文件类型:")
        print("  - graph.json: 电路图结构")
        print("  - matrix.npy: 电路矩阵表示")
        print("  - vector.npy: 特征向量 [track count, wirelength, via count]")
        print("  - visualization.png: 电路可视化")
        print("  - shape.npy: 形状矩阵表示 (metal和track区域)")
        print("  - shape.png: 形状矩阵可视化")
        print("  - info.json: 处理信息和统计数据")


def main():
    """主函数"""
    # 配置参数
    input_directory = "./input_circuits"
    output_directory = "./data"
    scale = 1
    
    # 确保输入目录存在
    if not Path(input_directory).exists():
        print(f"错误：输入目录 {input_directory} 不存在")
        print("请创建该目录并将JSON文件放入其中")
        
        # 创建输入目录
        Path(input_directory).mkdir(parents=True, exist_ok=True)
        print(f"已创建输入目录: {input_directory}")
        print("请将电路JSON文件放入该目录后重新运行程序")
        return
    
    # 处理所有电路文件
    process_all_circuits(
        input_dir=input_directory,
        output_dir=output_directory,
        scale_factor=scale
    )


if __name__ == "__main__":
    main()