"""
可视化模块
处理布线结果的图形化显示
"""
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from matplotlib.patches import Rectangle, FancyBboxPatch, Circle
from matplotlib.lines import Line2D
import matplotlib.colors as mcolors


class RoutingVisualizer:
    """布线结果可视化器"""
    
    def __init__(self, router):
        self.router = router
        self.cells = router.cells
        self.nets = router.nets
        self.track_y_positions = router.track_y_positions
        
    def visualize(self, save_path=None, dpi=300, show_grid=False):
        """可视化布线结果，支持多层显示"""
        # 创建图形
        fig, ax = plt.subplots(1, 1, figsize=(18, 12))
        
        # 设置标题
        ax.set_title('Single-Row Routing Visualization', fontsize=16, fontweight='bold')
        
        # 计算边界
        bounds = self._calculate_bounds()
        
        # 绘制各个组件
        self._draw_cells(ax)
        self._draw_tracks(ax, bounds)
        self._draw_routing(ax)
        
        # 设置坐标轴
        self._setup_axes(ax, bounds, show_grid)
        
        # 添加图例
        self._add_legend(ax)
        
        # 调整布局
        plt.tight_layout()
        
        # 保存或显示
        if save_path:
            plt.savefig(save_path, dpi=dpi, bbox_inches='tight')
            print(f"可视化图片已保存到: {save_path}")
        else:
            plt.show()
            
        plt.close()
    
    def _calculate_bounds(self):
        """计算绘图边界"""
        max_cell_x = 0
        max_cell_y = 0
        min_cell_x = float('inf')
        min_cell_y = float('inf')
        
        for cell in self.cells:
            x, y = cell["location"]
            width = cell["width"]
            height = cell["height"]
            
            max_cell_x = max(max_cell_x, x + width)
            max_cell_y = max(max_cell_y, y + height)
            min_cell_x = min(min_cell_x, x)
            min_cell_y = min(min_cell_y, y)
            
        return {
            'min_x': min_cell_x,
            'max_x': max_cell_x,
            'min_y': min_cell_y,
            'max_y': max_cell_y
        }
    
    def _draw_cells(self, ax):
        """绘制单元格"""
        for cell in self.cells:
            x, y = cell["location"]
            width = cell["width"]
            height = cell["height"]
            
            # 绘制单元格矩形
            cell_rect = FancyBboxPatch((x, y), width, height,
                                    boxstyle="round,pad=100",
                                    facecolor='white',
                                    edgecolor='black',
                                    linewidth=2,
                                    alpha=0.8)
            ax.add_patch(cell_rect)
            
            # 添加单元格名称
            cell_name = cell["id"].replace("|", "")
            ax.text(x + width/2, y + height/2, cell_name,
                ha='center', va='center', fontsize=9, fontweight='bold')
            
            # 绘制引脚
            for pin in cell["pins"]:
                pin_x = x + pin["offset"][0]
                pin_y = y + pin["offset"][1]
                
                # 绘制垂直线表示pin覆盖整个cell高度
                ax.plot([pin_x, pin_x], [y + 0.1 * height, y + 0.9 * height], 
                    color='red', linewidth=3, alpha=0.8, zorder=4)
    
    def _draw_tracks(self, ax, bounds):
        """绘制轨道背景"""
        if not self.track_y_positions:
            return
            
        track_line_start = bounds['min_x'] - 5000
        track_line_end = bounds['max_x'] + 5000
        
        for i, track_y in enumerate(self.track_y_positions):
            # 绘制轨道线
            ax.axhline(y=track_y, color='lightblue', linestyle=':', 
                    alpha=0.3, linewidth=1)
            
            # 添加轨道标签
            ax.text(track_line_start - 1000, track_y, f'Track {i}',
                ha='right', va='center', fontsize=8,
                bbox=dict(boxstyle="round,pad=0.3", facecolor='lightyellow', alpha=0.7))
    
    def _draw_routing(self, ax):
        """绘制布线"""
        # 生成颜色映射
        net_names = list(self.nets.keys())
        colors = plt.cm.Set3(np.linspace(0, 1, len(net_names)))
        net_colors = dict(zip(net_names, colors))
        
        # 定义每层的线型
        layer_styles = {
            1: {'linestyle': '-', 'linewidth': 3, 'alpha': 0.9},
            2: {'linestyle': '--', 'linewidth': 3, 'alpha': 0.8},
            3: {'linestyle': '-.', 'linewidth': 3, 'alpha': 0.7}
        }
        
        self.net_legend_elements = []
        
        for net_name, net in self.nets.items():
            if net.track == -1:
                continue
                
            color = net_colors[net_name]
            track_y = self.track_y_positions[net.track]
            pins = sorted(net.pins, key=lambda p: p.absolute_x)
            layer_style = layer_styles.get(net.layer, layer_styles[1])
            
            # 创建图例元素
            net_line = Line2D([0], [0], 
                            color=color, 
                            linestyle=layer_style['linestyle'],
                            linewidth=2,
                            label=f'{net_name} (T{net.track}, M{net.layer})')
            self.net_legend_elements.append(net_line)
            
            # 绘制水平线段
            if len(pins) >= 2:
                x_start = pins[0].absolute_x
                x_end = pins[-1].absolute_x
                
                # 主水平线
                ax.plot([x_start, x_end], [track_y, track_y],
                    color=color, 
                    linestyle=layer_style['linestyle'],
                    linewidth=layer_style['linewidth'], 
                    alpha=layer_style['alpha'])
                
                # 添加线网名称和层信息
                ax.text((x_start + x_end) / 2, track_y + 300, 
                    f"{net_name}\nMETAL{net.layer}",
                    ha='center', va='bottom', fontsize=7,
                    bbox=dict(boxstyle="round,pad=0.2", 
                            facecolor=color, alpha=0.3))
            
            # 绘制垂直连接
            for pin in pins:
                # 确定垂直线的起始和结束点
                if track_y < pin.cell_y_min:
                    vert_start = pin.cell_y_min
                    vert_end = track_y
                elif track_y > pin.cell_y_max:
                    vert_start = pin.cell_y_max
                    vert_end = track_y
                else:
                    vert_start = track_y
                    vert_end = track_y
                
                if vert_start != vert_end:
                    # 垂直线（METAL0）
                    ax.plot([pin.absolute_x, pin.absolute_x],
                        [vert_start, vert_end],
                        color=color, linewidth=2, alpha=0.8,
                        linestyle='-')
                
                # 绘制过孔
                via_circle = Circle((pin.absolute_x, track_y), 250,
                                facecolor=color,
                                edgecolor='black',
                                linewidth=1.5,
                                zorder=6)
                ax.add_patch(via_circle)
                
                # 如果是多层连接，添加额外标记
                if net.layer > 1:
                    inner_circle = Circle((pin.absolute_x, track_y), 150,
                                    facecolor='white',
                                    edgecolor='black',
                                    linewidth=1,
                                    zorder=7)
                    ax.add_patch(inner_circle)
    
    def _setup_axes(self, ax, bounds, show_grid):
        """设置坐标轴"""
        ax.set_xlabel('X Coordinate', fontsize=12)
        ax.set_ylabel('Y Coordinate', fontsize=12)
        
        # 添加网格
        if show_grid:
            ax.grid(True, alpha=0.3, linestyle=':')
        
        # 设置坐标轴范围
        y_values = [bounds['min_y'], bounds['max_y']]
        if self.track_y_positions:
            y_values.extend(self.track_y_positions)
        
        min_y = min(y_values)
        max_y = max(y_values)
        
        x_margin = max((bounds['max_x'] - bounds['min_x']) * 0.1, 5000)
        y_margin = max((max_y - min_y) * 0.1, 5000)
        
        ax.set_xlim(bounds['min_x'] - x_margin, bounds['max_x'] + x_margin)
        ax.set_ylim(min_y - y_margin, max_y + y_margin)
        
        # 设置纵横比
        ax.set_aspect('equal', adjustable='box')
    
    def _add_legend(self, ax):
        """添加图例"""
        # 根据net数量决定列数
        n_nets = len(self.net_legend_elements)
        if n_nets <= 10:
            ncol = 1
        elif n_nets <= 20:
            ncol = 2
        else:
            ncol = 3
            
        net_legend = ax.legend(handles=self.net_legend_elements,
                            bbox_to_anchor=(0.98, 0.98), 
                            loc='upper right',
                            ncol=ncol,
                            title='Nets (Track, Metal)',
                            fontsize=8,
                            title_fontsize=9,
                            borderaxespad=0.)
        
        ax.add_artist(net_legend)