"""
@Description :   碎片组合的测试类
@Author      :   tqychy 
@Time        :   2025/01/20 16:48:07
"""
import sys

sys.path.append("./")
import os

import cv2
import numpy as np
import torch.nn as nn
from torch.utils.data import DataLoader
from tqdm import tqdm

from dataset import MatchingDataset
from nets import (MatchingNet, decoder_nets, feature_extract_nets,
                  feature_fuse_nets)
from trainers.base_trainers import BaseTester
from visualize.visualization import Visualization


class MatchingTester(BaseTester):
    def __init__(self, *args):
        super().__init__(*args)
        self.results_path = os.path.join(
            self.cfg.TEST.RES_SAVE_PATH, "matching_test_results")
        os.makedirs(self.results_path, exist_ok=True)

    def set_dataset(self, test_dataset_path: str):
        calc_adjs_tab = {
            "ResGCN": True,
            "ViT": False
        }
        self.calc_adjs = calc_adjs_tab[self.cfg.TEST.MATCHING.FEATURE_EXTRACT]
        test_dataset = MatchingDataset(
            test_dataset_path, self.cfg, self.logger, calc_adjs=self.calc_adjs)
        test_loader = DataLoader(
            test_dataset, self.cfg.TEST.BATCH_SIZE, shuffle=False)
        return test_dataset, test_loader

    def set_model(self) -> nn.Module:
        feature_extract = feature_extract_nets[self.cfg.TEST.MATCHING.FEATURE_EXTRACT](
            self.cfg, self.logger)
        fuse = feature_fuse_nets[self.cfg.TEST.MATCHING.FEATURE_FUSE](
            self.cfg, self.logger)
        decoder = decoder_nets[self.cfg.TEST.MATCHING.DECODER](
            self.cfg, self.logger)

        return MatchingNet(feature_extract, fuse, decoder)

    def test(self):
        self.logger.debug("开始测试局部特征匹配。")
        ori_data = self.test_dataset.data
        with tqdm(total=len(self.test_loader)) as pbar:
            for i, (mask_para, imgs, pcd, c_input, t_input, adjs, factors, att_mask) in enumerate(self.test_loader):
                source_input = {
                    "c_input": c_input[0].to(self.device),
                    "t_input": t_input[0].to(self.device),
                    "pcd": pcd[0].to(self.device)
                }
                target_input = {
                    "c_input": c_input[1].to(self.device),
                    "t_input": t_input[1].to(self.device),
                    "pcd": pcd[1].to(self.device)
                }
                if self.calc_adjs:
                    max_point_nums = len(pcd[0][0])
                    adj_s = self.get_concat_adj2(adjs[0], max_point_nums)
                    adj_t = self.get_concat_adj2(adjs[1], max_point_nums)
                    source_input["adj"] = adj_s.to(self.device)
                    target_input["adj"] = adj_t.to(self.device)

                # mark the padded part in similarity matrix
                pad_mask = self.get_pad_mask(mask_para).to(self.device)
                similarity_matrices = self.model(
                    source_input, target_input, pad_mask)

                for batch in range(similarity_matrices.shape[0]):
                    similarity_matrix = similarity_matrices[batch].cpu(
                    ).numpy()
                    kernel = np.eye(3, dtype=np.uint8)
                    kernel[1, 1] = 0
                    kernel = np.rot90(kernel)
                    similarity_matrix = cv2.erode(
                        similarity_matrix, kernel, borderType=cv2.BORDER_CONSTANT, borderValue=0)
                    kernel[1, 1] = 1
                    similarity_matrix = cv2.dilate(
                        similarity_matrix, kernel, borderType=cv2.BORDER_CONSTANT, borderValue=0)

                    mask, idx_s, idx_t = mask_para[0][batch], mask_para[3][batch], mask_para[4][batch]
                    gt_matrix = mask.to_dense().float().numpy()

                    s_pcd, t_pcd = ori_data['full_pcd_all'][idx_s], ori_data['full_pcd_all'][idx_t]
                    source_img, target_img = ori_data['img_all'][idx_s], ori_data['img_all'][idx_t]
                    vis = Visualization(gt_matrix, similarity_matrix, s_pcd, t_pcd, source_img,
                                        target_img, conv_threshold=self.cfg.TEST.MATCHING.CONV_THRES)
                    transformation, pairs = vis.get_transformation()

                    save_idx = i * self.cfg.TEST.BATCH_SIZE + batch
                    # save preds
                    img_save_path = os.path.join(self.results_path, "preds")
                    os.makedirs(img_save_path, exist_ok=True)
                    vis.get_img(os.path.join(img_save_path,
                                f'pred{save_idx}.png'), transformation)

                    # save groundtruth
                    img_save_path = os.path.join(self.results_path, "gts")
                    os.makedirs(img_save_path, exist_ok=True)
                    vis.get_gt_img(os.path.join(
                        img_save_path, f'gt{save_idx}.png'))

                    # save corres
                    img_save_path = os.path.join(self.results_path, "corres")
                    os.makedirs(img_save_path, exist_ok=True)
                    vis.get_corresponding(os.path.join(
                        img_save_path, f'corres{save_idx}.png'))

                pbar.update(1)