"""
@Description :   Ground Truth 信息存储类，用于计算评价指标
@Author      :   tqychy 
@Time        :   2025/03/20 14:11:15
"""
import os

import numpy as np
import torch

from metrics.metrics_tools import *
from metrics.vis_tools import *


class MetricsHandler:
    def __init__(self, data, save_path, *args):
        self.cfg, self.logger = args
        self.data = data 
        self.save_path = save_path
        os.makedirs(save_path, exist_ok=True)

        # 统计每一张完整图含有的碎片和碎片对
        img_list = self.data["img_list"]
        belong_img = self.data["belong_image"]
        gt_pairs = np.array(self.data['GT_pairs'])

        # 建立图片到碎片的映射表
        self.indices_dict = {img: idx for idx, img in enumerate(img_list)}
        self.img_hash_tab = [[] for _ in range(len(img_list))]
        self.gt_pairs_hash_tab = [[] for _ in range(len(img_list))]

        # 初始化映射关系
        for frag_idx in range(len(self.data['img_all'])):
            img_name = belong_img[frag_idx]
            img_idx = self.indices_dict[img_name]
            self.img_hash_tab[img_idx].append(frag_idx)

        for pair_idx in range(len(gt_pairs)):
            idx1, _ = gt_pairs[pair_idx]
            img_name = belong_img[idx1]
            img_idx = self.indices_dict[img_name]
            self.gt_pairs_hash_tab[img_idx].append(pair_idx)
        
        # 单个 batch 的 GroundTruth
        self.batch_gt = {
            "v_num": None,
            "pic_num": None,
            "e_gt": None,
            "cluster_gt": None,
            "cur_batch": -1
        }
    
    def set_batch_info(self, info_dict, idx_convert):
        self.batch_gt.update(info_dict)
        self.batch_gt["idx_convert"] = idx_convert
        self.batch_gt["cur_batch"] = 1 + self.batch_gt["cur_batch"] # 更新 batch 编号

        # 找出 batch 中每个图像的索引和名称
        self.batch_gt["img_names"] = []
        cluster_gt = self.batch_gt["cluster_gt"]
        samples = [0]
        samples.extend((np.where(np.diff(cluster_gt) > 0)[0] + 1).tolist())
        for sample in samples:
            img_name = self.data["belong_image"][idx_convert[sample]]
            img_idx = self.indices_dict[img_name]
            self.batch_gt["img_names"].append((img_idx, img_name))

    
    def pairing_metrics(self, e_pred, sim_mat=None, gen_pic=True):
        """
        计算配对结果的指标
        Args:
            e_pred (torch.Tensor): 预测的边, shape: [N, 2]
            gen_pic (bool): 是否生成图片
        Returns:
            prec (float): 预测边的 precision
            rec (float): 预测边的 recall
            f1 (float): 预测边的 f1 score
        """
        e_gt = self.batch_gt["e_gt"]
        v_num = self.batch_gt["v_num"]
        if gen_pic:
            save_path = os.path.join(self.save_path, "edge_select")
            os.makedirs(save_path, exist_ok=True)
            prec, rec, f1 = vis_pairing_result(e_pred, e_gt, sim_mat, v_num, os.path.join(save_path, f"batch_{self.batch_gt['cur_batch']}.pdf"))
        else:
            prec, rec, f1 = pairing_metrics(e_pred, e_gt, v_num)
        return prec, rec, f1
        
    def ari_metrics(self, cluster_pred, gen_pic=True):
        """
        计算 ARI
        Args:
            cluster_pred (np.ndarray): 预测的聚类结果, shape: [N]
            gen_pic (bool): 是否生成图片
        Returns:
            ari (float): ARI
        """
        cluster_gt = self.batch_gt["cluster_gt"].numpy()
        pred = np.zeros_like(cluster_gt)
        for label, cluster in enumerate(cluster_pred):
            for v in cluster:
                pred[v] = label
        return ari_metrics(pred, cluster_gt)
    
    def distribution_metrics(self, noisy_scores, gen_pic=True):
        """
        绘制分数分布图，计算 roc_auc
        Args:
            noisy_scores (list): 预测的分数
            gen_pic (bool): 是否生成图片
        Returns:
            auc (float): AUC
        """
        e_gt = self.batch_gt["e_gt"]
        e_pred = []
        scores = []
        for u, v, score in noisy_scores:
            e_pred.append([u, v])
            scores.append(score)
        e_pred = np.array(e_pred)
        scores = np.array(scores)
        if gen_pic:
            save_path = os.path.join(self.save_path, "other_vis", "distributions")
            os.makedirs(save_path, exist_ok=True)
            vis_distribution(torch.tensor(e_pred), torch.tensor(scores), e_gt, os.path.join(save_path, f"batch_{self.batch_gt['cur_batch']}.pdf"))

        return score_eval_metrics(scores, e_pred, e_gt.numpy())
    
    def assemble_metrices(self, e_preds, trans_preds):
        idx_convert = self.batch_gt["idx_convert"]
        e_gt = self.batch_gt["e_gt"]
        pose_gt = self.data["gt_pose"]
        rec = assemble_rec(trans_preds, idx_convert, pose_gt, e_gt)
        prec = assemble_prec(e_preds, e_gt)
        return rec, prec, (rec + prec) / 2
    
    def vis_gt_pics(self):
        """
        可视化一个 batch 中的所有图片的真实拼接状态
        """
        imgs = self.data["img_all"]
        pcds = self.data["full_pcd_all"]
        cur_batch = self.batch_gt["cur_batch"]
        save_path = os.path.join(self.save_path, "other_vis", "gts", f"batch_{cur_batch}")
        os.makedirs(save_path, exist_ok=True)

        for img_idx, img_name in self.batch_gt["img_names"]:
            transformations = {idx: self.data["gt_pose"][idx] for idx in self.img_hash_tab[img_idx]}
            gt_pairs = [self.data["GT_pairs"][idx] for idx in self.gt_pairs_hash_tab[img_idx]]
            result_pic, centers = build_pic(imgs, pcds, transformations)
            # 画出代表边的线
            for u, v in gt_pairs:
                center1, center2 = centers[u], centers[v]
                cv2.line(result_pic, center1, center2,
                         (0, 255, 0), thickness=2)
            cv2.imwrite(os.path.join(save_path, img_name + ".png"), result_pic)
    
    def vis_graph(self, clusters, e_preds):
        """
        以图的形式可视化边预测结果
        Args:
            clusters (list): 聚类结果
            e_preds (list): 预测的边, shape: [N, 2]
        """
        imgs = self.data["img_all"]
        pcds = self.data["full_pcd_all"]
        cur_batch = self.batch_gt["cur_batch"]
        idx_convert = self.batch_gt["idx_convert"]
        save_path = os.path.join(self.save_path, "other_vis", "graphs", f"batch_{cur_batch}")
        os.makedirs(save_path, exist_ok=True)
        for local_idx, e_pred in zip(clusters, e_preds):
            global_idx = [idx_convert[i] for i in local_idx]
            img_idx, img_name = self._find_img_idx(global_idx)

            global_e_pred = []
            for local_u, local_v in e_pred.tolist():
                global_e_pred.append((idx_convert[local_u], idx_convert[local_v]))

            transformations = {idx: self.data["gt_pose"][idx]
                               for idx in self.img_hash_tab[img_idx]}
            gt_pairs = [self.data["GT_pairs"][idx] for idx in self.gt_pairs_hash_tab[img_idx]]
            result_pic, centers = build_pic(imgs, pcds, transformations)

            correct, missing, wrong = get_edge_result(np.array(gt_pairs), np.array(global_e_pred))
            # 添加 3 种边
            for u, v in correct:
                center1, center2 = centers[u], centers[v]
                cv2.line(result_pic, center1, center2,
                         (0, 255, 0), thickness=2)
            for u, v in missing:
                center1, center2 = centers[u], centers[v]
                draw_dashed_line(result_pic, center1, center2, (128, 0, 128), thickness=2)
            for u, v in wrong:
                if u not in centers.keys() or v not in centers.keys():
                    continue
                center1, center2 = centers[u], centers[v]
                cv2.line(result_pic, center1, center2,
                         (128, 0, 128), thickness=2)
            
            # 在每个碎片的 center 附近标上它的 global index
            text_offset = 10  # 文本偏移量
            font = cv2.FONT_HERSHEY_SIMPLEX
            font_scale = 0.6
            font_color = (0, 0, 255)  # 文本颜色
            thickness = 2
            global_idx_convert = {v: k for k, v in idx_convert.items()}
            for global_id, center in centers.items():
                local_id = global_idx_convert[global_id]
                text_position = (center[0] + text_offset, center[1] + text_offset)
                cv2.putText(result_pic, str(local_id), text_position, font, font_scale, font_color, thickness)
            
            # 图例
            # 扩展图像右侧并添加图例
            H, W, _ = result_pic.shape
            extend_width = 200  # 扩展宽度
            new_pic = 255 - np.zeros((H, W + extend_width, 3), dtype=result_pic.dtype)
            new_pic[:, :W, :] = result_pic  # 复制原始图像到左侧

            # 图例参数
            start_x = W + 20  # 图例起始 x 坐标
            start_y = 20      # 图例起始 y 坐标
            line_length = 30  # 图例线条长度
            text_offset = 15  # 线条与文本间距
            font = cv2.FONT_HERSHEY_SIMPLEX
            font_scale = 0.6
            font_color = (0, 0, 0)
            thickness = 2

            # Correct 图例
            num_correct = len(correct)
            cv2.line(new_pic, (start_x, start_y), (start_x + line_length, start_y), (0, 255, 0), thickness=2)
            cv2.putText(new_pic, f"Correct: {num_correct}", 
                        (start_x + line_length + text_offset, start_y), 
                        font, font_scale, font_color, thickness)

            # Missing 图例
            num_missing = len(missing)
            y_missing = start_y + 30
            draw_dashed_line(new_pic, (start_x, y_missing), (start_x + 15, y_missing), (128, 0, 128), thickness=2)
            cv2.putText(new_pic, f"Missing: {num_missing}", 
                        (start_x + line_length + text_offset, y_missing), 
                        font, font_scale, font_color, thickness)

            # Wrong 图例
            num_wrong = len(wrong)
            y_wrong = start_y + 60
            cv2.line(new_pic, (start_x, y_wrong), (start_x + line_length, y_wrong), (128, 0, 128), thickness=2)
            cv2.putText(new_pic, f"Wrong: {num_wrong}", 
                        (start_x + line_length + text_offset, y_wrong), 
                        font, font_scale, font_color, thickness)
            
            cv2.imwrite(os.path.join(save_path, img_name + ".png"), new_pic)

    def _find_img_idx(self, v_idx):
        """
        查找一系列点（global）最有可能属于哪张图
        Args:
            v_idx (list): 点的索引
        Returns:
            img_idx (int): 最有可能属于的图的索引
            img_name (str): 最有可能属于的图的名称
        """
        v_idx = set(v_idx)
        max_same, idx = 0, 0
        for i, img_list in enumerate(self.img_hash_tab):
            same_num = len(v_idx & set(img_list))
            if same_num > max_same:
                max_same = same_num
                idx = i
        return idx, self.data["img_list"][idx]

        


    


        


