"""
@Description :   全局拼接算法
@Author      :   tqychy 
@Time        :   2025/02/23 15:41:04
"""
import sys

sys.path.append("./")
sys.path.append("./nets")
import heapq
import os
from collections import deque
from copy import deepcopy

import cv2
import numpy as np
from torch_geometric.utils import to_undirected
from utils.assemble_utils import (Edge, Graph, Loop, UnionFind, Vertice,
                                  allclose_transform)
from utils.co_solver import COSolver


class HLMAssembler:
    """
    HLM 全局拼接器
    v_imgs: self.test_dataset.data['img_all']
    v_pcds: self.test_dataset.data['full_pcd_all']
    """

    def __init__(self, v_imgs, v_pcds, v_idx, idx_convert, cleaned_pairs, sim_mats, scores, results_path, metrics_handler, *args):
        self.cfg, self.logger = args
        self.results_path = os.path.join(results_path, "global_assemble")
        os.makedirs(self.results_path, exist_ok=True)
        self.v_num = len(v_idx)
        self.max_loop_num = 10

        vertices = {}
        for i, local_idx in enumerate(v_idx):
            global_idx = idx_convert[local_idx]
            # img = v_imgs[global_idx]
            # pcd = v_pcds[global_idx]
            vertices[i] = Vertice(global_idx, global_idx, global_idx)

        edges = []
        for e_idx, (idx1, idx2) in enumerate(cleaned_pairs):
            e_pairs = (int(idx1), int(idx2))
            # sim_mat = sim_mats[e_idx]
            score = scores[e_idx]
            edges.append(Edge(e_pairs, e_idx, score))

        self.ori_graph = Graph(vertices, edges, v_imgs, v_pcds, sim_mats)
        self.edges_set = set()
        for edge in self.ori_graph.edges:
            u, v = edge.vertices
            self.edges_set.add((v, u))
            self.edges_set.add((u, v))


    def assemble(self, batch: int, pic_idx: int):
        # TODO: debug here
        induced_loops, invalid_loops = self.find_induced_loops()
        path = os.path.join(self.results_path,
                            f"batch{batch}", "induced_loops")
        os.makedirs(path, exist_ok=True)
        self.ori_graph.visualize(os.path.join(
            self.results_path, f"batch{batch}", "graph.png"), induced_loops)
        for i, loop in enumerate(induced_loops):
            assembled = loop.build_pic()
            cv2.imwrite(os.path.join(path, f"loop_{i}_valid.png"), assembled)
        for i, loop in enumerate(invalid_loops):
            assembled = loop.build_pic()
            cv2.imwrite(os.path.join(path, f"loop_{i}_invalid.png"), assembled)

        # 打印有效环路，同时查找哪些点没有在环路中
        all_vers = set(range(len(self.ori_graph.vertices)))
        vers_in_loop = set()
        for i, loop in enumerate(induced_loops):
            global_vers = []
            for vertice in loop.vertices:
                global_vers.append(vertice.global_index)
                vers_in_loop.add(vertice.global_index)
            global_vers.sort()
            print(global_vers, f"loop_{i}_valid")
        ver_not_in_loop = list(all_vers - vers_in_loop)
        print(f"没有在任何诱导环路中的点：{ver_not_in_loop}")

        loop_closures, _ = self.bottom_up_merge(induced_loops)
        path = os.path.join(self.results_path,
                            f"batch{batch}", "loop_closures")
        os.makedirs(path, exist_ok=True)
        for i, loop in enumerate(loop_closures):
            score = loop.score
            assembled = loop.build_pic()
            cv2.imwrite(os.path.join(
                path, f"loop_{i}_score_{score}.png"), assembled)
        # exit(-1)
        return (0, 0, 0), 0

    def check_if_induce(self, loop: Loop):
        """
        检查环路是否是诱导环路
        Args:
            loop: 待检查的环路
        Returns:
            bool 是否是诱导环路
        """
        loop_edges = set()
        for edge in loop.edges:
            u, v = edge.vertices
            u = loop.vertices[u].global_index
            v = loop.vertices[v].global_index
            loop_edges.add((v, u))
            loop_edges.add((u, v))

        for u_idx in range(len(loop.vertices)):
            for v_idx in range(1 + u_idx, len(loop.vertices)):
                u_global = loop.vertices[u_idx].global_index
                v_global = loop.vertices[v_idx].global_index
                if (u_global, v_global) in self.edges_set and (u_global, v_global) not in loop_edges:
                    return False
        return True

    def find_induced_loops(self):
        """
        找到长度 3 ~ 4 的诱导环路
        Args: None
        Returns:
            valid_loops: 有效环路列表
        """
        valid_loops = []
        invalid_loops = []
        all_loops = self.ori_graph.gen_loops(max_loop_length=10)
        for loop in all_loops:
            if not self.check_if_induce(loop):  # 不是诱导环路
                continue
            if loop.check_closure() and loop.check_no_intersection():
                loop.set_global_transforms_by_path()
                valid_loops.append(loop)
            else:
                loop.set_global_transforms_by_path()
                invalid_loops.append(loop)

        self.logger.info(
            f"{len(valid_loops)} 个有效诱导环路，{len(invalid_loops)} 个无效诱导环路")
        return valid_loops, invalid_loops

    def bottom_up_merge(self, induced_loops):
        """
        自底向上合并环路
        Args:
            valid_loops: 初始有效环路列表
        Returns:
            merged_loops: 合并后的环路列表
            history: 合并历史记录
        """
        loop_closures = deepcopy(induced_loops)
        loop_closure_history = [deepcopy(induced_loops)]
        iteration = 1

        while True:
            self.logger.info(
                f"Iterate {iteration}, try to integrate {len(loop_closures)} loops...")
            temp_merged = []
            truncate = False

            for i, loop in enumerate(loop_closures):
                inds, edges = [], []
                for v in loop.vertices:
                    ind = v.global_index
                    inds.append(ind)
                for edge in loop.edges:
                    u, v = edge.vertices
                    u = loop.vertices[u].global_index
                    v = loop.vertices[v].global_index
                    edges.append((u, v))
                print(f"环路 {i}: 点 {inds}, 边 {edges}")
            
            # # debug
            # break

            # 遍历所有环路对寻找可合并项
            for i in range(len(loop_closures)):
                if truncate:
                    break
                loop_i = loop_closures[i]
                for j in range(i+1, len(loop_closures)):
                    loop_j = loop_closures[j]
                    # 查找公共边，对齐两个 loop 的坐标系
                    common_edges = self.get_common_edges(loop_i, loop_j)
                    transform = self.align_loops(loop_i, loop_j, common_edges)
                    if transform is None:
                        continue
                    # self.logger.info(f"找到有公共边的环路：{i} {j}")
                    # 合并环路
                    merged_loop = self.merge_loops(loop_i, loop_j, transform)
                    # 截断控制
                    if len(temp_merged) > self.max_loop_num:
                        truncate = True
                        break
                    # 去重
                    if not any(merged_loop == exist for exist in temp_merged):
                        temp_merged.append(merged_loop)

            # self.logger.info("Done!")
            iteration += 1
            
            # # debug
            # if iteration > 3:
            #     break

            if len(temp_merged) == 0:
                break
            else:
                loop_closures = temp_merged
                # loop_closure_history.append(loop_closures)

        return sorted(loop_closures, reverse=True), loop_closure_history

    @staticmethod
    def get_common_edges(loop1: Loop, loop2: Loop):
        """
        查找两个环路中有无公共边
        Args:
            loop1: 第一个环路
            loop2: 第二个环路
        Returns:
            common_edges: 所有公共边索引对，以及是否是反转对应的 (e_idx1, e_idx2, reverse)
        """
        common_edges = []
        for e_idx1, edge1 in enumerate(loop1.edges):
            u1, v1 = edge1.vertices
            u1 = loop1.vertices[u1].global_index
            v1 = loop1.vertices[v1].global_index
            for e_idx2, edge2 in enumerate(loop2.edges):
                u2, v2 = edge2.vertices
                u2 = loop2.vertices[u2].global_index
                v2 = loop2.vertices[v2].global_index

                if u1 == u2 and v1 == v2:
                    common_edges.append((e_idx1, e_idx2, False))
                elif u1 == v2 and u2 == v1:
                    common_edges.append((e_idx1, e_idx2, True))
        return common_edges

    @staticmethod
    def align_loops(loop1: Loop, loop2: Loop, common_edges: list):
        """
        对齐两个有公共边的环路
        Args:
            loop1: 第一个环路
            loop2: 第二个环路
            common_edges: get_common_edges 找到的公共边
        Returns:
            若成功，返回将 loop2 转换到 loop1 的坐标系的变换矩阵，否则返回 None
        """
        assert loop1.has_set_global_transforms and loop2.has_set_global_transforms, "环路尚未设置全局仿射变换矩阵"
        if len(common_edges) == 0:
            return None
        final_transform = None
        for e_idx1, e_idx2, reverse in common_edges:
            u1, v1 = loop1.edges[e_idx1].vertices
            u2, v2 = loop2.edges[e_idx2].vertices
            if reverse:
                u2, v2 = v2, u2
            u1_trans = loop1.vertices[u1].global_transform
            v1_trans = loop1.vertices[v1].global_transform
            u2_trans = loop2.vertices[u2].global_transform
            v2_trans = loop2.vertices[v2].global_transform

            # 计算变换矩阵
            transform1 = np.linalg.inv(u2_trans) @ u1_trans
            transform2 = np.linalg.inv(v2_trans) @ v1_trans
            if final_transform is None:
                final_transform = transform1
            all_same = allclose_transform(transform1, transform2) and allclose_transform(
                final_transform, transform1)
            if not all_same:
                return None

        return final_transform

    @staticmethod
    def merge_loops(loop1: Loop, loop2: Loop, transform: np.ndarray):
        """
        将两个环路合并
        Args:
            loop1: 第一个环路
            loop2: 第二个环路
            transform: 将 loop2 转换到 loop1 的坐标系的变换矩阵
        Returns:
            一个新环路
        """
        assert loop1.has_set_global_transforms and loop2.has_set_global_transforms, "环路尚未设置全局仿射变换矩阵"
        assert id(loop1.graph) == id(loop2.graph), "两个环路的图不同"
        loop_a = deepcopy(loop1)
        loop_b = deepcopy(loop2)
        loop_b.overall_transform(transform)

        idx_convert = {}
        merged_vertices = []
        merged_edges = []
        local_cnt = 0
        for vertice in loop_a.vertices + loop_b.vertices:
            global_idx = vertice.global_index
            if global_idx not in idx_convert.keys():
                idx_convert[global_idx] = local_cnt
                local_cnt += 1
                merged_vertices.append(vertice)
        for edge in loop_a.edges:
            u, v = edge.vertices
            u_global = loop_a.vertices[u].global_index
            v_global = loop_a.vertices[v].global_index
            new_u_idx, new_v_idx = idx_convert[u_global], idx_convert[v_global]
            edge.vertices = new_u_idx, new_v_idx
            merged_edges.append(edge)
        for edge in loop_b.edges:
            u, v = edge.vertices
            u_global = loop_b.vertices[u].global_index
            v_global = loop_b.vertices[v].global_index
            new_u_idx, new_v_idx = idx_convert[u_global], idx_convert[v_global]
            edge.vertices = new_u_idx, new_v_idx
            merged_edges.append(edge)

        merged_loop = Loop(merged_vertices, merged_edges,
                           [], loop1.graph)  # WARN: 目前不需要合并后的环的环路，先设置为空
        merged_loop.has_set_global_transforms = True
        return merged_loop


