"""
@Description :   全局拼接工具函数
@Author      :   tqychy 
@Time        :   2025/02/23 20:20:19
"""
import sys

sys.path.append("./")
sys.path.append("./visualize")

from copy import deepcopy
from dataclasses import dataclass

import cv2
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
from matplotlib import cm

from visualize.ransac import ransac


def allclose_transform(mat1, mat2, translate_tolerance=100, rotation_tolerance=10):
    # 计算误差矩阵 (mat1 * mat2 的逆)
    err_mat = mat1 @ np.linalg.inv(mat2)
    
    # 平移误差（[0,2] 和 [1,2] 元素的平方和开根号）
    t_err = np.sqrt(err_mat[0, 2]**2 + err_mat[1, 2]**2)
    
    # 根据容差 (1e-3) 调整 err_mat[0,0]
    if abs(err_mat[0, 0] - 1) < 1e-2:
        err_mat[0, 0] = 1.0
    if abs(err_mat[0, 0] + 1) < 1e-2:
        err_mat[0, 0] = -1.0
    
    # 旋转误差（转换为角度）
    r_err = np.arccos(err_mat[0, 0]) * 180 / np.pi
    # print(f"t_err: {t_err}, r_err: {r_err}, {t_err < translate_tolerance and r_err < rotation_tolerance}")
    return t_err < translate_tolerance and r_err < rotation_tolerance

@dataclass
class Vertice:
    """
    图的节点
    :params:
        global_index: 节点在上一级的序号
        img: 节点原始图像索引
        pcd: 节点原始边缘索引
        global_transform: 节点的全局变换矩阵
    """
    global_index: int
    img: int
    pcd: int
    global_transform: np.ndarray = np.eye(3)

    def __eq__(self, other):
        if isinstance(other, Vertice):
            return self.global_index == other.global_index
        else:
            return False
    
    def __lt__(self, other):
        return self.global_index < other.global_index


@dataclass
class Edge:
    """
    图的边
    :params:
        vertices: 边的两个点的序号 (u, v)
        sim_mats: 模型预测的轮廓点相似度矩阵索引
        score: 模型预测的置信率
        transform: u 到 v 的变换矩阵
        rank: 置信率排名
    """
    vertices: tuple
    sim_mats: int
    score: float
    transform: np.ndarray = np.eye(3)
    rank: int = 0

    # def __eq__(self, other):
    #     if isinstance(other, Edge):
    #         u1, v1 = self.vertices
    #         u2, v2 = other.vertices
    #         return (u1 == u2 and v1 == v2) or (u1 == v2 and v1 == u2)
    #     return False


