"""
@Description : 边选择组合优化求解器
@Author : tqychy
@Time : 2025/03/01 10:34:19
"""
import sys

sys.path.append("./")
import gtsam
import numpy as np

from nets.utils.assemble_utils import Graph, UnionFind


class COSolver:
    def __init__(self, graph: Graph, init_u: list, init_A: dict, *args) -> None:
        """
        初始化组合优化求解器
        :params:
            graph: 图对象，包含节点和边的信息
            init_u: 初始的边选择变量 u_{ij}，列表形式，长度等于边的数量
            init_A: 初始的全局变换矩阵 A_i，字典形式，键为节点索引，值为 3x3 变换矩阵
        """
        self.cfg, self.logger = args
        self.graph = graph
        self.u = init_u  # u_{ij} 列表，与 graph.edges 对应
        self.A = init_A  # A_i 字典，键为节点索引
        self.factor = 0.2
        self.root_idx = self.find_fixed_vertice()

    def get_current_edges(self):
        """获取当前选中的边索引和边对象"""
        return [(i, edge) for i, edge in enumerate(self.graph.edges) if self.u[i]]

    def update_u(self):
        """
        边选择更新方法，并保证图的连通性
        1. 首先计算所有边的损失
        2. 按损失排序，优先选择损失小的边
        3. 使用Kruskal算法思想逐步添加边，保证连通性
        """
        # 计算所有边的损失
        edge_losses = []
        for e_idx, edge in enumerate(self.graph.edges):
            u, v = edge.vertices
            score = edge.score
            trans = edge.transform
            A_u, A_v = self.A[u], self.A[v]
            global_trans = np.linalg.inv(A_v) @ A_u
            loss = -score + self.factor * \
                np.linalg.norm(trans - global_trans, ord='fro')
            edge_losses.append((e_idx, edge, loss))

        # 按损失从小到大排序
        edge_losses.sort(key=lambda x: x[2])

        # 初始化并查集
        uf = UnionFind(self.graph.vertices.keys())
        current_edges = self.get_current_edges()

        # 先保留当前选中的边(保证不减少连通性)
        for e_idx, edge in current_edges:
            u, v = edge.vertices
            uf.union(u, v)

        # 重置所有边选择为False
        self.u = [False] * len(self.graph.edges)

        # 逐步添加边，优先添加损失小的边
        for e_idx, edge, loss in edge_losses:
            u, v = edge.vertices
            if uf.find(u) != uf.find(v):  # 如果不连通，则必须添加
                self.u[e_idx] = True
                uf.union(u, v)
            else:  # 如果已经连通，则根据损失决定
                self.u[e_idx] = loss <= 0

        # 特殊情况处理：如果没有选中任何边，强制连接至少生成树
        if sum(self.u) == 0:
            self.logger.debug("没有选中任何边，强制生成最小生成树")
            for e_idx, edge, loss in edge_losses:
                u, v = edge.vertices
                if uf.find(u) != uf.find(v):
                    self.u[e_idx] = True
                    uf.union(u, v)

    def update_A(self):
        """
        固定 u_{ij}，优化 A_i
        使用 GTSAM 的姿态图优化，优化全局变换矩阵
        """
        # 创建非线性因子图
        graph = gtsam.NonlinearFactorGraph()

        # 创建初始估计
        initial = gtsam.Values()

        # 将 np.ndarray 转换为 GTSAM 的 Pose2
        def matrix_to_pose2(A):
            """从 3x3 变换矩阵转换为 Pose2 (x, y, theta)"""
            theta = np.arctan2(A[1, 0], A[0, 0])  # 提取旋转角度
            x, y = A[0, 2], A[1, 2]  # 提取平移
            return gtsam.Pose2(x, y, theta)

        # 添加初始估计
        for v_idx in self.graph.vertices.keys():
            pose = matrix_to_pose2(self.A[v_idx])
            initial.insert(v_idx, pose)

        # 添加因子：固定一个节点为原点
        fixed_vertex = self.root_idx
        prior_noise = gtsam.noiseModel.Diagonal.Sigmas(
            np.array([1e-6, 1e-6, 1e-6]))  # 小方差表示强约束
        graph.add(gtsam.PriorFactorPose2(
            fixed_vertex, gtsam.Pose2(0, 0, 0), prior_noise))

        # 添加选中边的因子
        between_noise = gtsam.noiseModel.Diagonal.Sigmas(
            np.array([0.1, 0.1, 0.1]))  # 边的噪声模型
        for e_idx, edge in enumerate(self.graph.edges):
            if self.u[e_idx]:
                u, v = edge.vertices
                T_ij = edge.transform
                pose_ij = matrix_to_pose2(T_ij)
                # BetweenFactor 表示 v = u * T_ij
                graph.add(gtsam.BetweenFactorPose2(
                    v, u, pose_ij, between_noise))

        # 设置优化参数并执行优化
        params = gtsam.LevenbergMarquardtParams()
        optimizer = gtsam.LevenbergMarquardtOptimizer(graph, initial, params)
        result = optimizer.optimize()

        # 更新 A_i
        def pose2_to_matrix(pose):
            """从 Pose2 转换为 3x3 变换矩阵"""
            rot = np.array([[np.cos(pose.theta()), -np.sin(pose.theta())],
                            [np.sin(pose.theta()), np.cos(pose.theta())]])
            trans = np.array([[pose.x()], [pose.y()]])
            return np.vstack((np.hstack((rot, trans)), [0, 0, 1]))

        for v_idx in self.graph.vertices.keys():
            pose = result.atPose2(v_idx)
            self.A[v_idx] = pose2_to_matrix(pose)

    def solve(self, max_iter=10, tol=1e-6, patience=2):
        """
        改进的优化流程，增加连通性检查
        """
        cnt = 0
        for i in range(max_iter):
            old_u = self.u.copy()
            old_A = {k: v.copy() for k, v in self.A.items()}

            # 更新u，自动保证连通性
            self.update_u()

            # 检查连通性
            conn = self.check_connectivity()
            if conn > 1:
                self.logger.info(f"图不连通(连通分量={conn})，尝试修复...")
                self.force_connectivity()
                conn = self.check_connectivity()
                if conn > 1:
                    self.logger.info("无法修复图连通性，停止优化！")
                    return self.u, self.A

            # 更新A
            self.update_A()

            # 检查收敛
            u_diff = sum(a ^ b for a, b in zip(self.u, old_u))
            A_diff = sum(np.linalg.norm(
                self.A[k] - old_A[k], ord='fro') for k in self.A)
            print(f"CO iter {i} update A, u_diff: {u_diff}, A_diff: {A_diff}")
            if u_diff < tol and A_diff < tol:
                cnt += 1
            if cnt == patience:
                break

        return self.u, self.A

    def force_connectivity(self):
        """
        强制保证图的连通性
        使用Kruskal算法思想添加必要的边
        """
        uf = UnionFind(self.graph.vertices.keys())

        # 先添加当前选中的边
        for e_idx, edge in enumerate(self.graph.edges):
            if self.u[e_idx]:
                u, v = edge.vertices
                uf.union(u, v)

        # 如果已连通则返回
        if len(set(uf.find(v) for v in self.graph.vertices)) == 1:
            return

        # 否则添加必要的边使图连通
        edge_losses = []
        for e_idx, edge in enumerate(self.graph.edges):
            if not self.u[e_idx]:  # 只考虑未选中的边
                u, v = edge.vertices
                score = edge.score
                trans = edge.transform
                A_u, A_v = self.A[u], self.A[v]
                global_trans = np.linalg.inv(A_v) @ A_u
                loss = -score + self.factor * \
                    np.linalg.norm(trans - global_trans, ord='fro')
                edge_losses.append((e_idx, edge, loss))

        # 按损失排序
        edge_losses.sort(key=lambda x: x[2])

        # 添加必要的边
        for e_idx, edge, _ in edge_losses:
            u, v = edge.vertices
            if uf.find(u) != uf.find(v):
                self.u[e_idx] = True
                uf.union(u, v)
                if len(set(uf.find(v) for v in self.graph.vertices)) == 1:
                    break

    def check_connectivity(self):
        union_find_set = UnionFind(self.graph.vertices.keys())
        num = set()
        for e_idx, edge in enumerate(self.graph.edges):
            if self.u[e_idx]:
                u, v = edge.vertices
                union_find_set.union(u, v)
        for v_idx in self.graph.vertices.keys():
            num.add(union_find_set.find(v_idx))
        return len(num)

    def find_fixed_vertice(self):
        root = None
        for v_idx in self.graph.vertices:
            if np.allclose(np.eye(3), self.A[v_idx]):
                root = v_idx
                break
        if root is None:
            raise ValueError("CO 求解器未找到全局坐标系固定碎片。")
        return root

    def check_edge_loss(self, u_list, A, iter_num):
        for e_idx, edge in enumerate(self.graph.edges):
            u, v = edge.vertices
            score = edge.score
            trans = edge.transform  # T_{ij}
            A_u, A_v = A[u], A[v]  # A_i, A_j
            global_trans = np.linalg.inv(A_v) @ A_u  # A_j^{-1} A_i
            selected = u_list[e_idx]

            # 使用 Frobenius 范数计算偏差
            edge_loss = np.linalg.norm(trans - global_trans, ord='fro')
            self.logger.debug(
                f"edge {u}&{v}: score: {score}, loss: {edge_loss}, selected: {selected}")

