"""
@Description :   从 raw 中构建碎片数据集，提取边缘并分割训练、验证、测试集
@Author      :   tqychy 
@Time        :   2024/12/29 11:40:35
"""
import sys

sys.path.append("./")
import argparse
import os
import pickle
import random
import warnings

import numpy as np
from scripts.extract_contours import extract_contour
from scripts.gen_fragments import gen_fragment
from scripts.train_test_split import divide
from tqdm import tqdm

from config.default import cfg
from logger.logger import build_logger
from visualize.dataset_vistools import DatasetVis


def display_hyper_parameters(matching_set, logger):
    image_size = 0
    contour_max_len = 0
    max_scripts = 0
    max_edges = 0
    for shape in matching_set["shape_all"]:
        h, w, _ = shape
        image_size = max(h, w, image_size)

    for pcd in matching_set["full_pcd_all"]:
        contour_max_len = max(pcd.shape[0], contour_max_len)

    img_list = matching_set["img_list"]
    belong_img = matching_set["belong_image"]
    gt_pairs = np.array(matching_set['GT_pairs'])

    # 建立图片到碎片的映射表
    indices_dict = {img: idx for idx, img in enumerate(img_list)}
    img_hash_tab = [[] for _ in range(len(img_list))]
    gt_pairs_hash_tab = [[] for _ in range(len(img_list))]

    # 初始化映射关系
    for frag_idx in range(len(matching_set['img_all'])):
        img_name = belong_img[frag_idx]
        img_idx = indices_dict[img_name]
        img_hash_tab[img_idx].append(frag_idx)

    for pair_idx in range(len(gt_pairs)):
        idx1, _ = gt_pairs[pair_idx]
        img_name = belong_img[idx1]
        img_idx = indices_dict[img_name]
        gt_pairs_hash_tab[img_idx].append(pair_idx)

    for imgs in img_hash_tab:
        max_scripts = max(len(imgs), max_scripts)

    for edges in gt_pairs_hash_tab:
        max_edges = max(len(edges), max_edges)

    logger.info(
        f"数据集中，最大碎片宽高（image_size）: {image_size}, 最长碎片长度（contour_max_len）{contour_max_len}, 图片的最大碎片数量（max_scripts）: {max_scripts}，图片的最大碎片对数量（max_edges）: {max_edges}")


def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)


def main(cfg):
    logger, _ = build_logger(cfg)
    logger.info(cfg)
    set_seed(cfg.GLOBALS.SEED)

    # 生成碎片数据集
    if not cfg.DATASET.BUILD_FROM_EXIST_DATASET:
        logger.info("开始生成碎片数据集")
        gen_fragment(cfg, logger)

        # 提取边缘，存储元信息
        data_path = os.path.join(cfg.DATASET.FRAGMENT_PATH, "fragments")
        sub_list = os.listdir(data_path)
        save_name = os.path.join(
            cfg.DATASET.FRAGMENT_PATH, cfg.GLOBALS.EXPR_NAME + "_all.pkl")
        if os.path.exists(save_name):
            os.remove(save_name)

        matching_set = {
            "img_list": list(sub_list),  # 图片名称列表
            'full_pcd_all': [],  # 碎片下采样前的边缘
            'img_all': [],  # 碎片图像
            'belong_image': [],  # 碎片属于哪个图像（名称）
            'shape_all': [],  # 碎片图像大小
            "gt_pose": [],  # 碎片的真实变换矩阵
            'GT_pairs': [],  # 匹配的碎片对
            'source_ind': [],  # GT_pairs 中第 i 个对中第一个碎片边缘和另一个碎片边缘匹配的点的左边列表
            'target_ind': [],  # GT_pairs 中第 i 个对中第二个碎片边缘和另一个碎片边缘匹配的点的左边列表
            'down_sample_pcd': []  # 碎片下采样后的边缘
        }

        logger.info("开始提取边缘")
        with tqdm(total=len(sub_list)) as pbar:
            for _, img_name in enumerate(sub_list):
                matching_set = extract_contour(os.path.join(
                    data_path, img_name), matching_set, cfg, logger)
                pbar.update(1)

        logger.info("保存碎片元信息")
        with open(save_name, 'wb') as file:
            pickle.dump(matching_set, file)

    # 分割数据集
    data_path = cfg.DATASET.FRAGMENT_PATH
    logger.info("分割数据集")
    divide(data_path, cfg, logger)

    # 从 GroundTruth 中恢复
    if not cfg.DATASET.BUILD_FROM_EXIST_DATASET:
        logger.info("可视化碎片数据集")
        root = os.path.join(cfg.DATASET.FRAGMENT_PATH, "fragments")
        parent_path = os.path.dirname(root)
        overall_folder = os.path.join(parent_path, "recover_images")
        for type in ["train", "valid", "test"]:
            res_path = os.path.join(overall_folder, type)
            os.makedirs(res_path, exist_ok=True)
            vis = DatasetVis(os.path.join(cfg.DATASET.FRAGMENT_PATH, type + "_set.pkl"), res_path)
            vis.recover_images_by_gtinfo()
    logger.info("数据集生成完成")
    display_hyper_parameters(matching_set, logger)


