"""
血管路径查找器
功能：从血管分割图中提取中心线，找出从左上端点到所有其他端点的路径并可视化
"""

import numpy as np
import cv2
from skimage.morphology import skeletonize, binary_dilation, disk
from scipy import ndimage
from collections import deque, defaultdict
import matplotlib.pyplot as plt
from pathlib import Path


class VesselPathFinder:
    """血管路径查找器类"""
    
    def __init__(self, image_path):
        """
        初始化
        Args:
            image_path: 血管分割图路径
        """
        self.image_path = image_path
        self.image = None
        self.binary_mask = None
        self.skeleton = None
        self.endpoints = []
        self.branch_points = []
        self.graph = defaultdict(list)
        self.paths = []
        self.first_branch_points = []  # 存储每条路径的第一个分支点
        self.leftmost_point = None  # 最左侧点
        self.main_start_point = None  # 主起始点（第一分支点）
        self.next_branch_points = []  # 从主起始点出发，每条路径的下一级分支点
        
    def load_image(self):
        """加载并预处理图像"""
        img = cv2.imread(self.image_path, cv2.IMREAD_GRAYSCALE)
        if img is None:
            raise ValueError(f"无法读取图像: {self.image_path}")
        
        self.image = img
        # 二值化
        _, self.binary_mask = cv2.threshold(img, 127, 255, cv2.THRESH_BINARY)
        self.binary_mask = (self.binary_mask > 0).astype(np.uint8)
        print(f"图像加载成功，尺寸: {self.binary_mask.shape}")
        
    def extract_skeleton(self):
        """提取骨架/中心线"""
        # 使用skimage的skeletonize提取骨架
        self.skeleton = skeletonize(self.binary_mask).astype(np.uint8)
        print(f"骨架提取完成，骨架点数: {np.sum(self.skeleton)}")
        
    def find_key_points(self):
        """找出端点和分支点"""
        skeleton = self.skeleton.copy()
        h, w = skeleton.shape
        
        endpoints = []
        branch_points = []
        
        # 遍历骨架上的每个点
        for i in range(1, h-1):
            for j in range(1, w-1):
                if skeleton[i, j] == 0:
                    continue
                
                # 计算3x3邻域内的骨架点数量
                neighbors = skeleton[i-1:i+2, j-1:j+2].copy()
                neighbors[1, 1] = 0  # 不计算中心点
                neighbor_count = np.sum(neighbors)
                
                if neighbor_count == 1:
                    # 端点：只有一个邻居
                    endpoints.append((i, j))
                elif neighbor_count >= 3:
                    # 分支点：有3个或更多邻居
                    branch_points.append((i, j))
        
        self.endpoints = endpoints
        self.branch_points = branch_points
        print(f"找到端点数: {len(endpoints)}, 分支点数: {len(branch_points)}")
        
    def get_neighbors(self, point):
        """获取骨架上某点的相邻点"""
        i, j = point
        neighbors = []
        
        # 8邻域
        for di in [-1, 0, 1]:
            for dj in [-1, 0, 1]:
                if di == 0 and dj == 0:
                    continue
                ni, nj = i + di, j + dj
                if 0 <= ni < self.skeleton.shape[0] and 0 <= nj < self.skeleton.shape[1]:
                    if self.skeleton[ni, nj] > 0:
                        neighbors.append((ni, nj))
        
        return neighbors
    
    def build_graph(self):
        """构建骨架的图结构"""
        skeleton_points = np.argwhere(self.skeleton > 0)
        
        # 为每个骨架点构建邻接关系
        for point in skeleton_points:
            point = tuple(point)
            neighbors = self.get_neighbors(point)
            self.graph[point] = neighbors
        
        print(f"图构建完成，节点数: {len(self.graph)}")
    
    def find_leftmost_point(self):
        """找出骨架上最靠近左侧边缘的点（可以是任意骨架点，不限于端点）"""
        # 获取所有骨架点
        skeleton_points = np.argwhere(self.skeleton > 0)
        
        if len(skeleton_points) == 0:
            print("未找到骨架点")
            return None
        
        # 找出x坐标（列）最小的点
        # 注意：argwhere返回的是(row, col)格式
        # 如果有多个点x坐标相同，则选择y坐标（行）最小的
        leftmost_point = min(skeleton_points, key=lambda p: (p[1], p[0]))
        leftmost_point = tuple(leftmost_point)
        
        print(f"最左侧骨架点位置: {leftmost_point} (行={leftmost_point[0]}, 列={leftmost_point[1]})")
        
        # 判断该点是否为端点或分支点
        point_type = "骨架点"
        if leftmost_point in self.endpoints:
            point_type = "端点"
        elif leftmost_point in self.branch_points:
            point_type = "分支点"
        print(f"该点类型: {point_type}")
        
        return leftmost_point
    
    def find_path_bfs(self, start, end):
        """使用BFS查找两点之间的路径"""
        if start == end:
            return [start]
        
        queue = deque([(start, [start])])
        visited = {start}
        
        while queue:
            current, path = queue.popleft()
            
            for neighbor in self.graph[current]:
                if neighbor in visited:
                    continue
                
                new_path = path + [neighbor]
                
                if neighbor == end:
                    return new_path
                
                visited.add(neighbor)
                queue.append((neighbor, new_path))
        
        return None  # 未找到路径
    
    def find_first_branch_in_path(self, path):
        """找出路径中的第一个分支点"""
        for point in path:
            if point in self.branch_points:
                return point
        return None
    
    def find_all_paths(self):
        """找出从最左侧骨架点到所有端点的路径"""
        start_point = self.find_leftmost_point()
        if start_point is None:
            print("未找到起始点")
            return
        
        self.leftmost_point = start_point
        self.paths = []
        self.first_branch_points = []
        
        for endpoint in self.endpoints:
            if endpoint == start_point:
                continue
            
            path = self.find_path_bfs(start_point, endpoint)
            if path:
                self.paths.append(path)
                # 找出该路径的第一个分支点
                first_branch = self.find_first_branch_in_path(path)
                self.first_branch_points.append(first_branch)
                
                branch_info = f", 第一个分支点: {first_branch}" if first_branch else ", 无分支点"
                print(f"找到路径: {start_point} -> {endpoint}, 长度: {len(path)}{branch_info}")
            else:
                print(f"未找到路径: {start_point} -> {endpoint}")
        
        # 统计第一个分支点
        unique_first_branches = set([b for b in self.first_branch_points if b is not None])
        print(f"总共找到 {len(self.paths)} 条路径")
        print(f"不同的第一分支点数量: {len(unique_first_branches)}")
    
    def find_all_paths_from_first_branch(self):
        """两阶段路径查找：先找第一分支点，然后从第一分支点作为起始点找所有路径"""
        # 第一阶段：找到最左侧点
        leftmost = self.find_leftmost_point()
        if leftmost is None:
            print("未找到起始点")
            return
        
        self.leftmost_point = leftmost
        
        # 找到第一个分支点（最常见的第一分支点）
        print("\n【第一阶段】从最左侧点找第一分支点...")
        temp_paths = []
        temp_first_branches = []
        
        for endpoint in self.endpoints:
            if endpoint == leftmost:
                continue
            
            path = self.find_path_bfs(leftmost, endpoint)
            if path:
                temp_paths.append(path)
                first_branch = self.find_first_branch_in_path(path)
                temp_first_branches.append(first_branch)
        
        # 找出最常见的第一分支点作为主起始点
        valid_branches = [b for b in temp_first_branches if b is not None]
        if not valid_branches:
            print("警告：未找到任何分支点，将使用最左侧点作为起始点")
            self.main_start_point = leftmost
        else:
            # 统计每个分支点出现的次数
            from collections import Counter
            branch_counter = Counter(valid_branches)
            # 选择出现次数最多的分支点
            self.main_start_point = branch_counter.most_common(1)[0][0]
            print(f"第一分支点（主起始点）: {self.main_start_point}")
            print(f"  - 共有 {branch_counter[self.main_start_point]} 条路径通过此分支点")
        
        # 第二阶段：从主起始点（第一分支点）找到所有端点的路径
        print(f"\n【第二阶段】从主起始点 {self.main_start_point} 找所有路径...")
        self.paths = []
        self.next_branch_points = []
        
        for endpoint in self.endpoints:
            if endpoint == self.main_start_point:
                continue
            
            path = self.find_path_bfs(self.main_start_point, endpoint)
            if path:
                self.paths.append(path)
                # 找出从主起始点出发的下一级分支点
                next_branch = self.find_first_branch_in_path(path)
                self.next_branch_points.append(next_branch)
                
                branch_info = f", 下一级分支点: {next_branch}" if next_branch else ", 无分支点"
                print(f"找到路径: {self.main_start_point} -> {endpoint}, 长度: {len(path)}{branch_info}")
            else:
                print(f"未找到路径: {self.main_start_point} -> {endpoint}")
        
        # 统计下一级分支点
        unique_next_branches = set([b for b in self.next_branch_points if b is not None])
        print(f"\n总共找到 {len(self.paths)} 条路径（从主起始点出发）")
        print(f"不同的下一级分支点数量: {len(unique_next_branches)}")
    
    def visualize_paths(self, output_path=None, use_two_stage=False):
        """可视化所有路径
        
        Args:
            output_path: 输出路径
            use_two_stage: 是否使用两阶段模式（从第一分支点出发）
        """
        # 创建RGB图像用于可视化
        h, w = self.skeleton.shape
        vis_image = np.zeros((h, w, 3), dtype=np.uint8)
        
        # 背景：灰色显示原始分割
        vis_image[self.binary_mask > 0] = [50, 50, 50]
        
        # 骨架：白色
        vis_image[self.skeleton > 0] = [200, 200, 200]
        
        # 为每条路径分配不同颜色
        colors = self._generate_colors(len(self.paths))
        
        # 绘制路径
        for path_idx, path in enumerate(self.paths):
            color = colors[path_idx]
            for point in path:
                vis_image[point[0], point[1]] = color
        
        # 标记端点
        for endpoint in self.endpoints:
            # 用圆圈标记
            cv2.circle(vis_image, (endpoint[1], endpoint[0]), 5, (255, 255, 0), 2)
        
        if use_two_stage and self.main_start_point:
            # 两阶段模式：标记最左侧点和主起始点（第一分支点）
            if self.leftmost_point:
                cv2.circle(vis_image, (self.leftmost_point[1], self.leftmost_point[0]), 6, (0, 255, 255), -1)
            
            # 主起始点（第一分支点）用大的绿色圆圈
            cv2.circle(vis_image, (self.main_start_point[1], self.main_start_point[0]), 9, (0, 255, 0), -1)
            cv2.circle(vis_image, (self.main_start_point[1], self.main_start_point[0]), 11, (255, 255, 255), 2)
            
            # 收集所有唯一的下一级分支点
            unique_next_branches = set([b for b in self.next_branch_points if b is not None])
            
            # 标记普通分支点（不是下一级分支点的，也不是主起始点）
            for branch in self.branch_points:
                if branch != self.main_start_point and branch not in unique_next_branches:
                    cv2.circle(vis_image, (branch[1], branch[0]), 3, (255, 0, 255), -1)
            
            # 特别标记下一级分支点（用红色圆圈）
            for next_branch in unique_next_branches:
                cv2.circle(vis_image, (next_branch[1], next_branch[0]), 7, (0, 0, 255), -1)
                cv2.circle(vis_image, (next_branch[1], next_branch[0]), 9, (255, 255, 255), 2)
        else:
            # 单阶段模式：原有逻辑
            start_point = self.leftmost_point if self.leftmost_point else self.find_leftmost_point()
            
            # 最左侧起始点用特殊颜色标记
            if start_point:
                cv2.circle(vis_image, (start_point[1], start_point[0]), 7, (0, 255, 0), -1)
            
            # 收集所有唯一的第一分支点
            unique_first_branches = set([b for b in self.first_branch_points if b is not None])
            
            # 标记普通分支点（不是第一分支点的）
            for branch in self.branch_points:
                if branch not in unique_first_branches:
                    cv2.circle(vis_image, (branch[1], branch[0]), 3, (255, 0, 255), -1)
            
            # 特别标记第一分支点（用更大的红色圆圈）
            for first_branch in unique_first_branches:
                cv2.circle(vis_image, (first_branch[1], first_branch[0]), 8, (0, 0, 255), -1)
                # 外圈标记
                cv2.circle(vis_image, (first_branch[1], first_branch[0]), 10, (255, 255, 255), 2)
        
        # 使用matplotlib显示
        plt.figure(figsize=(12, 12))
        plt.imshow(vis_image)
        
        # 设置标题
        if use_two_stage and self.main_start_point:
            unique_next_branches = set([b for b in self.next_branch_points if b is not None])
            plt.title(f'Vessel Path Visualization - Two Stage ({len(self.paths)} paths from main start point)', fontsize=14)
            
            # 两阶段模式图例
            legend_elements = [
                plt.Line2D([0], [0], marker='o', color='w', label='Leftmost Point',
                          markerfacecolor='c', markersize=8),
                plt.Line2D([0], [0], marker='o', color='w', label='Main Start Point (1st Branch)',
                          markerfacecolor='g', markersize=10),
                plt.Line2D([0], [0], marker='o', color='w', label='Endpoints',
                          markerfacecolor='y', markersize=10),
                plt.Line2D([0], [0], marker='o', color='w', label='Next Level Branch Points',
                          markerfacecolor='r', markersize=9),
                plt.Line2D([0], [0], marker='o', color='w', label='Other Branch Points',
                          markerfacecolor='m', markersize=6),
            ]
        else:
            unique_first_branches = set([b for b in self.first_branch_points if b is not None])
            plt.title(f'Vessel Path Visualization ({len(self.paths)} paths, {len(unique_first_branches)} first branch points)', fontsize=14)
            
            # 单阶段模式图例
            legend_elements = [
                plt.Line2D([0], [0], marker='o', color='w', label='Leftmost Point (Start)',
                          markerfacecolor='g', markersize=10),
                plt.Line2D([0], [0], marker='o', color='w', label='Endpoints',
                          markerfacecolor='y', markersize=10),
                plt.Line2D([0], [0], marker='o', color='w', label='First Branch Points',
                          markerfacecolor='r', markersize=10),
                plt.Line2D([0], [0], marker='o', color='w', label='Other Branch Points',
                          markerfacecolor='m', markersize=8),
            ]
        
        plt.axis('off')
        plt.legend(handles=legend_elements, loc='upper right', prop={'size': 10})
        
        if output_path:
            plt.savefig(output_path, dpi=150, bbox_inches='tight')
            print(f"可视化结果已保存到: {output_path}")
        
        plt.show()
        
        return vis_image
    
    def visualize_individual_paths(self, output_dir=None, use_two_stage=False):
        """单独可视化每条路径
        
        Args:
            output_dir: 输出目录
            use_two_stage: 是否使用两阶段模式
        """
        h, w = self.skeleton.shape
        
        if output_dir:
            Path(output_dir).mkdir(parents=True, exist_ok=True)
        
        # 确定起始点和分支点列表
        if use_two_stage and self.main_start_point:
            start_point = self.main_start_point
            branch_points_list = self.next_branch_points
            branch_label = "Next Branch"
        else:
            start_point = self.leftmost_point if self.leftmost_point else self.find_leftmost_point()
            branch_points_list = self.first_branch_points
            branch_label = "First Branch"
        
        for path_idx, path in enumerate(self.paths):
            # 创建图像
            vis_image = np.zeros((h, w, 3), dtype=np.uint8)
            vis_image[self.binary_mask > 0] = [50, 50, 50]
            vis_image[self.skeleton > 0] = [200, 200, 200]
            
            # 绘制当前路径
            for point in path:
                vis_image[point[0], point[1]] = [0, 255, 255]  # 青色
            
            # 标记起点和终点
            cv2.circle(vis_image, (start_point[1], start_point[0]), 7, (0, 255, 0), -1)
            end_point = path[-1]
            cv2.circle(vis_image, (end_point[1], end_point[0]), 7, (255, 0, 0), -1)
            
            # 标记该路径的分支点
            branch_point = None
            if path_idx < len(branch_points_list):
                branch_point = branch_points_list[path_idx]
            
            if branch_point:
                cv2.circle(vis_image, (branch_point[1], branch_point[0]), 8, (0, 0, 255), -1)
                cv2.circle(vis_image, (branch_point[1], branch_point[0]), 10, (255, 255, 255), 2)
            
            # 准备标题
            branch_text = f", {branch_label}: {branch_point}" if branch_point else f", No {branch_label}"
            
            plt.figure(figsize=(10, 10))
            plt.imshow(vis_image)
            plt.title(f'Path {path_idx + 1}: {start_point} -> {end_point} (Length: {len(path)}{branch_text})', fontsize=12)
            plt.axis('off')
            
            if output_dir:
                output_path = f"{output_dir}/path_{path_idx + 1:03d}.png"
                plt.savefig(output_path, dpi=150, bbox_inches='tight')
                print(f"路径 {path_idx + 1} 已保存到: {output_path}")
            
            plt.close()
    
    def get_paths_from_first_branch(self):
        """获取从第一分支点开始的所有路径（两阶段模式的简化接口）
        
        Returns:
            dict: 包含路径和相关信息的字典
                - main_start_point: 主起始点（第一分支点）
                - paths: 所有路径列表
                - num_paths: 路径数量
                - path_info: 每条路径的详细信息
        """
        if not self.main_start_point:
            raise ValueError("请先使用 use_two_stage=True 运行 run() 方法")
        
        result = {
            'main_start_point': self.main_start_point,
            'leftmost_point': self.leftmost_point,
            'paths': self.paths,
            'num_paths': len(self.paths),
            'path_info': []
        }
        
        for i, (path, next_branch) in enumerate(zip(self.paths, self.next_branch_points)):
            info = {
                'path_id': i + 1,
                'path': path,  # 完整的路径点列表
                'start': path[0],  # 起点（即主起始点）
                'end': path[-1],  # 终点
                'length': len(path),
                'next_branch_point': next_branch
            }
            result['path_info'].append(info)
        
        return result
    
    def get_first_branch_info(self, use_two_stage=False):
        """获取分支点的详细信息
        
        Args:
            use_two_stage: 是否返回两阶段模式的信息
        """
        if use_two_stage and self.main_start_point:
            # 两阶段模式：返回从主起始点出发的路径信息
            unique_next_branches = set([b for b in self.next_branch_points if b is not None])
            
            info = {
                'mode': 'two_stage',
                'leftmost_point': self.leftmost_point,
                'main_start_point': self.main_start_point,
                'total_paths': len(self.paths),
                'unique_next_branches': list(unique_next_branches),
                'num_unique_next_branches': len(unique_next_branches),
                'path_details': []
            }
            
            for i, (path, next_branch) in enumerate(zip(self.paths, self.next_branch_points)):
                path_info = {
                    'path_id': i + 1,
                    'start': path[0],
                    'end': path[-1],
                    'length': len(path),
                    'next_branch_point': next_branch,
                    'has_next_branch': next_branch is not None
                }
                info['path_details'].append(path_info)
        else:
            # 单阶段模式：原有逻辑
            unique_first_branches = set([b for b in self.first_branch_points if b is not None])
            
            info = {
                'mode': 'single_stage',
                'total_paths': len(self.paths),
                'unique_first_branches': list(unique_first_branches),
                'num_unique_first_branches': len(unique_first_branches),
                'path_details': []
            }
            
            for i, (path, first_branch) in enumerate(zip(self.paths, self.first_branch_points)):
                path_info = {
                    'path_id': i + 1,
                    'start': path[0],
                    'end': path[-1],
                    'length': len(path),
                    'first_branch_point': first_branch,
                    'has_branch': first_branch is not None
                }
                info['path_details'].append(path_info)
        
        return info
    
    def _generate_colors(self, n):
        """生成n种不同的颜色"""
        colors = []
        for i in range(n):
            hue = int(180 * i / max(n, 1))
            color = cv2.cvtColor(np.uint8([[[hue, 255, 255]]]), cv2.COLOR_HSV2BGR)[0][0]
            colors.append(tuple(map(int, color)))
        return colors
    
    def run(self, output_path=None, save_individual=False, individual_output_dir=None, use_two_stage=False):
        """运行完整流程
        
        Args:
            output_path: 输出可视化路径
            save_individual: 是否保存单独路径
            individual_output_dir: 单独路径保存目录
            use_two_stage: 是否使用两阶段模式（从第一分支点作为起始点）
        """
        print("=" * 50)
        if use_two_stage:
            print("开始处理血管路径查找（两阶段模式）")
        else:
            print("开始处理血管路径查找")
        print("=" * 50)
        
        # 1. 加载图像
        self.load_image()
        
        # 2. 提取骨架
        self.extract_skeleton()
        
        # 3. 找出关键点
        self.find_key_points()
        
        # 4. 构建图
        self.build_graph()
        
        # 5. 找出所有路径
        if use_two_stage:
            self.find_all_paths_from_first_branch()
        else:
            self.find_all_paths()
        
        # 6. 可视化
        self.visualize_paths(output_path, use_two_stage=use_two_stage)
        
        # 7. 可选：单独保存每条路径
        if save_individual:
            self.visualize_individual_paths(individual_output_dir, use_two_stage=use_two_stage)
        
        print("=" * 50)
        print("处理完成！")
        print("=" * 50)
        
        return self.paths


