"""
布线器核心算法
"""
from typing import List, Dict, Tuple, Set
from collections import defaultdict
from datetime import datetime

from Routing.data_loader import Pin, Net, Segment


class SingleRowRouter:
    def __init__(self, json_data):
        """初始化布线器"""
        self.cells = json_data["cells"]
        self.nets_data = json_data["nets"] if len(json_data) > 1 else []
        self.pins = []
        self.nets = {}
        self.layers = {}  # 每层的轨道
        self.track_y_positions = []  # 轨道的Y坐标
        self.track_spacing = 5000  # 轨道间距
        self.base_y = None
        self.max_layers = 1  # 最大金属层数（METAL1, METAL2, METAL3）
        
        # 添加统计信息
        self.total_wirelength = 0
        self.total_via_count = 0
        self.net_statistics = {}  # 每个net的统计信息
        
        self._parse_cells()
        self._create_nets()
        
    def _parse_cells(self):
        """解析单元格信息，提取所有引脚"""
        for cell in self.cells:
            cell_x, cell_y = cell["location"]
            cell_height = cell["height"]
            
            for pin in cell["pins"]:
                abs_x = cell_x + pin["offset"][0]
                abs_y = cell_y + pin["offset"][1]
                
                pin_obj = Pin(
                    name=pin["name"],
                    cell_id=cell["id"],
                    absolute_x=abs_x,
                    absolute_y=abs_y,
                    net=pin["net"],
                    cell_y_min=cell_y - 0.4 * cell_height,
                    cell_y_max=cell_y + 0.4 * cell_height
                )
                self.pins.append(pin_obj)
                
                # 记录基准Y坐标
                if self.base_y is None:
                    self.base_y = abs_y
                    
    def _create_nets(self):
        """根据引脚创建线网"""
        net_pins = defaultdict(list)
        for pin in self.pins:
            net_pins[pin.net].append(pin)
            
        for net_name, pins in net_pins.items():
            self.nets[net_name] = Net(name=net_name, pins=pins)
            
    def _get_net_span(self, net: Net) -> Tuple[int, int]:
        """获取线网的水平跨度"""
        x_coords = [pin.absolute_x for pin in net.pins]
        return min(x_coords), max(x_coords)
    
    def _check_overlap(self, seg1: Tuple[int, int], seg2: Tuple[int, int]) -> bool:
        """检查两个线段是否重叠"""
        return not (seg1[1] < seg2[0] or seg2[1] < seg1[0])
    
    def _build_vertical_constraint_graph(self) -> Dict[str, Set[str]]:
        """构建垂直约束图
        如果两个线网在水平方向上重叠，它们不能放在同一轨道上
        """
        constraints = defaultdict(set)
        net_names = list(self.nets.keys())
        
        for i in range(len(net_names)):
            for j in range(i + 1, len(net_names)):
                net1, net2 = self.nets[net_names[i]], self.nets[net_names[j]]
                span1, span2 = self._get_net_span(net1), self._get_net_span(net2)
                
                if self._check_overlap(span1, span2):
                    constraints[net_names[i]].add(net_names[j])
                    constraints[net_names[j]].add(net_names[i])
                    
        return constraints
    
    def _assign_tracks_multi_layer(self, constraints: Dict[str, Set[str]]):
        """使用多层金属进行轨道分配"""
        # 按照线网的起始位置排序
        sorted_nets = sorted(self.nets.values(), 
                           key=lambda n: self._get_net_span(n)[0])
        
        # 初始化每层的轨道
        self.layers = {i: [] for i in range(1, self.max_layers + 1)}
        
        for net in sorted_nets:
            x_start, x_end = self._get_net_span(net)
            assigned = False
            
            # 尝试在每一层分配
            for layer in range(1, self.max_layers + 1):
                tracks_in_layer = self.layers[layer]
                
                # 尝试将线网分配到该层的现有轨道
                for track_idx, track_segments in enumerate(tracks_in_layer):
                    can_place = True
                    
                    # 检查是否与该轨道上的现有线段冲突
                    for seg in track_segments:
                        if seg.net_name in constraints.get(net.name, set()):
                            can_place = False
                            break
                        if self._check_overlap((x_start, x_end), 
                                             (seg.x_start, seg.x_end)):
                            can_place = False
                            break
                            
                    if can_place:
                        net.track = track_idx
                        net.layer = layer
                        segment = Segment(net.name, x_start, x_end, track_idx, layer)
                        track_segments.append(segment)
                        assigned = True
                        break
                        
                if assigned:
                    break
                    
                # 如果无法分配到现有轨道，在当前层创建新轨道
                if not assigned and layer == 1:  # 优先使用第一层
                    net.track = len(tracks_in_layer)
                    net.layer = layer
                    segment = Segment(net.name, x_start, x_end, len(tracks_in_layer), layer)
                    tracks_in_layer.append([segment])
                    assigned = True
                    break
                    
            # 如果第一层满了，使用更高层
            if not assigned:
                for layer in range(2, self.max_layers + 1):
                    tracks_in_layer = self.layers[layer]
                    net.track = len(tracks_in_layer)
                    net.layer = layer
                    segment = Segment(net.name, x_start, x_end, len(tracks_in_layer), layer)
                    tracks_in_layer.append([segment])
                    break
                    
    def _calculate_track_positions(self):
        """计算每个轨道的Y坐标"""
        # 计算所有层中的最大轨道数
        max_tracks = max(len(tracks) for tracks in self.layers.values()) if self.layers else 0
        
        self.track_y_positions = []
        
        if max_tracks == 0:
            return
        
        # 计算起始位置，使轨道均匀分布在base_y周围
        start_offset = -(max_tracks - 1) * self.track_spacing / 2
        
        for i in range(max_tracks):
            y = self.base_y + start_offset + i * self.track_spacing
            self.track_y_positions.append(y)
            
    def route(self):
        """执行布线"""
        # 构建约束图
        constraints = self._build_vertical_constraint_graph()
        
        # 分配轨道（多层）
        self._assign_tracks_multi_layer(constraints)
        
        # 计算轨道位置
        self._calculate_track_positions()
        
        # 生成布线结果
        return self._generate_routing_result()
    
    def _calculate_wire_length(self, wire: Dict) -> float:
        """计算单根导线的长度"""
        location = wire["location"]
        x1, y1 = location[0]
        x2, y2 = location[1]
        return abs(x2 - x1) + abs(y2 - y1)
    
    def _generate_routing_result(self) -> List[Dict]:
        """生成布线结果"""
        routing_result = []
        self.total_wirelength = 0
        self.total_via_count = 0
        self.net_statistics = {}
        
        for net_name, net in self.nets.items():
            if net.track == -1:
                continue
                
            track_y = self.track_y_positions[net.track]
            pins = sorted(net.pins, key=lambda p: p.absolute_x)
            
            wires = []
            vias = []
            net_wirelength = 0
            net_via_count = 0
            
            # 生成水平线段（在分配的金属层上）
            if len(pins) >= 2:
                x_start = pins[0].absolute_x
                x_end = pins[-1].absolute_x
                
                wire = {
                    "id": f"{net_name}_track{net.track}_layer{net.layer}",
                    "location": [[x_start, track_y], [x_end, track_y]],
                    "layer": f"METAL{net.layer}"
                }
                wires.append(wire)
                # 计算水平线长
                horizontal_length = abs(x_end - x_start)
                net_wirelength += horizontal_length
                
            # 为每个引脚生成连接
            for pin in pins:
                # pin在其x坐标上覆盖整个cell的y范围
                # 从cell的y范围到轨道的垂直连接
                if track_y < pin.cell_y_min:
                    # 轨道在cell下方
                    vertical_wire = {
                        "id": f"{net_name}_vertical_{pin.cell_id}_{pin.name}",
                        "location": [[pin.absolute_x, pin.cell_y_min], 
                                   [pin.absolute_x, track_y]],
                        "layer": "METAL0"
                    }
                    wires.append(vertical_wire)
                    # 计算垂直线长
                    vertical_length = abs(pin.cell_y_min - track_y)
                    net_wirelength += vertical_length
                    via_y = track_y
                elif track_y > pin.cell_y_max:
                    # 轨道在cell上方
                    vertical_wire = {
                        "id": f"{net_name}_vertical_{pin.cell_id}_{pin.name}",
                        "location": [[pin.absolute_x, pin.cell_y_max], 
                                   [pin.absolute_x, track_y]],
                        "layer": "METAL0"
                    }
                    wires.append(vertical_wire)
                    # 计算垂直线长
                    vertical_length = abs(track_y - pin.cell_y_max)
                    net_wirelength += vertical_length
                    via_y = track_y
                else:
                    # 轨道在cell范围内
                    via_y = track_y
                    
                # 添加过孔连接METAL0和水平布线层
                if net.layer > 0:
                    # 可能需要多个过孔来连接不同层
                    for layer in range(1, net.layer + 1):
                        via = {
                            "id": f"{net_name}_via_{pin.cell_id}_{pin.name}_L{layer-1}to{layer}",
                            "location": [pin.absolute_x, via_y],
                            "layer": [f"METAL{layer-1}", f"METAL{layer}"]
                        }
                        vias.append(via)
                        net_via_count += 1
            
            # 更新总计
            self.total_wirelength += net_wirelength
            self.total_via_count += net_via_count
            
            # 保存每个net的统计信息
            self.net_statistics[net_name] = {
                "wirelength": net_wirelength,
                "via_count": net_via_count,
                "layer": net.layer,
                "track": net.track,
                "pin_count": len(pins)
            }
            
            net_result = {
                "name": net_name,
                "wires": wires,
                "vias": vias,
                "layer": net.layer,  # 记录使用的金属层
                "wirelength": net_wirelength,  # 添加线长信息
                "via_count": net_via_count  # 添加过孔数信息
            }
            routing_result.append(net_result)
            
        return routing_result
    
    def get_complete_output(self):
        """生成完整的输出JSON，包含原始单元格信息和布线结果"""
        routing_nets = self._generate_routing_result()
        
        # 统计每层的使用情况
        layer_usage = {i: 0 for i in range(1, self.max_layers + 1)}
        for net in self.nets.values():
            if net.layer > 0:
                layer_usage[net.layer] += 1
        
        # 构建完整的输出
        output = [
            {
                "cells": self.cells  # 保留原始单元格信息
            },
            {
                "nets": routing_nets,  # 新的布线结果
                "routing_info": {
                    "total_nets": len(self.nets),
                    "total_layers": self.max_layers,
                    "layer_usage": layer_usage,
                    "track_spacing": self.track_spacing,
                    "base_y": self.base_y,
                    "track_y_positions": self.track_y_positions,
                    "total_wirelength": self.total_wirelength,  # 添加总线长
                    "total_via_count": self.total_via_count,  # 添加总过孔数
                    "average_wirelength_per_net": self.total_wirelength / len(self.nets) if self.nets else 0,
                    "average_vias_per_net": self.total_via_count / len(self.nets) if self.nets else 0,
                    "timestamp": datetime.now().isoformat()
                }
            }
        ]
        
        return output
    
    def print_routing_summary(self):
        """打印布线摘要"""
        print("=" * 50)
        print("布线摘要")
        print("=" * 50)
        print(f"总线网数: {len(self.nets)}")
        print(f"使用金属层数: {self.max_layers}")
        print(f"轨道间距: {self.track_spacing}")
        print(f"Track基准Y坐标: {self.base_y}")
        print(f"总线长: {self.total_wirelength}")
        print(f"总过孔数: {self.total_via_count}")
        
        # 统计每层的使用情况
        layer_usage = {i: [] for i in range(1, self.max_layers + 1)}
        for net_name, net in self.nets.items():
            if net.layer > 0:
                layer_usage[net.layer].append(net_name)
        
        print("\n层分配情况:")
        for layer, nets in layer_usage.items():
            print(f"\nMETAL{layer} ({len(nets)} nets):")
            for net_name in nets:
                net = self.nets[net_name]
                span = self._get_net_span(net)
                stats = self.net_statistics.get(net_name, {})
                print(f"  - {net_name}: Track {net.track}, Span [{span[0]}, {span[1]}], "
                      f"Wirelength: {stats.get('wirelength', 0)}, Vias: {stats.get('via_count', 0)}")
                