"""
@Description :   数据集可视化
@Author      :   tqychy 
@Time        :   2025/03/01 17:04:57
"""
import os
import pickle

import cv2
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import scienceplots
import seaborn as sns
from scipy.stats import beta, norm

matplotlib.use("Agg")
plt.style.use(["science","grid","no-latex"])
plt.rcParams["font.sans-serif"]=['KaiTi']
plt.rcParams["axes.unicode_minus"]=False
plt.rcParams["figure.figsize"] = (15.2, 8.8)
fontdict = {"fontsize": 15}


class DatasetVis:
    def __init__(self, data_path: str, result_path="./temp/data_vis"):
        # 读取数据
        with open(data_path, "rb") as f:
            self.data = pickle.load(f)
        self.result_path = result_path
        os.makedirs(result_path, exist_ok=True)

        # 统计每一张完整图含有的碎片和碎片对
        img_list = self.data["img_list"]
        belong_img = self.data["belong_image"]
        gt_pairs = np.array(self.data['GT_pairs'])

        # 建立图片到碎片的映射表
        self.indices_dict = {img: idx for idx, img in enumerate(img_list)}
        self.img_hash_tab = [[] for _ in range(len(img_list))]
        self.gt_pairs_hash_tab = [[] for _ in range(len(img_list))]

        # 初始化映射关系
        for frag_idx in range(len(self.data['img_all'])):
            img_name = belong_img[frag_idx]
            img_idx = self.indices_dict[img_name]
            self.img_hash_tab[img_idx].append(frag_idx)

        for pair_idx in range(len(gt_pairs)):
            idx1, _ = gt_pairs[pair_idx]
            img_name = belong_img[idx1]
            img_idx = self.indices_dict[img_name]
            self.gt_pairs_hash_tab[img_idx].append(pair_idx)

    def recover_images_by_gtinfo(self):
        path = os.path.join(self.result_path, "gt_pics")
        os.makedirs(path, exist_ok=True)
        for img_idx in range(len(self.data["img_list"])):
            transformations = {idx: self.data["gt_pose"][idx]
                               for idx in self.img_hash_tab[img_idx]}
            gt_pairs = [self.data["GT_pairs"][idx]
                        for idx in self.gt_pairs_hash_tab[img_idx]]

            result_pic, centers, real_shape = self._build_pic(transformations)
            result_pic = self._add_edges(result_pic, centers, gt_pairs)
            result_pic = result_pic[:real_shape[0], :real_shape[1], :]
            img_name = self.data["img_list"][img_idx]
            cv2.imwrite(os.path.join(path, img_name + ".png"), result_pic)
    
    def display_hyper_parameters(self, logger_display=None):
        image_size = 0
        contour_max_len = 0
        max_scripts = 0
        max_edges = 0

        for shape in self.data["shape_all"]:
            h, w, _ = shape
            image_size = max(h, w, image_size)
        for pcd in self.data["full_pcd_all"]:
            contour_max_len = max(pcd.shape[0], contour_max_len)
        for imgs in self.img_hash_tab:
            max_scripts = max(len(imgs), max_scripts)
        for edges in self.gt_pairs_hash_tab:
            max_edges = max(len(edges), max_edges)
        display_str = f"数据集中，最大碎片宽高（image_size）: {image_size}, 最长碎片长度（contour_max_len）{contour_max_len}, 图片的最大碎片数量（max_scripts）: {max_scripts}，图片的最大碎片对数量（max_edges）: {max_edges}"
        
        if logger_display is None:
            print(display_str)
        else:
            logger_display(display_str)
    
    def area_displot(self):
        area_list = [self._calc_area(pcd) for pcd in self.data["full_pcd_all"]]
        sns.histplot(area_list, color="green", kde=False, stat="density", alpha=.7, label="Histogram")
        sns.kdeplot(area_list, color="purple", linestyle='--', label="KDE")

        a, b, loc, scale = beta.fit(area_list)
        xmin, xmax = plt.xlim()
        x = np.linspace(xmin, xmax, 1000)
        p = beta.pdf(x, a, b, loc=loc, scale=scale)
        skew, kurt = beta.stats(a, b, loc=loc, scale=scale, moments='sk')
        expression = f'B({a:.2f}, {b:.2f})\nSkewness: {skew:.2f}, Kurtosis: {kurt:.2f}'
        plt.plot(x, p, 'purple', linewidth=2, label=expression)

        plt.legend(prop={'size': 15})
        plt.xticks(fontsize=15)
        plt.yticks(fontsize=15)
        plt.xlabel("Area(pix)", fontdict=fontdict)
        plt.ylabel("Density", fontdict=fontdict)
        plt.savefig(self.result_path + "/area_displot.pdf")
        plt.clf()
    
    def script_num_displot(self):
        script_num_list = [len(imgs) for imgs in self.img_hash_tab]
        mean_num = np.mean(script_num_list)

        sns.histplot(script_num_list, color="green", kde=False, stat="density", alpha=.7, label="Histogram")
        sns.kdeplot(script_num_list, color="purple", linestyle='--', label="KDE")
        plt.axvline(mean_num, color='orange', linestyle='dashed', linewidth=2, label=f'Mean: {mean_num:.2f}')
        mu, std = norm.fit(script_num_list)
        xmin, xmax = plt.xlim()
        x = np.linspace(xmin, xmax, 1000)
        p = norm.pdf(x, mu, std)
        expression = f'$N({mu:.2f}, {std:.2f}^2)$'
        plt.plot(x, p, 'purple', linewidth=2, label=expression)

        plt.legend(prop={'size': 15})
        plt.xticks(fontsize=15)
        plt.yticks(fontsize=15)
        plt.xlabel("Number of Fragments", fontdict=fontdict)
        plt.ylabel("Density", fontdict=fontdict)
        plt.savefig(self.result_path + "/script_num_displot.pdf")
        plt.clf()

    @staticmethod
    def _calc_area(pcd):
        contours, _ = cv2.findContours(pcd.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        if len(contours) > 0:
            contour = contours[0]
            # 计算轮廓的面积
            area = cv2.contourArea(contour)
            return area
        else:
            return 0

    def _build_pic(self, transformations):
        # 收集所有变换后的点云坐标以确定画布尺寸
        all_points = []
        centers = {}
        for idx in transformations:
            T = transformations[idx]
            pcd = self.data["full_pcd_all"][idx]
            pcd = np.hstack((pcd[:, 1].reshape(-1, 1),
                            pcd[:, 0].reshape(-1, 1)))
            homogeneous = np.hstack([pcd, np.ones((len(pcd), 1))])
            transformed = np.matmul(homogeneous, T.T)
            # 计算该碎片的中心点（x, y 坐标的均值）
            center = np.mean(transformed[:, :2], axis=0)
            centers[idx] = center
            all_points.append(transformed)
        all_points = np.concatenate(all_points, axis=0)

        min_x, max_x = np.min(all_points[:, 0]), np.max(all_points[:, 0])
        min_y, max_y = np.min(all_points[:, 1]), np.max(all_points[:, 1])

        # 计算调整后的中心点，使其相对于画布坐标系
        adjusted_centers = {idx: (
            center - np.array([min_x, min_y])).astype(np.int32) for idx, center in centers.items()}

        # 计算画布的尺寸
        canvas_width = int(np.ceil(max_x - min_x))
        canvas_height = int(np.ceil(max_y - min_y))
        canvas_size = max(canvas_height, canvas_width)
        canvas = np.zeros((canvas_size, canvas_size, 3), dtype=np.float32)

        # 处理每个碎片，叠加到画布
        for i, idx in enumerate(transformations):
            T = transformations[idx]
            adjusted_T = np.array([
                [T[0, 0], T[0, 1], T[0, 2] - min_x],
                [T[1, 0], T[1, 1], T[1, 2] - min_y]
            ], dtype=np.float32)
            img = self.data["img_all"][idx]
            pcd = self.data["full_pcd_all"][idx]
            pcd = np.hstack((pcd[:, 1].reshape(-1, 1),
                            pcd[:, 0].reshape(-1, 1)))
            homogeneous = np.hstack([pcd, np.ones((len(pcd), 1))])
            transformed = np.matmul(homogeneous, adjusted_T.T)
            transformed_img = cv2.warpAffine(img, adjusted_T, (canvas_size, canvas_size),
                                             flags=cv2.INTER_LINEAR,
                                             borderMode=cv2.BORDER_CONSTANT,
                                             borderValue=(0, 0, 0))
            # 加粗边缘
            for m in range(len(transformed)):
                cv2.circle(canvas, tuple(
                    transformed[m].astype(int)), 2, (255, 255, 255), -1)
            canvas += transformed_img

        # 将图像数据限制在0-255范围内
        canvas = np.clip(canvas, 0, 255)

        # 转换为uint8类型
        canvas = canvas.astype(np.uint8)
        return canvas, adjusted_centers, (canvas_height, canvas_width)

    def _add_edges(self, canvas, centers, e_pairs):
        canvas_size = canvas.shape[0]
        # 画出碎片中心点
        for _, center in centers.items():
            x, y = int(center[0]), int(center[1])
            if 0 <= x < canvas_size and 0 <= y < canvas_size:
                cv2.circle(canvas, center, radius=6,
                           color=(0, 0, 0), thickness=6)
        # 画出代表边的线
        for u, v in e_pairs:
            center1, center2 = centers[u], centers[v]
            cv2.line(canvas, center1, center2, (0, 255, 0), thickness=2)

        return canvas


if __name__ == "__main__":
    data_path = "./dataset/2192/train_set.pkl"
    vis = DatasetVis(data_path)
    vis.area_displot()
    vis.script_num_displot()