class KruskalAssembler:
    def __init__(self, v_imgs, v_pcds, v_idx, idx_convert, cleaned_pairs, sim_mats, scores, results_path, metrics_handler, *args):
        self.cfg, self.logger = args
        self.metrics_handler = metrics_handler
        self.results_path = os.path.join(results_path, "global_assemble")
        os.makedirs(self.results_path, exist_ok=True)
        self.v_num = len(v_idx)

        vertices = {}
        for local_idx in v_idx:
            global_idx = idx_convert[local_idx]
            # img = v_imgs[global_idx]
            # pcd = v_pcds[global_idx]
            vertices[local_idx] = Vertice(global_idx, global_idx, global_idx)

        edges = []
        for e_idx, (idx1, idx2) in enumerate(cleaned_pairs):
            e_pairs = (int(idx1), int(idx2))
            # sim_mat = sim_mats[e_idx]
            score = scores[e_idx]
            edges.append(Edge(e_pairs, e_idx, score))

        self.ori_graph = Graph(vertices, edges, v_imgs, v_pcds, sim_mats)

    def assemble(self, batch: int, pic_idx: int):
        components, _ = self.find_spanning_trees()
        # 用于计算评价指标的信息
        e_preds = []
        post_clusters = []
        for component in components:
            post_clusters.append(component["v"])
            comp_e = component["e"]
            for u, val in comp_e.items():
                for v, _ in val:
                    u, v = int(u), int(v)
                    e_preds.append([u, v])
        e_preds = np.array(e_preds)
        
        global_transformations = {}
        for num, component in enumerate(components):
            comp_v, comp_e = component["v"], component["e"]
            global_transformation = self.get_global_transformation(
                comp_v, comp_e)
            global_transformations.update(global_transformation)
            result = self.build_pic(global_transformation)
            save_path = os.path.join(
                self.results_path, f"batch_{batch}")
            os.makedirs(save_path, exist_ok=True)
            cv2.imwrite(os.path.join(
                save_path, f"{pic_idx}_{num}.png"), result)
        return self.metrics_handler.assemble_metrices(e_preds, global_transformations), self.metrics_handler.ari_metrics(post_clusters)

    def check_edges(self, components, e_gt):
        """
        检查生成树错误的边
        :params:
            components: 生成树结果
            e_gt: 真实边
        """
        spanning_tree_edges = set()
        gt_edges = set()
        scores_dict = {}
        for edge in self.ori_graph.edges:
            u, v = edge.vertices
            scores_dict[(u, v)] = edge.score
            scores_dict[(v, u)] = edge.score
        for component in components:
            comp_e = component["e"]
            for u, val in comp_e.items():
                for v, _ in val:
                    u, v = int(u), int(v)
                    spanning_tree_edges.add((v, u))
                    spanning_tree_edges.add((u, v))
        e_gt = to_undirected(e_gt).T
        for u, v in e_gt:
            u, v = int(u), int(v)
            gt_edges.add((v, u))
            gt_edges.add((u, v))

        wrong_edges = spanning_tree_edges - gt_edges
        self.logger.debug(
            f"生成树中共 {len(wrong_edges)} 条错误边，错误率 {len(wrong_edges) / len(spanning_tree_edges)}")
        for u, v in wrong_edges:
            score = scores_dict[(u, v)]
            self.logger.debug(f"错误边：{u}-{v}，得分 {score}")

    def find_spanning_trees(self):
        """
        Kruskal 算法计算最大生成森林
        """
        v_num, e_num = len(self.ori_graph.vertices), len(self.ori_graph.edges)
        union_find_set = UnionFind(list(self.ori_graph.vertices.keys()))
        e_queue = []  # 边优先队列
        e_selected = [False] * e_num  # 边是否被选择
        transformations = [None] * e_num  # 生成树边的仿射变换矩阵
        for e_idx, edge in enumerate(self.ori_graph.edges):
            e_queue.append((-edge.score, e_idx))
        heapq.heapify(e_queue)

        # Kruskal
        edge_cnt = 0
        while edge_cnt < v_num and len(e_queue) > 0:
            _, edge_idx = heapq.heappop(e_queue)
            u, v = self.ori_graph.edges[edge_idx].vertices
            if not union_find_set.connected(u, v):
                trans = self.ori_graph.get_relative_transformation(edge_idx)
                if trans is not None:
                    edge_cnt += 1
                    union_find_set.union(u, v)
                    e_selected[edge_idx] = True
                    transformations[edge_idx] = trans

        # 计算最小最大 score
        min_score, max_score = float("inf"), -float("inf")
        edge_cnt = 0
        for e_idx in range(e_num):
            if e_selected[e_idx] == False:
                continue
            score = self.ori_graph.edges[e_idx].score
            min_score = min(score, min_score)
            max_score = max(score, max_score)
            edge_cnt += 1
        self.logger.debug(
            f"共选择 {edge_cnt} 条边，最小得分 {min_score}，最大得分 {max_score}")

        # 收集不同连通分量的生成树
        components = {}
        for v_idx in self.ori_graph.vertices.keys():
            father = union_find_set.find(v_idx)
            if father not in components.keys():
                components[father] = {"v": [], "e": {}}
            components[father]["v"].append(v_idx)
        for e_idx in range(e_num):
            if e_selected[e_idx] == False:
                continue
            u, v = self.ori_graph.edges[e_idx].vertices
            t = transformations[e_idx]
            father = union_find_set.find(u)
            e_dict = components[father]["e"]

            if u not in e_dict.keys():
                e_dict[u] = []
            if v not in e_dict.keys():
                e_dict[v] = []
            e_dict[u].append((v, np.linalg.inv(t)))
            e_dict[v].append((u, t))

        return list(components.values()), e_selected

    @staticmethod
    def get_global_transformation(v, e):
        """
        遍历生成树，得到各碎片的全局变换矩阵
        """
        root = v[0]
        que = deque([root])
        global_trans = {root: np.eye(3, dtype=np.float32)}
        visited = {v_idx: False for v_idx in v}
        visited[root] = True

        # 处理只有一个碎片的情况
        if len(e.keys()) == 0:
            return global_trans

        # BFS
        while len(que) > 0:
            u_idx = que.popleft()
            for v_idx, trans in e[u_idx]:
                if visited[v_idx]:
                    continue
                que.append(v_idx)
                global_trans[v_idx] = global_trans[u_idx] @ trans
                visited[v_idx] = True

        return global_trans

    def build_pic(self, transformations):
        # 收集所有变换后的点云坐标以确定画布尺寸
        all_points = []
        graph = self.ori_graph
        for idx in transformations:
            T = transformations[idx]
            pcd = graph.pcds[graph.vertices[idx].pcd]
            homogeneous = np.hstack([pcd, np.ones((len(pcd), 1))])
            transformed = (T @ homogeneous.T).T
            all_points.append(transformed)
        all_points = np.concatenate(all_points, axis=0)

        min_x, max_x = np.min(all_points[:, 0]), np.max(all_points[:, 0])
        min_y, max_y = np.min(all_points[:, 1]), np.max(all_points[:, 1])

        # 计算画布的尺寸
        canvas_width = int(np.ceil(max_x - min_x))
        canvas_height = int(np.ceil(max_y - min_y))
        canvas_size = max(canvas_height, canvas_width)
        canvas = np.zeros((canvas_size, canvas_size, 3), dtype=np.float32)

        # 处理每个碎片，叠加到画布
        for i, idx in enumerate(transformations):
            T = transformations[idx]
            adjusted_T = np.array([
                [T[0, 0], T[0, 1], T[0, 2] - min_x],
                [T[1, 0], T[1, 1], T[1, 2] - min_y]
            ], dtype=np.float32)

            img = graph.imgs[graph.vertices[idx].img]
            img = img.transpose(1, 0, 2)
            transformed_img = cv2.warpAffine(img, adjusted_T, (canvas_size, canvas_size),
                                             flags=cv2.INTER_LINEAR,
                                             borderMode=cv2.BORDER_CONSTANT,
                                             borderValue=(0, 0, 0))
            transformed_img = transformed_img.transpose(1, 0, 2)
            canvas += transformed_img
            # # TODO: debug here
            # cv2.imwrite(f"./temp/debug_{i}.png", canvas)

        # 将图像数据限制在0-255范围内并转换为uint8类型
        canvas = np.clip(canvas, 0, 255).astype(np.uint8)
        return canvas
    
