"""
@Description :   碎片对分类数据集，输出平衡的正例（成对碎片）和反例（非成对碎片）
@Author      :   tqychy 
@Time        :   2025/02/15 11:28:25
"""
import sys

sys.path.append("./")
import random

import numpy as np
import torch
from tqdm import tqdm

from dataset.base_dataset import BaseDataSet


class ClassifyDataset(BaseDataSet):
    def __init__(self, data_path: str, *args, calc_adjs=True):
        super().__init__(data_path, *args, calc_adjs=calc_adjs)
        self.gt_pairs = np.array(self.data['GT_pairs'])
        self.img_list = self.data["img_list"]
        self.belong_img = self.data["belong_image"]
        self.factor = self.cfg.DATASET.CLASSIFY.FACTOR

        # 建立图片到碎片的映射表
        self.indices_dict = {img: idx for idx, img in enumerate(self.img_list)}
        self.img_hash_tab = [[] for _ in range(len(self.img_list))]
        self.gt_pairs_hash_tab = [set() for _ in range(len(self.img_list))]

        # 初始化映射关系
        for frag_idx in range(len(self.data['img_all'])):
            img_name = self.belong_img[frag_idx]
            img_idx = self.indices_dict[img_name]
            self.img_hash_tab[img_idx].append(frag_idx)

        for pair_idx in range(len(self.gt_pairs)):
            idx1, idx2 = self.gt_pairs[pair_idx]
            img_name = self.belong_img[idx1]
            img_idx = self.indices_dict[img_name]
            self.gt_pairs_hash_tab[img_idx].add((idx1, idx2))
            self.gt_pairs_hash_tab[img_idx].add((idx2, idx1))
        
        self.hash_tab = {i: set() for i in range(len(self.data["img_all"]))}
        for idx1, idx2 in self.gt_pairs:
            self.hash_tab[idx2].add(idx1)
            self.hash_tab[idx1].add(idx2)
        
        self.neg_pairs = self._neg_sample()

    def _neg_sample(self):
        self.logger.debug("开始负采样")
        neg_pairs = set()
        num = int(len(self.gt_pairs) / self.factor * (1 - self.factor))
        for i in tqdm(range(num), desc="负采样"):
            img_idx = i % len(self.img_list)
            max_tries = 1000
            idx1, idx2 = random.choices(self.img_hash_tab[img_idx], k=2)
            invalid = (idx2 == idx1) or ((idx1, idx2) in self.gt_pairs_hash_tab[img_idx])
            while invalid or ((idx1, idx2) in neg_pairs) or ((idx2, idx1) in neg_pairs):
                idx1, idx2 = random.choices(self.img_hash_tab[img_idx], k=2)
                invalid = (idx2 == idx1) or ((idx1, idx2) in self.gt_pairs_hash_tab[img_idx])
                max_tries -= 1
                if max_tries < 0:
                    break
            if max_tries < 0:
                idx1 = random.randint(0, len(self.data["img_all"]) - 1)
                idx2 = random.randint(0, len(self.data["img_all"]) - 1)
                invalid = (idx2 == idx1) or (idx2 in self.hash_tab[idx1]) or (idx1 in self.hash_tab[idx2])
                while invalid or ((idx1, idx2) in neg_pairs) or ((idx2, idx1) in neg_pairs):
                    idx2 = random.randint(0, len(self.data["img_all"]) - 1)
                    invalid = (idx2 == idx1) or (idx2 in self.hash_tab[idx1]) or (idx1 in self.hash_tab[idx2])
            neg_pairs.add((idx1, idx2))
        
        self.logger.info(f"负采样结束，正例 {len(self.gt_pairs)}, 反例 {num}")

        # for _ in range(num):
            # idx1 = random.randint(0, len(self.data["img_all"]) - 1)
            # idx2 = random.randint(0, len(self.data["img_all"]) - 1)
            # invalid = (idx2 == idx1) or (idx2 in self.hash_tab[idx1]) or (idx1 in self.hash_tab[idx2])
            # while invalid or ((idx1, idx2) in neg_pairs) or ((idx2, idx1) in neg_pairs):
            #     idx2 = random.randint(0, len(self.data["img_all"]) - 1)
            #     invalid = (idx2 == idx1) or (idx2 in self.hash_tab[idx1]) or (idx1 in self.hash_tab[idx2])
        #     neg_pairs.add((idx1, idx2))
        
        return list(neg_pairs)
    
    def __len__(self):
        return len(self.neg_pairs) + len(self.gt_pairs)

    def __getitem__(self, idx):
        if idx < len(self.gt_pairs):
            gt = 1.
            idx1, idx2 = self.gt_pairs[idx]
        else:
            gt = 0.
            idx1, idx2 = self.neg_pairs[idx - len(self.gt_pairs)]
        
        data1 = super().__getitem__(idx1)
        data2 = super().__getitem__(idx2)

        # 实际上不用这个 mask
        mask = torch.zeros(
            (self.max_points, self.max_points), dtype=torch.bool)

        return (
            (mask, self.long[idx1], self.long[idx2], idx1, idx2),
            (data1['img'], data2['img']),
            (data1['full_pcd'], data2['full_pcd']),
            (data1['c_input'], data2['c_input']),
            (data1['t_input'], data2['t_input']),
            (data1['adj'], data2['adj']),
            (data1['factor'], data2['factor']),
            (torch.zeros((1, 1)), torch.zeros((1, 1))),  # att_mask占位符
            gt
        )