"""
@Description :   全局拼接测试
@Author      :   tqychy 
@Time        :   2025/01/22 16:36:23
"""
import sys

sys.path.append("./")
import os

import cv2
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch_geometric.utils import to_undirected
from tqdm import tqdm

from dataset import DeNoiseDataset, PairingAllDataset
from metrics.metrics_handler import MetricsHandler
from nets import (MatchingNet, PairingNet, decoder_nets, feature_extract_nets,
                  feature_fuse_nets, global_assemblers, graph_cluster,
                  score_evaluator)
from trainers.base_trainers import BaseTester


def graph_concat(v, v_num, idx_convert, e, e_num, metrics_handler):
    bs = len(v_num)
    total_v_num = sum(v_num)
    total_e_num = sum(e_num)

    concat_v = {key: torch.zeros(
        (total_v_num, *val.shape[2:])) for key, val in v.items()}
    concat_e = torch.zeros(
        (2, total_e_num), dtype=torch.int64, device=e.device)
    cluster_labels = torch.zeros(
        (total_v_num), dtype=torch.int64, device=e.device)
    concat_idx_convert = {}

    v_idx, e_idx = 0, 0
    for i in range(bs):
        current_v_num = v_num[i]
        current_e_num = e_num[i]
        current_idx_convert = idx_convert[i][:v_num[i]]
        cluster_labels[v_idx:v_idx + current_v_num] = i
        for key in concat_v.keys():
            concat_v[key][v_idx:v_idx +
                          current_v_num] = v[key][i, :current_v_num]
        concat_e[:, e_idx:e_idx + current_e_num] = e[i,
                                                     :, :current_e_num] + v_idx
        concat_idx_convert.update(
            {v_idx + loc_idx: int(glb_idx) for loc_idx, glb_idx in enumerate(current_idx_convert)})
        v_idx += int(current_v_num)
        e_idx += int(current_e_num)

    # GroundTruth 信息写入 gt_handler
    info_dict = {
        "v_num": total_v_num,
        "pic_num": bs,
        "e_gt": concat_e.T,
        "cluster_gt": cluster_labels
    }
    metrics_handler.set_batch_info(info_dict, concat_idx_convert)

    return concat_v, total_v_num, concat_idx_convert, metrics_handler


def binarizer(sim_mat, k):
    if k < 1:
        binary_mat = (sim_mat > k).long()
    else:
        k_val = int(k)
        k_val = max(1, min(k_val, sim_mat.size(-1)))
        _, indices = torch.topk(sim_mat, k=k_val, dim=-1)
        binary_mat = torch.zeros_like(sim_mat, dtype=torch.long)
        binary_mat.scatter_(-1, indices, 1)
    return binary_mat