class Graph:
    def __init__(self, vertices: dict, edges: list, imgs: list, pcds: list, sim_mats:list):
        self.vertices = vertices
        self.edges = edges
        self.imgs = imgs
        self.pcds = pcds
        self.sim_mats = sim_mats
        self.init_edges()

    def init_edges(self):
        """
        初始化边：遍历所有边，计算相对变换矩阵并删去无法得到矩阵的边，最后计算 rank
        """
        # print(f"原始边数 {len(self.edges)}")
        error_idx = set()
        for e_idx in range(len(self.edges)):
            transform = self.get_relative_transformation(e_idx)
            if transform is None:
                error_idx.add(e_idx)
            else:
                self.edges[e_idx].transform = transform
        self.edges = [edge for e_idx, edge in enumerate(self.edges) if e_idx not in error_idx]

        # print(f"处理后边数 {len(self.edges)}")
        score_ranks = []
        for e_idx in range(len(self.edges)):
            score_ranks.append((self.edges[e_idx].score, e_idx))
        score_ranks.sort()
        for rank, (_, e_idx) in enumerate(score_ranks):
            self.edges[e_idx].rank = rank

    def get_relative_transformation(self, e_idx: int):
        """
        计算相对变换矩阵，默认返回 idx1 旋转到 idx2 的变换矩阵
        :params:
            e_idx: 边的编号
        :returns:
            transformation: 相对变换矩阵，未找到(即检测到不可拼接)时为 None
        """
        v_idx1, v_idx2 = self.edges[e_idx].vertices
        pcd1, pcd2 = self.pcds[self.vertices[v_idx1].pcd], self.pcds[self.vertices[v_idx2].pcd]

        idx = np.where(self.sim_mats[self.edges[e_idx].sim_mats])
        pcd1_inter = pcd1[idx[0]].reshape(-1, 2)
        pcd2_inter = pcd2[idx[1]].reshape(-1, 2)

        transformation, valid = ransac(pcd1_inter, pcd2_inter)
        if valid == False:
            return None
        transformation = np.delete(transformation[:2], 2, axis=-1)
        transformation = np.vstack([transformation, [0, 0, 1]])

        return transformation

    @staticmethod
    def normalize_loop(vertices):
        """标准化环路顶点顺序以去重"""
        n = len(vertices)
        if n == 0:
            return tuple()
        # 生成所有循环排列
        rotations = [tuple(vertices[i:] + vertices[:i]) for i in range(n)]
        # 将 reversed(vertices) 转换为列表后再进行切片
        reversed_vertices = list(reversed(vertices))
        reversed_rot = [tuple(reversed_vertices[i:] + reversed_vertices[:i]) for i in range(n)]
        return min(rotations + reversed_rot)  # 选择字典序最小的表示
    
    def create_loop(self, loop_vertices, loop_edges):
        idx_convert = {}
        sub_vertices = []
        sub_edges = []
        for local_idx, global_idx in enumerate(loop_vertices):
            idx_convert[global_idx] = local_idx
            new_v = deepcopy(self.vertices[global_idx])
            new_v.global_index = global_idx
            sub_vertices.append(new_v)
        for e_idx in loop_edges:
            new_e = deepcopy(self.edges[e_idx])
            global_idx1, global_idx2 = new_e.vertices
            new_e.vertices = idx_convert[global_idx1], idx_convert[global_idx2]
            sub_edges.append(new_e)
        return Loop(sub_vertices, sub_edges, [(i, i) for i in range(len(loop_vertices))], self)

    def gen_loops(self, max_loop_length=4):
        """
        生成所有长度不超过max_loop_length的环路子图
        Args:
            max_loop_length: 环路的最大边数（默认4）
        Yields:
            Graph: 包含环路顶点和边的子图
        """
        # 构建邻接表
        adj_tab = {}
        for e_idx, edge in enumerate(self.edges):
            u, v = edge.vertices
            # 无向图处理：为两个顶点都添加邻接信息
            if u not in adj_tab:
                adj_tab[u] = []
            if v not in adj_tab:
                adj_tab[v] = []
            adj_tab[u].append((v, e_idx))
            adj_tab[v].append((u, e_idx))

        visited_loops = set()  # 记录已发现的唯一环路
        ret_loops = []

        # 从每个顶点出发进行DFS搜索
        for start in range(len(self.vertices)):
            # 堆栈元素：(当前顶点, 路径顶点列表, 已用边索引列表, 当前边数)
            stack = [(start, [start], [], 0)]

            while len(stack) > 0:
                current, path, edges, length = stack.pop()

                # 遍历当前顶点的所有邻接边
                for neighbor, e_idx in adj_tab.get(current, []):
                    # 跳过已使用的边
                    if e_idx in edges:
                        continue

                    # 发现闭合环路
                    if neighbor == start:
                        if 3 <= length + 1 <= max_loop_length:
                            # 标准化环路的顶点顺序
                            loop_vertices = path  # 路径不包含最后的start节点
                            norm = self.normalize_loop(loop_vertices)
                            if norm not in visited_loops:
                                visited_loops.add(norm)
                                # 构建子图的边列表
                                loop_edges = edges + [e_idx]
                                ret_loops.append(self.create_loop(loop_vertices, loop_edges))
                    # 继续DFS搜索
                    else:
                        if neighbor not in path and (length + 1) < max_loop_length:
                            new_path = path + [neighbor]
                            new_edges = edges + [e_idx]
                            stack.append(
                                (neighbor, new_path, new_edges, length + 1))
        
        return ret_loops
    
    def visualize(self, name, loops):
        # 创建 NetworkX 图对象
        G = nx.Graph()
        for v_idx in self.vertices.keys():
            G.add_node(v_idx)
        for e in self.edges:
            u, v = e.vertices
            G.add_edge(u, v)

        # 计算布局
        pos = nx.spring_layout(G)

        # 绘制图形
        plt.figure(figsize=(10, 10))
        # 绘制整个图（背景）
        nx.draw(
            G, 
            pos, 
            with_labels=True, 
            node_color='lightblue', 
            node_size=500, 
            font_size=10, 
            edge_color='gray', 
            width=0.5
        )

        # 高亮显示环路
        if loops:
            colors = cm.get_cmap('tab10', len(loops))  # 使用 tab10 颜色图为每个环路分配颜色
            for i, loop in enumerate(loops):
                # 获取环路的全局边
                loop_global_edges = []
                for edge in loop.edges:
                    u, v = edge.vertices
                    u_global = loop.vertices[u].global_index
                    v_global = loop.vertices[v].global_index
                    loop_global_edges.append((u_global, v_global))
                # 绘制环路的边
                nx.draw_networkx_edges(
                    G, 
                    pos, 
                    edgelist=loop_global_edges, 
                    edge_color=colors(i), 
                    width=2
                )

        # 添加标题并显示
        plt.title("Graph with Highlighted Loops")
        plt.savefig(name)