class COAssembler(KruskalAssembler):
    def assemble(self, batch: int, pic_idx: int):
        u_init, A_init, e_preds, post_clusters = self.init_variables()
        solver = COSolver(self.ori_graph, u_init, A_init, self.cfg, self.logger)
        _, final_pose = solver.solve()

        result = self.build_pic(final_pose)
        save_path = os.path.join(
            self.results_path, f"batch_{batch}")
        os.makedirs(save_path, exist_ok=True)
        cv2.imwrite(os.path.join(
            save_path, f"{pic_idx}.png"), result)
        
        return self.metrics_handler.assemble_metrices(e_preds, final_pose), self.metrics_handler.ari_metrics(post_clusters)
    
    def init_variables(self):
        components, u_init = self.find_spanning_trees()

        e_preds = []
        post_clusters = []
        for component in components:
            post_clusters.append(component["v"])
            comp_e = component["e"]
            for u, val in comp_e.items():
                for v, _ in val:
                    u, v = int(u), int(v)
                    e_preds.append([u, v])
        e_preds = np.array(e_preds)

        A_init = {}
        print(f"components num: {len(components)}")
        for component in components:
            comp_v, comp_e = component["v"], component["e"]
            global_transformation = self.get_global_transformation(comp_v, comp_e)
            A_init.update(global_transformation)
        
        return u_init, A_init, e_preds, post_clusters
