"""
@Description :   训练集、验证集、测试集划分
@Author      :   tqychy 
@Time        :   2024/12/29 09:33:44
"""
import os
import pickle
import random

import numpy as np


def divide(data_path, *args):

    train_set = {
        "img_list": [],
        'full_pcd_all': [],
        'img_all': [],
        'belong_image': [],
        'shape_all': [],
        "gt_pose": [],
        'GT_pairs': [],
        'source_ind': [],
        'target_ind': [],
        'intersection_len': [],
    }
    valid_set = {
        "img_list": [],
        'full_pcd_all': [],
        'img_all': [],
        'belong_image': [],
        'shape_all': [],
        "gt_pose": [],
        'GT_pairs': [],
        'source_ind': [],
        'target_ind': [],
        'intersection_len': [],
    }
    test_set = {
        "img_list": [],
        'full_pcd_all': [],
        'img_all': [],
        'belong_image': [],
        'shape_all': [],
        "gt_pose": [],
        'GT_pairs': [],
        'source_ind': [],
        'target_ind': [],
        'intersection_len': [],
    }

    cfg, logger = args
    with open(os.path.join(data_path, cfg.GLOBALS.EXPR_NAME + '_all.pkl'), 'rb') as file:
        data = pickle.load(file)

    # 按图片分割数据集
    nums = len(data["img_list"])
    shuffle_ind = np.arange(0, nums)
    random.shuffle(shuffle_ind)
    reflect_tab = {}
    dot1, dot2 = int(nums * cfg.DATASET.TRAIN_VALID_PERCENT[0]), int(
        nums * sum(cfg.DATASET.TRAIN_VALID_PERCENT))

    # 分割碎片本身的元信息
    train_indices, valid_indices = set(
        shuffle_ind[:dot1]), set(shuffle_ind[dot1:dot2])
    indices_dict = {value: index for index,
                    value in enumerate(data["img_list"])}
    img_info_keys = ["img_all", "belong_image", "full_pcd_all", "shape_all", "gt_pose"]
    for i in range(len(data["img_all"])):
        if indices_dict[data["belong_image"][i]] in train_indices:
            for key in train_set.keys():
                if key in img_info_keys:
                    train_set[key].append(data[key][i])
            reflect_tab[i] = len(train_set["img_all"]) - 1
        elif indices_dict[data["belong_image"][i]] in valid_indices:
            for key in valid_set.keys():
                if key in img_info_keys:
                    valid_set[key].append(data[key][i])
            reflect_tab[i] = len(valid_set["img_all"]) - 1
        else:
            for key in test_set.keys():
                if key in img_info_keys:
                    test_set[key].append(data[key][i])
            reflect_tab[i] = len(test_set["img_all"]) - 1

    train_set["img_list"] = [data["img_list"][i]
                             for i in shuffle_ind[:dot1]]
    valid_set["img_list"] = [data["img_list"][i]
                             for i in shuffle_ind[dot1:dot2]]
    test_set["img_list"] = [data["img_list"][i]
                            for i in shuffle_ind[dot2:]]
    # 更新各数据集的 gt_pairs
    for i, pair in enumerate(data['GT_pairs']):
        idx1, idx2 = pair
        if indices_dict[data["belong_image"][idx1]] in train_indices:  # 碎片对在训练集中
            train_set["GT_pairs"].append(
                (reflect_tab[idx1], reflect_tab[idx2]))
            train_set["source_ind"].append(data["source_ind"][i])
            train_set["target_ind"].append(data["target_ind"][i])
        elif indices_dict[data["belong_image"][idx1]] in valid_indices:
            valid_set["GT_pairs"].append(
                (reflect_tab[idx1], reflect_tab[idx2]))
            valid_set["source_ind"].append(data["source_ind"][i])
            valid_set["target_ind"].append(data["target_ind"][i])
        else:
            test_set["GT_pairs"].append((reflect_tab[idx1], reflect_tab[idx2]))
            test_set["source_ind"].append(data["source_ind"][i])
            test_set["target_ind"].append(data["target_ind"][i])

    # 打印长度
    train_set_info = (len(train_set["img_list"]), len(
        train_set["img_all"]), len(train_set["GT_pairs"]))
    valid_set_info = (len(valid_set["img_list"]), len(
        valid_set["img_all"]), len(valid_set["GT_pairs"]))
    test_set_info = (len(test_set["img_list"]), len(
        test_set["img_all"]), len(test_set["GT_pairs"]))
    logger.info(
        f"训练集：{train_set_info[0]} 张图片，{train_set_info[1]} 个碎片，{train_set_info[2]} 对碎片对。")
    logger.info(
        f"验证集：{valid_set_info[0]} 张图片，{valid_set_info[1]} 个碎片，{valid_set_info[2]} 对碎片对。")
    logger.info(
        f"测试集：{test_set_info[0]} 张图片，{test_set_info[1]} 个碎片，{test_set_info[2]} 对碎片对。")

    if os.path.exists(os.path.join(data_path, 'train_set.pkl')):
        os.remove(os.path.join(data_path, 'train_set.pkl'))
    if os.path.exists(os.path.join(data_path, 'valid_set.pkl')):
        os.remove(os.path.join(data_path, 'valid_set.pkl'))
    if os.path.exists(os.path.join(data_path, 'test_set.pkl')):
        os.remove(os.path.join(data_path, 'test_set.pkl'))

    with open(os.path.join(data_path, 'train_set.pkl'), 'wb') as file:
        pickle.dump(train_set, file)
    with open(os.path.join(data_path, 'valid_set.pkl'), 'wb') as file:
        pickle.dump(valid_set, file)
    with open(os.path.join(data_path, 'test_set.pkl'), 'wb') as file:
        pickle.dump(test_set, file)


if __name__ == '__main__':
    import cv2
    data_path = "./dataset/1000_all"
    # divide(data_path)

    # with open(data_path+'/train_set.pkl', 'rb') as f:
    #     train_set = pickle.load(f)
    with open(data_path+'/build_dataset_1000_pairing_valid_set.pkl', 'rb') as f:
        valid_set = pickle.load(f)
    # with open(data_path+'/ori_test_set.pkl', 'rb') as f:
    #     test_set = pickle.load(f)
    dic = {}
    for i, pair in enumerate(valid_set['GT_pairs'][:1000]):
        idx1, idx2 = pair
        img_name = valid_set["belong_image"][idx1]
        if img_name not in dic.keys():
            dic[img_name] = 1
        else:
            dic[img_name] = 1 + dic[img_name]
    for k, v in dic.items():
        print(f"{k}: {v}")
        # img1, img2 = valid_set["img_all"][idx1], valid_set["img_all"][idx2]
        # cv2.imwrite(f"./{i}_1.png", img1)
        # cv2.imwrite(f"./{i}_2.png", img2)
        # print(valid_set["belong_image"][idx1])