if __name__ == "__main__":
    warnings.filterwarnings("ignore")

    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--config_path",
        type=str,
        default="./config/build_dataset/20.yaml"
    )
    args = parser.parse_args()
    cfg.merge_from_file(args.config_path)
    cfg.freeze()

    main(cfg)
    # import random
    # import cv2

    # random.seed(2048)
    # np.random.seed(2048)

    # data_path = "./dataset/1000_all/test_set.pkl"

    # def pcd2img(pcd):
    #     hw_max = 20 + pcd.max()
    #     edge_image = np.zeros((1, hw_max, hw_max))
    #     x_coords = pcd[:, 0]
    #     y_coords = pcd[:, 1]
    #     edge_image[0, x_coords, y_coords] = 255
    #     return 255 - edge_image

    # with open(data_path, 'rb') as f:
    #     data = pickle.load(f)

    # img_list = data["img_list"]
    # belong_img = data["belong_image"]
    # gt_pairs = np.array(data['GT_pairs'])

    # # 建立图片到碎片的映射表
    # indices_dict = {img: idx for idx, img in enumerate(img_list)}
    # img_hash_tab = [[] for _ in range(len(img_list))]
    # gt_pairs_hash_tab = [[] for _ in range(len(img_list))]

    # # 初始化映射关系
    # for frag_idx in range(len(data['img_all'])):
    #     img_name = belong_img[frag_idx]
    #     img_idx = indices_dict[img_name]
    #     img_hash_tab[img_idx].append(frag_idx)

    # for pair_idx in range(len(gt_pairs)):
    #     idx1, _ = gt_pairs[pair_idx]
    #     img_name = belong_img[idx1]
    #     img_idx = indices_dict[img_name]
    #     gt_pairs_hash_tab[img_idx].append(pair_idx)

    # for i in range(3):
    #     idx1, idx2 = random.choice(data["GT_pairs"])
    #     img_idx = indices_dict[belong_img[idx1]]
    #     idx3 = random.choice(img_hash_tab[img_idx])
    #     while (idx1, idx3) in gt_pairs_hash_tab[img_idx] or (idx3, idx1) in gt_pairs_hash_tab[img_idx] or idx1 == idx3:
    #         idx3 = random.choice(img_hash_tab[img_idx])

    #     img1, img2, img3 = data["img_all"][idx1], data["img_all"][idx2], data["img_all"][idx3]

    #     pcd1, pcd2, pcd3 = data["full_pcd_all"][idx1], data["full_pcd_all"][idx2], data["full_pcd_all"][idx3]
    #     pcd1 = pcd2img(pcd1).squeeze()
    #     pcd2 = pcd2img(pcd2).squeeze()
    #     pcd3 = pcd2img(pcd3).squeeze()

    #     cv2.imwrite(f"./temp/{i}_1i.png", img1)
    #     cv2.imwrite(f"./temp/{i}_2i.png", img2)
    #     cv2.imwrite(f"./temp/{i}_3i.png", img3)
    #     cv2.imwrite(f"./temp/{i}_1e.png", pcd1)
    #     cv2.imwrite(f"./temp/{i}_2e.png", pcd2)
    #     cv2.imwrite(f"./temp/{i}_3e.png", pcd3)
