"""
@Description :   评价指标
@Author      :   tqychy 
@Time        :   2025/03/20 15:05:07
"""
import math

import numpy as np
from sklearn.metrics import adjusted_rand_score, roc_auc_score
from torch_geometric.utils import to_dense_adj, to_undirected


def pairing_metrics(e_pred, e_gt, v_num):
    """
    计算预测边的 prec & rec
    Args:
        e_pred (torch.Tensor): 预测的边, shape: [N, 2]
        e_gt (torch.Tensor): 真实边, shape: [M, 2]
        v_num (int): 节点数
    Returns:
        prec (float): 预测边的 precision
        rec (float): 预测边的 recall
        f1 (float): 预测边的 f1 score
    """
    e_np = to_dense_adj(to_undirected(e_gt.T), max_num_nodes=v_num).squeeze().numpy().astype(int)
    pred_np = to_dense_adj(to_undirected(e_pred.T), max_num_nodes=v_num).squeeze().numpy().astype(int)

    # Calculate metrics
    TP = np.sum((pred_np == 1) & (e_np == 1))
    FP = np.sum((pred_np == 1) & (e_np == 0))
    FN = np.sum((pred_np == 0) & (e_np == 1))

    precision = TP / (TP + FP) if (TP + FP) != 0 else 0
    recall = TP / (TP + FN) if (TP + FN) != 0 else 0
    f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) != 0 else 0

    return precision, recall, f1_score

def ari_metrics(cluster_pred, cluster_gt):
    """
    计算 ARI
    Args:
        cluster_pred (np.ndarray): 预测的聚类结果, shape: [N]
        cluster_gt (np.ndarray): 真实的聚类结果, shape: [N]
    Returns:
        ari (float): ARI
    """
    ari = adjusted_rand_score(cluster_gt, cluster_pred)
    return ari

def score_eval_metrics(scores, e_pred, e_gt):
    """
    计算分数评价指标(roc_auc)
    Args:
        scores (np.ndarray): 预测的分数, shape: [N]
        e_pred (np.ndarray): 预测的边, shape: [N, 2]
        e_gt (np.ndarray): 真实的边, shape: [M, 2]
    Returns:
        auc (float): AUC
    """
    gt_set = set([(u, v) if u < v else (v, u) for u, v in e_gt.tolist()])
    gt = np.zeros_like(scores)
    for i, (u, v) in enumerate(e_pred.tolist()):
        if u >= v:
            u, v = v, u
        if (u, v) in gt_set:
            gt[i] = 1
    return roc_auc_score(gt, scores)


def get_edge_result(e_gt, e_pred):
    # 转换numpy数组为标准化边集合
    def process_edges(edge_array):
        return {tuple(sorted(edge.astype(int))) for edge in edge_array}

    true_set = process_edges(e_gt)
    pred_set = process_edges(e_pred)

    # 计算三类边
    correct = true_set & pred_set
    missing = true_set - correct
    wrong = pred_set - correct

    return list(correct), list(missing), list(wrong)

def assemble_rec(trans_preds, idx_convert, pose_gt, e_gt):
    """
    计算全局拼接 rec (有多少真实边是预测全局变换矩阵符合的(无论是否被显式预测到))
    Args:
        trans_preds (dict): 预测的相对变换矩阵, shape: [N, 3, 3]
        idx_convert (dict): 预测的相对变换矩阵, shape: [N, 3, 3]
        pose_gt (dict): 真实的相对变换矩阵, shape: [N, 3, 3]
        e_gt (np.ndarray): 真实的边, shape: [M, 2]
    Returns:
        rec (float): 组装精度
    """
    v_preds = set(trans_preds.keys())
    tot_num = 0
    rec = 0
    s = np.array([[0, 1, 0], [1, 0, 0], [0, 0, 1]])
    for local_u, local_v in e_gt.tolist():
        if local_v in v_preds and local_u in v_preds:
            tot_num += 1
            # 计算真实相对变换矩阵
            global_u, global_v = idx_convert[local_u], idx_convert[local_v]
            trans_u, trans_v = pose_gt[global_u], pose_gt[global_v]
            trans_gt = np.linalg.inv(trans_v) @ trans_u
            # 计算预测相对变化矩阵
            trans_u, trans_v = trans_preds[local_u], trans_preds[local_v]
            trans_u = s @ trans_u @ s
            trans_v = s @ trans_v @ s
            trans_pred = np.linalg.inv(trans_v) @ trans_u
            # 判断是否相等
            rec += transform_error(trans_pred, trans_gt)
    return rec / tot_num if tot_num!= 0 else 0.

def assemble_prec(e_pred, e_gt):
    """
    计算全局拼接 prec (有多少预测边是真实的)
    """
    gt_set = set([(u, v) if u < v else (v, u) for u, v in e_gt.tolist()])
    true_num = 0
    for u, v in e_pred.tolist():
        if u >= v:
            u, v = v, u
        if (u, v) in gt_set:
            true_num += 1
    return true_num / len(e_pred) if len(e_pred) != 0 else 0.


def transform_error(T1, T2, shift_err=100, rotate_err=5):
    # 提取旋转矩阵和平移向量
    R1 = T1[:2, :2]
    t1 = T1[:2, 2]
    R2 = T2[:2, :2]
    t2 = T2[:2, 2]
    
    # 计算旋转差异
    R_diff = R1.T @ R2
    theta = math.atan2(R_diff[1, 0], R_diff[0, 0])
    theta_deg = math.degrees(theta)
    theta_abs = abs(theta_deg)
    min_angle = min(theta_abs, 360 - theta_abs)
    
    if min_angle > rotate_err:
        return False
    
    # 计算平移差异
    dx = t2[0] - t1[0]
    dy = t2[1] - t1[1]
    distance = math.hypot(dx, dy)
    
    if distance > shift_err:
        return False
    
    return True