class GradualPairingAllTester(BaseTester):
    def __init__(self, *args):
        self.cfg, self.logger = args
        self.device = torch.device(self.cfg.GLOBALS.DEVICE)
        self.results_path = self.cfg.TEST.RES_SAVE_PATH
        os.makedirs(self.results_path, exist_ok=True)

        calc_adjs_tab = {
            "ResGCN": True,
            "ViT": False
        }

        self.logger.info("加载全局拼接测试集。")
        self.pairing_calc_adjs = calc_adjs_tab[self.cfg.TEST.PAIRING_ALL.PAIRING_FEATURE_EXTRACT]
        self.test_dataset = PairingAllDataset(
            self.cfg.TEST.TEST_DATA_PATH, self.cfg, self.logger, calc_adjs=self.pairing_calc_adjs)
        self.test_loader = DataLoader(
            self.test_dataset, self.cfg.TEST.PAIRING_ALL.PICS_NUM, shuffle=False)

        self.logger.info("加载辅助数据集。")
        self.matching_calc_adjs = calc_adjs_tab[self.cfg.TEST.PAIRING_ALL.MATCHING_FEATURE_EXTRACT]
        self.denoise_dataset = DeNoiseDataset(self.cfg.TEST.TEST_DATA_PATH, self.cfg,
                                              self.logger, adjs=self.test_dataset.adj_all, calc_adjs=self.matching_calc_adjs)

        self.logger.info("加载全局特征匹配模型。")
        pairing_feature_extract = feature_extract_nets[self.cfg.TEST.PAIRING_ALL.PAIRING_FEATURE_EXTRACT](
            self.cfg, self.logger)
        pairing_fuse = feature_fuse_nets[self.cfg.TEST.PAIRING_ALL.PAIRING_FEATURE_FUSE](
            self.cfg, self.logger)
        self.pairing_net = PairingNet(
            pairing_feature_extract, pairing_fuse).to(self.device)
        self.pairing_net.load_state_dict(torch.load(
            self.cfg.TEST.PAIRING_ALL.PAIRING_STAT_DICT_PATH, weights_only=True)["model_state_dict"])
        self.pairing_net.eval()
        self.pairing_net.requires_grad_(False)

        self.logger.info("加载局部特征匹配模型。")
        matching_feature_extract = feature_extract_nets[self.cfg.TEST.PAIRING_ALL.MATCHING_FEATURE_EXTRACT](
            self.cfg, self.logger)
        matching_fuse = feature_fuse_nets[self.cfg.TEST.PAIRING_ALL.MATCHING_FEATURE_FUSE](
            self.cfg, self.logger)
        matching_decoder = decoder_nets[self.cfg.TEST.PAIRING_ALL.MATCHING_DECODER](
            self.cfg, self.logger)
        classifier = score_evaluator[self.cfg.TEST.PAIRING_ALL.SCORE_EVAL_TYPE](
            self.cfg, self.logger)
        self.matching_net = MatchingNet(
            matching_feature_extract, matching_fuse, matching_decoder, classifier).to(self.device)
        self.matching_net.load_state_dict(torch.load(
            self.cfg.TEST.PAIRING_ALL.MATCHING_STAT_DICT_PATH, weights_only=True)["model_state_dict"])
        self.matching_net.eval()
        self.matching_net.requires_grad_(False)

        self.metrics_handler = MetricsHandler(
            self.test_dataset.data, self.results_path, self.cfg, self.logger)

    def test(self):
        self.logger.debug("开始测试全局拼接。")
        preliminary_metrics = {m: [] for m in ["f1", "rec", "prec"]}
        cluster_metrics = {m: [] for m in ["ari"]}
        further_metrics = {m: [] for m in ["f1", "rec", "prec"]}
        score_eval_metrics = {m: [] for m in ["auc"]}
        assemble_metrics = {m: [] for m in ["f1", "rec", "prec"]}
        with tqdm(total=len(self.test_loader)) as pbar:
            for i, (v, v_num, idx_convert, e, e_num) in enumerate(self.test_loader):
                with tqdm(total=5, leave=False) as step:
                    v, v_num, idx_convert, self.metrics_handler = graph_concat(
                        v, v_num, idx_convert, e, e_num, self.metrics_handler)
                    pcd, c_input, t_input, adjs = v["full_pcd"], v["c_input"], v["t_input"], v["adj"]
                    inputs = {
                        'c_input': c_input.to(self.device),
                        't_input': t_input.to(self.device),
                        'pcd': pcd.to(self.device)
                    }
                    if self.pairing_calc_adjs:
                        max_point_nums = len(pcd[0][0])
                        adj = self.get_concat_adj2(adjs, max_point_nums)
                        inputs["adj"] = adj.to(self.device)

                    features = self.pairing_net(inputs)
                    F_normalized = nn.functional.normalize(
                        features, p=2, dim=1)
                    cos_sim_matrix = torch.matmul(
                        F_normalized, F_normalized.T).cpu()

                    # 二值化
                    cos_sim_matrix.fill_diagonal_(-1)
                    noisy_pred = binarizer(
                        cos_sim_matrix, self.cfg.TEST.PAIRING_ALL.K_DENOISE).nonzero()
                    noisy_pred = to_undirected(noisy_pred.T).T
                    noisy_pred = noisy_pred[noisy_pred[:, 0]
                                            < noisy_pred[:, 1]]
                    # 初步匹配评价指标
                    prec, rec, f1 = self.metrics_handler.pairing_metrics(
                        noisy_pred, gen_pic=False)
                    preliminary_metrics["f1"].append(f1)
                    preliminary_metrics["prec"].append(prec)
                    preliminary_metrics["rec"].append(rec)
                    self.logger.debug(
                        f"batch_{i} 初步匹配 F1: {f1}, Prec: {prec}, Rec: {rec}")
                    step.set_description(
                        f"初步匹配 F1：{f1:.3f}, Prec: {prec:.3f}, Rec: {rec:.3f}")
                    step.update(1)

                    # 点聚类
                    if self.cfg.TEST.PAIRING_ALL.CLUSTER:
                        clusters, noisy_edges = graph_cluster(
                            v_num, noisy_pred.numpy(), num_runs=1)
                        ari = self.metrics_handler.ari_metrics(clusters)
                        cluster_metrics["ari"].append(ari)
                        self.logger.debug(f"batch_{i} 聚类 ARI: {ari}")
                        step.set_description(f"聚类 ARI：{ari:.3f}")
                    else:
                        clusters = [np.arange(v_num)]
                        noisy_edges = [noisy_pred.numpy()]
                    step.update(1)
                    cleaned_pairs = []
                    noisy_scores = []

                    # 每个类分别拼接
                    with tqdm(total=len(clusters), leave=False) as pbar_cluster:
                        for pic_idx, (cluster, noisy_edge) in enumerate(zip(clusters, noisy_edges)):
                            # 去噪
                            self.denoise_dataset.set_noisy_pairs(
                                noisy_edge, idx_convert)
                            cleaned_pair, sim_mats, scores, all_scores = self.denoise()
                            cleaned_pairs.append(cleaned_pair)
                            noisy_scores.extend(all_scores)

                            # 全局拼接
                            assembler_type = self.cfg.TEST.PAIRING_ALL.ASSEMBLER_TYPE
                            if assembler_type is not None:
                                v_imgs = self.test_dataset.data['img_all']
                                v_pcds = self.test_dataset.data['full_pcd_all']
                                global_assembler = global_assemblers[assembler_type](
                                    v_imgs, v_pcds, cluster, idx_convert, cleaned_pair, sim_mats, scores, self.results_path, self.metrics_handler, self.cfg, self.logger)

                                (rec, prec, f1), post_ari = global_assembler.assemble(
                                    i, pic_idx)
                                if not self.cfg.TEST.PAIRING_ALL.CLUSTER:
                                    cluster_metrics["ari"].append(post_ari)
                                    self.logger.debug(f"batch_{i} POST_ARI: {post_ari}")
                                assemble_metrics["rec"].append(rec)
                                assemble_metrics["prec"].append(prec)
                                assemble_metrics["f1"].append(f1)
                                self.logger.debug(
                                    f"batch: {i}, pic_idx: {pic_idx}, 拼接 rec: {rec}, prec: {prec}, f1: {f1}")
                                pbar_cluster.set_description(
                                    f"拼接 rec：{rec:.3f}, prec: {prec:.3f}, f1: {f1:.3f}")
                                pbar_cluster.update(1)
                    step.set_description(f"拼接")
                    step.update(1)

                    # 可视化结果
                    auc = self.metrics_handler.distribution_metrics(
                        noisy_scores)
                    score_eval_metrics["auc"].append(auc)
                    self.logger.debug(f"batch_{i} 分数估计 auc: {auc}")
                    step.set_description(f"分数估计 auc: {auc:.3f}")
                    step.update(1)
                    self.metrics_handler.vis_graph(clusters, cleaned_pairs)

                    cleaned_pairs = [pair for pair in cleaned_pairs if len(pair) > 0 and pair.shape[1] > 0]
                    cleaned_pairs = torch.tensor(np.vstack(cleaned_pairs))
                    cos_sim_matrix.fill_diagonal_(0)
                    prec, rec, f1 = self.metrics_handler.pairing_metrics(
                        cleaned_pairs, cos_sim_matrix, True)
                    further_metrics["f1"].append(f1)
                    further_metrics["prec"].append(prec)
                    further_metrics["rec"].append(rec)
                    self.logger.debug(
                        f"batch_{i} 精细匹配 F1: {f1}, Prec: {prec}, Rec: {rec}")

                    step.update(1)
                    disp = sum(further_metrics["f1"]) / len(further_metrics["f1"]), sum(assemble_metrics["f1"]) / len(assemble_metrics['f1'])
                    pbar.set_description(f"精细 f1: {(disp[0]):.3f}, 拼接 f1: {(disp[1]):.3f}")
                    pbar.update(1)

        # 计算平均指标
        mean_f1 = sum(preliminary_metrics["f1"]) / \
            len(preliminary_metrics["f1"])
        mean_prec = sum(
            preliminary_metrics["prec"]) / len(preliminary_metrics["prec"])
        mean_rec = sum(preliminary_metrics["rec"]) / \
            len(preliminary_metrics["rec"])
        self.logger.info(
            f"初步匹配 平均 F1: {mean_f1}, 平均 Prec: {mean_prec}, 平均 Rec: {mean_rec}")

        mean_ari = sum(cluster_metrics["ari"]) / len(cluster_metrics["ari"])
        self.logger.info(f"聚类 平均 ARI: {mean_ari}")

        mean_f1 = sum(further_metrics["f1"]) / len(further_metrics["f1"])
        mean_prec = sum(further_metrics["prec"]) / len(further_metrics["prec"])
        mean_rec = sum(further_metrics["rec"]) / len(further_metrics["rec"])
        self.logger.info(
            f"精细匹配 平均 F1: {mean_f1}, 平均 Prec: {mean_prec}, 平均 Rec: {mean_rec}")

        mean_auc = sum(score_eval_metrics["auc"]) / \
            len(score_eval_metrics["auc"])
        self.logger.info(f"分数估计 平均 AUC: {mean_auc}")

        mean_f1 = sum(assemble_metrics["f1"]) / len(assemble_metrics["f1"])
        mean_prec = sum(assemble_metrics["prec"]) / \
            len(assemble_metrics["prec"])
        mean_rec = sum(assemble_metrics["rec"]) / len(assemble_metrics["rec"])
        self.logger.info(
            f"拼接 平均 Rec: {mean_rec}, 平均 Prec: {mean_prec}, 平均 F1: {mean_f1}")

    def denoise_forward(self, batch: tuple):
        mask_para, imgs, pcd, c_input, t_input, adjs, factors, att_mask = batch
        source_input = {
            "c_input": c_input[0].to(self.device),
            "t_input": t_input[0].to(self.device),
            "pcd": pcd[0].to(self.device)
        }
        target_input = {
            "c_input": c_input[1].to(self.device),
            "t_input": t_input[1].to(self.device),
            "pcd": pcd[1].to(self.device)
        }
        if self.matching_calc_adjs:
            max_point_nums = len(pcd[0][0])
            adj_s = self.get_concat_adj2(adjs[0], max_point_nums)
            adj_t = self.get_concat_adj2(adjs[1], max_point_nums)
            source_input["adj"] = adj_s.to(self.device)
            target_input["adj"] = adj_t.to(self.device)

        # mark the padded part in similarity matrix
        pad_mask = self.get_pad_mask(mask_para).to(self.device)
        similarity_matrix = self.matching_net(
            source_input, target_input, pad_mask)
        return similarity_matrix

    def denoise(self):
        clean_pairs = []  # 存储去噪后的边
        sim_matrice = []
        scores = []
        all_scores = []

        denoise_loader = DataLoader(
            self.denoise_dataset, self.cfg.TEST.BATCH_SIZE, shuffle=False)
        self.logger.debug("开始去噪。")
        # true_scores = []
        # false_scores = []
        with tqdm(total=len(denoise_loader), leave=False, desc="去噪") as pbar_denoise:
            for batch_sample in denoise_loader:
                mask_para, local_idx = batch_sample[0], batch_sample[-1]
                similarity_matrices = self.denoise_forward(batch_sample[:-1])
                preds = torch.zeros_like(similarity_matrices).to(self.device)
                for batch in range(similarity_matrices.shape[0]):
                    similarity_matrix = similarity_matrices[batch].cpu(
                    ).numpy()
                    kernel = np.eye(3, dtype=np.uint8)
                    kernel[1, 1] = 0
                    kernel = np.rot90(kernel)
                    similarity_matrix = cv2.erode(
                        similarity_matrix, kernel, borderType=cv2.BORDER_CONSTANT, borderValue=0)
                    kernel[1, 1] = 1
                    similarity_matrix = cv2.dilate(
                        similarity_matrix, kernel, borderType=cv2.BORDER_CONSTANT, borderValue=0)

                    conv_threshold = self.cfg.TEST.MATCHING.CONV_THRES
                    pred = np.array(
                        (similarity_matrix > conv_threshold), dtype=np.int32)
                    pred = torch.tensor(pred).to(preds.device)
                    preds[batch] = pred

                # 计算所有边的得分
                final_scores = self.matching_net(
                    preds, mask_para[1] + mask_para[2]).cpu().numpy()
                preds = preds.cpu().numpy()

                for batch, score in enumerate(final_scores):
                    idx1, idx2 = local_idx[0][batch], local_idx[1][batch]
                    score = score.item()
                    all_scores.append((int(idx1), int(idx2), score))
                    if score > self.cfg.TEST.PAIRING_ALL.DESISION_THRES:
                        clean_pairs.append((int(idx1), int(idx2)))
                        sim_matrice.append(preds[batch])
                        scores.append(score)

                pbar_denoise.update(1)

        return torch.tensor(np.array(clean_pairs)), torch.tensor(np.array(sim_matrice)), scores, all_scores


if __name__ == "__main__":
    def max_continuous_diagonal_length(mat: torch.Tensor) -> int:
        n = mat.shape[0]
        max_len = 0
        device = mat.device  # 支持GPU/CPU

        for k in range(-n + 1, n):
            diag = torch.diagonal(mat, offset=k)
            mask = diag == 1

            if mask.all():  # 全1直接取长度
                current_max = diag.size(0)
            elif not mask.any():  # 无1则跳过
                current_max = 0
            else:
                # 找到所有非1的位置，并添加首尾边界
                indices = torch.where(~mask)[0]
                indices = torch.cat([
                    torch.tensor([-1], dtype=torch.long, device=device),
                    indices,
                    torch.tensor([diag.size(0)],
                                 dtype=torch.long, device=device)
                ])
                # 计算间隔并取最大值
                diffs = indices[1:] - indices[:-1]
                current_max = (diffs - 1).max().item()

            if current_max > max_len:
                max_len = current_max

        return max_len

    mat = torch.tensor([
        [0, 0, 0],
        [0, 1, 1],
        [0, 0, 1]
    ])

    print(max_continuous_diagonal_length(mat))