class Loop:
    def __init__(self, vertices: list, edges: list, path: list, graph: Graph):
        """
        Args:
            vertices: 环路的所有节点
            edges: 环路的所有边
            path: 环路路径 [(v_idx, e_idx), ...]
            graph: 所属的图
        """
        self.vertices = vertices
        self.edges = edges
        self.path = path
        self.graph = graph
        self.score = 0
        self.has_set_global_transforms = False

        for edge in edges:
            self.score += edge.rank
    
    def __eq__(self, other):
        """
        检查已设置全局变换矩阵的环路是否相等，如果含有相同的节点即认为相等
        """
        has_set_global_transforms = other.has_set_global_transforms and self.has_set_global_transforms
        if isinstance(other, Loop) and has_set_global_transforms:
            return sorted(self.vertices) == sorted(other.vertices)
        else:
            return False
    
    def __lt__(self, other):
        return self.score < other.score
    
    def check_closure(self):
        """
        检查环路是否满足环路闭包条件
        """
        cur_trans = np.eye(3)
        cur_v = self.path[0][0]
        for _, e_idx in self.path:
            u, v = self.edges[e_idx].vertices
            trans = self.edges[e_idx].transform
            if cur_v == u:
                cur_trans = cur_trans @ trans
                cur_v = v
            else:
                trans = np.linalg.inv(trans)
                cur_trans = cur_trans @ trans
                cur_v = u

        return allclose_transform(cur_trans, np.eye(3))
    
    def check_no_intersection(self, tolerance=0.):
        """
        检查环路拼合图像是否有交叠
        TODO: 还未实现这个函数
        """
        return True
    
    def overall_transform(self, transform):
        """
        全体应用一个变换矩阵
        """
        assert self.has_set_global_transforms, "环路尚未设置全局仿射变换矩阵"
        for vertice in self.vertices:
            vertice.global_transform = vertice.global_transform @ transform
    

    def set_global_transforms_by_path(self, root_idx=None):
        if root_idx == None or root_idx < 0 or root_idx >= len(self.vertices):
            root_idx = self.path[0][0]
        cur_trans = np.eye(3)
        cur_v = root_idx
        for _, e_idx in self.path[:-1]:
            u, v = self.edges[e_idx].vertices
            trans = self.edges[e_idx].transform
            if cur_v == u:
                trans = np.linalg.inv(trans)
                cur_trans = cur_trans @ trans
                cur_v = v
            else:
                cur_trans = cur_trans @ trans
                cur_v = u
            self.vertices[cur_v].global_transform = cur_trans
        self.has_set_global_transforms = True

    def build_pic(self):
        assert self.has_set_global_transforms, "环路尚未设置全局仿射变换矩阵"
        # 收集所有变换后的点云坐标以确定画布尺寸
        all_points = []
        imgs, pcds = self.graph.imgs, self.graph.pcds
        for vertice in self.vertices:
            T = vertice.global_transform
            pcd = pcds[vertice.pcd]
            homogeneous = np.hstack([pcd, np.ones((len(pcd), 1))])
            transformed = (T @ homogeneous.T).T
            all_points.append(transformed)
        all_points = np.concatenate(all_points, axis=0)

        min_x, max_x = np.min(all_points[:, 0]), np.max(all_points[:, 0])
        min_y, max_y = np.min(all_points[:, 1]), np.max(all_points[:, 1])

        # 计算画布的尺寸
        canvas_width = int(np.ceil(max_x - min_x))
        canvas_height = int(np.ceil(max_y - min_y))
        canvas_size = max(canvas_height, canvas_width)
        canvas = np.zeros((canvas_size, canvas_size, 3), dtype=np.float32)

        # 处理每个碎片，叠加到画布
        for i, vertice in enumerate(self.vertices):
            T = vertice.global_transform
            adjusted_T = np.array([
                [T[0, 0], T[0, 1], T[0, 2] - min_x],
                [T[1, 0], T[1, 1], T[1, 2] - min_y]
            ], dtype=np.float32)

            img = imgs[vertice.img]
            img = img.transpose(1, 0, 2)
            transformed_img = cv2.warpAffine(img, adjusted_T, (canvas_size, canvas_size),
                                             flags=cv2.INTER_LINEAR,
                                             borderMode=cv2.BORDER_CONSTANT,
                                             borderValue=(0, 0, 0))
            transformed_img = transformed_img.transpose(1, 0, 2)
            canvas += transformed_img
    
        # 将图像数据限制在0-255范围内并转换为uint8类型
        canvas = np.clip(canvas, 0, 255).astype(np.uint8)
        return canvas
        


class UnionFind:
    def __init__(self, v_idx):
        self.parent = {idx: idx for idx in v_idx}  # 初始化每个元素的父节点为自身
        self.rank = {idx: 0 for idx in v_idx}          # 初始化秩为0

    def find(self, x):
        # 递归查找根节点，并应用路径压缩
        if self.parent[x] != x:
            self.parent[x] = self.find(self.parent[x])
        return self.parent[x]

    def union(self, x, y):
        root_x = self.find(x)
        root_y = self.find(y)
        if root_x == root_y:
            return  # 已在同一集合中，无需合并

        # 按秩合并：将较小秩的树合并到较大秩的树上
        if self.rank[root_x] > self.rank[root_y]:
            self.parent[root_y] = root_x
        else:
            self.parent[root_x] = root_y
            # 若秩相等，则合并后秩加1
            if self.rank[root_x] == self.rank[root_y]:
                self.rank[root_y] += 1

    def connected(self, x, y):
        return self.find(x) == self.find(y)