# """
# @Description : 边选择组合优化求解器
# @Author : tqychy
# @Time : 2025/03/01 10:34:19
# """
# import sys

# sys.path.append("./")
# from scipy.optimize import least_squares
# import numpy as np

# from nets.utils.assemble_utils import Graph, UnionFind


# class COSolver:
#     def __init__(self, graph: Graph, init_u: list, init_A: dict, *args) -> None:
#         """
#         初始化组合优化求解器
#         :params:
#             graph: 图对象，包含节点和边的信息
#             init_u: not use
#             init_A: 初始的全局变换矩阵 A_i，字典形式，键为节点索引，值为 3x3 变换矩阵
#             root_idx: 固定的根节点索引
#         """
#         self.cfg, self.logger = args
#         self.graph = graph
#         self.A = init_A
#         self.node_indices = list(init_A.keys())
#         self.root_idx = self.find_fixed_vertice()


#     def residual(self, X: np.ndarray) -> np.ndarray:
#         """
#         计算残差向量
#         :param X: 优化变量，长度为 9*(n-1)，表示除root_idx外的变换矩阵
#         :return: 展平的残差向量，长度为 9*m
#         """
#         # 构造变换矩阵字典
#         A = {self.root_idx: np.eye(3)}  # 固定根节点的变换为单位矩阵
#         idx = 0
#         for i in self.node_indices:
#             if i != self.root_idx:
#                 A[i] = X[9*idx:9*(idx+1)].reshape(3, 3)
#                 idx += 1