def main():
    """主函数 - 示例用法"""
    import argparse
    
    parser = argparse.ArgumentParser(description='血管路径查找和可视化')
    parser.add_argument('--input', '-i', type=str, required=True,
                       help='输入血管分割图路径')
    parser.add_argument('--output', '-o', type=str, default=None,
                       help='输出可视化图像路径')
    parser.add_argument('--save-individual', action='store_true',
                       help='是否单独保存每条路径')
    parser.add_argument('--individual-dir', type=str, default='./individual_paths',
                       help='单独路径保存目录')
    parser.add_argument('--two-stage', action='store_true',
                       help='使用两阶段模式（从第一分支点作为起始点）')
    
    args = parser.parse_args()
    
    # 创建路径查找器
    finder = VesselPathFinder(args.input)
    
    # 运行
    paths = finder.run(
        output_path=args.output,
        save_individual=args.save_individual,
        individual_output_dir=args.individual_dir,
        use_two_stage=args.two_stage
    )
    
    # 打印统计信息
    print(f"\n统计信息:")
    print(f"  - 总端点数: {len(finder.endpoints)}")
    print(f"  - 总路径数: {len(paths)}")
    if paths:
        print(f"  - 平均路径长度: {np.mean([len(p) for p in paths]):.2f}")
        print(f"  - 最长路径: {max([len(p) for p in paths])}")
        print(f"  - 最短路径: {min([len(p) for p in paths])}")
    
    # 打印分支点信息
    branch_info = finder.get_first_branch_info(use_two_stage=args.two_stage)
    
    if args.two_stage:
        print(f"\n两阶段路径分析:")
        print(f"  - 最左侧点: {branch_info['leftmost_point']}")
        print(f"  - 主起始点（第一分支点）: {branch_info['main_start_point']}")
        print(f"  - 不同的下一级分支点数量: {branch_info['num_unique_next_branches']}")
        print(f"  - 有下一级分支的路径数: {sum([1 for d in branch_info['path_details'] if d['has_next_branch']])}")
        print(f"  - 无下一级分支的路径数: {sum([1 for d in branch_info['path_details'] if not d['has_next_branch']])}")
        
        if branch_info['unique_next_branches']:
            print(f"\n下一级分支点位置:")
            for i, branch in enumerate(branch_info['unique_next_branches'], 1):
                print(f"    {i}. {branch}")
    else:
        print(f"\n第一分支点信息:")
        print(f"  - 不同的第一分支点数量: {branch_info['num_unique_first_branches']}")
        print(f"  - 有分支点的路径数: {sum([1 for d in branch_info['path_details'] if d['has_branch']])}")
        print(f"  - 无分支点的路径数: {sum([1 for d in branch_info['path_details'] if not d['has_branch']])}")
        
        if branch_info['unique_first_branches']:
            print(f"\n第一分支点位置:")
            for i, branch in enumerate(branch_info['unique_first_branches'], 1):
                print(f"    {i}. {branch}")


if __name__ == '__main__':
    main()

