"""
@Description :   局部特征匹配后分类的训练类
@Author      :   tqychy 
@Time        :   2025/02/15 11:41:53
"""
from trainers.base_trainers import BaseTrainer
from nets import (CNNScoreEvaluator, FocalLoss2, MatchingNet, decoder_nets,
                  feature_extract_nets, feature_fuse_nets)
from dataset import ClassifyDataset
from torch.utils.data import DataLoader
from sklearn.metrics import (accuracy_score, f1_score, precision_score,
                             recall_score)
import torch.nn as nn
import torch
import numpy as np
import cv2
import sys

sys.path.append("./")


class ClassifyTrainer(BaseTrainer):
    def __init__(self, writer_path: str, *args):
        super().__init__(writer_path, *args)

    @staticmethod
    def check_if_best(current_result, results) -> bool:
        if len(results) == 0:
            return True
        elif current_result > max(results):
            return True
        return False

    @staticmethod
    def pbar_desc(epoch_result_dict: dict) -> str:
        mean_loss_t, mean_loss_v = epoch_result_dict["loss"]["train"], epoch_result_dict["loss"]["valid"]
        mean_acc_t, mean_acc_v = epoch_result_dict["acc"]["train"], epoch_result_dict["acc"]["valid"]
        return f"loss(t/v): {mean_loss_t:.3f}/{mean_loss_v:.3f}, acc(t/v): {mean_acc_t:.3f}/{mean_acc_v:.3f}"

    def model_forward(self, batch: tuple):
        mask_para, imgs, pcd, c_input, t_input, adjs, factors, att_mask, gts = batch
        gts = gts.to(self.device)
        source_input = {
            "c_input": c_input[0].to(self.device),
            "t_input": t_input[0].to(self.device),
            "pcd": pcd[0].to(self.device)
        }
        target_input = {
            "c_input": c_input[1].to(self.device),
            "t_input": t_input[1].to(self.device),
            "pcd": pcd[1].to(self.device)
        }
        if self.calc_adjs:
            max_point_nums = len(pcd[0][0])
            adj_s = self.get_concat_adj2(adjs[0], max_point_nums)
            adj_t = self.get_concat_adj2(adjs[1], max_point_nums)
            source_input["adj"] = adj_s.to(self.device)
            target_input["adj"] = adj_t.to(self.device)

        pad_mask = self.get_pad_mask(mask_para).to(self.device)
        # sigmoid = nn.Sigmoid()
        similarity_matrices = self.model(source_input, target_input, pad_mask)

        preds = torch.zeros_like(similarity_matrices).to(self.device)
        for batch in range(similarity_matrices.shape[0]):
            similarity_matrix = similarity_matrices[batch].cpu(
            ).numpy()
            kernel = np.eye(3, dtype=np.uint8)
            kernel[1, 1] = 0
            kernel = np.rot90(kernel)
            similarity_matrix = cv2.erode(
                similarity_matrix, kernel, borderType=cv2.BORDER_CONSTANT, borderValue=0)
            kernel[1, 1] = 1
            similarity_matrix = cv2.dilate(
                similarity_matrix, kernel, borderType=cv2.BORDER_CONSTANT, borderValue=0)

            conv_threshold = self.cfg.TRAIN.CLASSIFY.CONV_THRES
            pred = np.array(
                (similarity_matrix > conv_threshold), dtype=np.int32)
            pred = torch.tensor(pred).to(preds.device)
            preds[batch] = pred

        scores = self.model(preds).to(torch.double)
        loss = self.criterion(scores, gts)

        pred = (scores.detach().cpu() > .5).numpy()
        gts = gts.detach().cpu().numpy()

        acc = accuracy_score(gts, pred)
        f1 = f1_score(gts, pred)
        prec = precision_score(gts, pred)
        rec = recall_score(gts, pred)

        result_dict = {
            "loss": loss.detach().cpu().item(),
            "acc": acc,
            "f1": f1,
            "prec": prec,
            "rec": rec
        }

        return result_dict, loss

    def set_dataset(self, train_dataset_path: str, valid_dataset_path: str, batch_size: int) -> tuple:
        calc_adjs_tab = {
            "ResGCN": True,
            "ViT": False
        }
        self.calc_adjs = calc_adjs_tab[self.cfg.TRAIN.CLASSIFY.FEATURE_EXTRACT]
        train_dataset = ClassifyDataset(
            train_dataset_path, self.cfg, self.logger, calc_adjs=self.calc_adjs)
        valid_dataset = ClassifyDataset(
            valid_dataset_path, self.cfg, self.logger, calc_adjs=self.calc_adjs)
        return DataLoader(train_dataset, batch_size, num_workers=0, shuffle=True), DataLoader(valid_dataset, batch_size, num_workers=0, shuffle=False)

    def set_loss(self) -> nn.Module:
        return nn.BCELoss()

    def set_model(self) -> nn.Module:
        feature_extract = feature_extract_nets[self.cfg.TRAIN.CLASSIFY.FEATURE_EXTRACT](
            self.cfg, self.logger)
        fuse = feature_fuse_nets[self.cfg.TRAIN.CLASSIFY.FEATURE_FUSE](
            self.cfg, self.logger)
        decoder = decoder_nets[self.cfg.TRAIN.CLASSIFY.DECODER](
            self.cfg, self.logger)
        classify = CNNScoreEvaluator(self.cfg, self.logger)
        model = MatchingNet(feature_extract, fuse, decoder, classify)

        state_dict = torch.load(self.cfg.TRAIN.CLASSIFY.STAT_DICT, weights_only=True)[
            "model_state_dict"]
        model_dict = model.state_dict()
        pretrained_dict = {k: v for k, v in state_dict.items(
        ) if k in model_dict and not k.startswith("classify")}
        model_dict.update(pretrained_dict)
        model.load_state_dict(model_dict)

        for name, param in model.named_parameters():
            if not name.startswith("classify"):
                param.requires_grad = False
            else:
                param.requires_grad = True

        return model

    def train_batch(self, batch: tuple):
        self.model.train()
        return self.model_forward(batch)

    def valid_batch(self, batch: tuple):
        self.model.eval()
        result_dict, _ = self.model_forward(batch)
        return result_dict, "acc"