#         # 计算所有边的残差
#         r = []
#         for e_idx, edge in enumerate(self.graph.edges):
#             u, v = edge.vertices
#             score = edge.score
#             trans = edge.transform  # T_{ij}
#             A_u, A_v = A[u], A[v]
#             global_trans = np.linalg.inv(A_v) @ A_u

#             residual_matrix = global_trans - trans  # Frobenius范数的平方对应元素差平方和
#             r.append(np.sqrt(score) * residual_matrix.flatten())

#         return np.concatenate(r)


#     def solve(self):
#         """优化全局变换矩阵"""
#         X0 = []
#         for i in self.node_indices:
#             if i != self.root_idx:
#                 X0.append(self.A[i].flatten())
#         X0 = np.concatenate(X0)  # 长度为 9*(n-1)

#         # 运行非线性最小二乘优化
#         res = least_squares(self.residual, X0, method='lm', verbose=1)
#         self.logger.info(f"Optimization converged: {res.success}, cost: {res.cost}")

#         # 构造优化后的变换矩阵字典
#         A_opt = {self.root_idx: np.eye(3)}
#         idx = 0
#         for i in self.node_indices:
#             if i != self.root_idx:
#                 A_opt[i] = res.x[9*idx:9*(idx+1)].reshape(3, 3)
#                 idx += 1

#         return None, A_opt


#     def check_connectivity(self):
#         union_find_set = UnionFind(self.graph.vertices.keys())
#         num = set()
#         for e_idx, edge in enumerate(self.graph.edges):
#             if self.u[e_idx]:
#                 u, v = edge.vertices
#                 union_find_set.union(u, v)
#         for v_idx in self.graph.vertices.keys():
#             num.add(union_find_set.find(v_idx))
#         return len(num)

#     def find_fixed_vertice(self):
#         root = None
#         for v_idx in self.graph.vertices:
#             if np.allclose(np.eye(3), self.A[v_idx]):
#                 root = v_idx
#                 break
#         if root is None:
#             raise ValueError("CO 求解器未找到全局坐标系固定碎片。")
#         return root

#     def check_edge_loss(self, u_list, A, iter_num):
#         for e_idx, edge in enumerate(self.graph.edges):
#             u, v = edge.vertices
#             score = edge.score
#             trans = edge.transform  # T_{ij}
#             A_u, A_v = A[u], A[v]  # A_i, A_j
#             global_trans = np.linalg.inv(A_v) @ A_u  # A_j^{-1} A_i
#             selected = u_list[e_idx]

#             # 使用 Frobenius 范数计算偏差
#             edge_loss = np.linalg.norm(trans - global_trans, ord='fro')
#             print(f"edge {u}&{v}: score: {score}, loss: {edge_loss}, selected: {selected}")
