"""
@Description :   全局匹配数据集，以 graph 的形式返回一批碎片（节点）和其中的所有的配对（边）
@Author      :   tqychy 
@Time        :   2025/01/22 13:56:17
"""
import sys

sys.path.append("./")

import numpy as np
import torch

from dataset.base_dataset import BaseDataSet


class PairingAllDataset(BaseDataSet):
    def __init__(self, data_path, *args, calc_adjs=True):
        super().__init__(data_path, *args, calc_adjs=calc_adjs)
        self.img_list = self.data["img_list"]
        self.belong_img = self.data["belong_image"]
        self.gt_pairs = np.array(self.data['GT_pairs'])

        # 建立图片到碎片的映射表
        self.indices_dict = {img: idx for idx, img in enumerate(self.img_list)}
        self.img_hash_tab = [[] for _ in range(len(self.img_list))]
        self.gt_pairs_hash_tab = [[] for _ in range(len(self.img_list))]

        # 初始化映射关系
        for frag_idx in range(len(self.data['img_all'])):
            img_name = self.belong_img[frag_idx]
            img_idx = self.indices_dict[img_name]
            self.img_hash_tab[img_idx].append(frag_idx)

        for pair_idx in range(len(self.gt_pairs)):
            idx1, _ = self.gt_pairs[pair_idx]
            img_name = self.belong_img[idx1]
            img_idx = self.indices_dict[img_name]
            self.gt_pairs_hash_tab[img_idx].append(pair_idx)

    def __len__(self):
        return len(self.img_list)

    def __getitem__(self, idx):
        max_scripts = self.cfg.DATASET.MAX_SCRIPTS
        max_edges = self.cfg.DATASET.MAX_EDGES

        # 动态获取该图片所有碎片数据
        v = {k: None for k in ["img", "full_pcd",
                               "c_input", "t_input", "adj", "factor"]}
        e = torch.zeros((2, max_edges), dtype=torch.int64)
        reflect_tab = {}  # 碎片全局索引 -> 局部索引映射
        frag_indices = self.img_hash_tab[idx]
        v_num = len(frag_indices)

        # 遍历该图片的所有碎片
        for local_idx, frag_idx in enumerate(frag_indices):
            # 动态获取碎片数据
            frag_data = super().__getitem__(frag_idx)
            for key, val in frag_data.items():
                if key not in v.keys():
                    continue
                if v[key] == None:
                    v[key] = torch.zeros((max_scripts, *val.shape))
                v[key][local_idx] = val
            reflect_tab[frag_idx] = local_idx

        # 填充边数据
        e_num = 0
        for pair_idx in self.gt_pairs_hash_tab[idx]:
            if e_num >= max_edges:
                break

            idx1, idx2 = self.gt_pairs[pair_idx]
            if idx1 in reflect_tab and idx2 in reflect_tab:
                e[0, e_num] = reflect_tab[idx1]
                e[1, e_num] = reflect_tab[idx2]
                e_num += 1

        # 防止孤立节点
        active_nodes = set(e[:, :e_num].flatten().tolist())
        if len(active_nodes) > 0 and (max(active_nodes) >= v_num):
            v_num = max(active_nodes) + 1
            v_num = min(v_num, max_scripts)  # 双重保险

        # 局部索引映射 -> 碎片全局索引映射表
        idx_convert = torch.zeros((max_scripts))
        for glb_idx, loc_idx in reflect_tab.items():
            idx_convert[loc_idx] = glb_idx

        return v, v_num, idx_convert, e, e_num


if __name__ == "__main__":
    import argparse

    from config.default import cfg
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--config_path",
        type=str,
        default="./config/pairingall_cnn/train.yaml"
    )
    args = parser.parse_args()
    cfg.merge_from_file(args.config_path)
    data_path = "dataset/1000_all/train_set.pkl"
