"""
@Description :   碎片组合的训练类
@Author      :   tqychy 
@Time        :   2025/01/13 19:42:03
"""
import sys

sys.path.append("./")

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

from dataset import MatchingDataset
from nets import (FocalLoss, MatchingNet, decoder_nets, feature_extract_nets,
                  feature_fuse_nets)
from trainers.base_trainers import BaseTrainer


class MatchingTrainer(BaseTrainer):
    def __init__(self, writer_path, *args):
        super().__init__(writer_path, *args)
    
    @staticmethod
    def check_if_best(current_result, results) -> bool:
        if len(results) > 0 and current_result < min(results):
            return True
        return False
    
    @staticmethod
    def pbar_desc(epoch_result_dict: dict) -> str:
        mean_loss_t, mean_loss_v = epoch_result_dict["loss"]["train"], epoch_result_dict["loss"]["valid"]
        mean_ploss_t, mean_ploss_v = epoch_result_dict["positive_loss"]["train"], epoch_result_dict["positive_loss"]["valid"]
        return f"loss(t/v): {mean_loss_t:.3f}/{mean_loss_v:.3f}, ploss(t/v): {mean_ploss_t:.3f}/{mean_ploss_v:.3f}"
    
    def model_forward(self, batch):
        mask_para, imgs, pcd, c_input, t_input, adjs, factors, att_mask = batch
        source_input = {
            "c_input": c_input[0].to(self.device),
            "t_input": t_input[0].to(self.device),
            "pcd": pcd[0].to(self.device)
        }
        target_input = {
            "c_input": c_input[1].to(self.device),
            "t_input": t_input[1].to(self.device),
            "pcd": pcd[1].to(self.device)
        }
        if self.calc_adjs:
            max_point_nums = len(pcd[0][0])
            adj_s = self.get_concat_adj2(adjs[0], max_point_nums)
            adj_t = self.get_concat_adj2(adjs[1], max_point_nums)
            source_input["adj"] = adj_s.to(self.device)
            target_input["adj"] = adj_t.to(self.device)

        # mark the padded part in similarity matrix
        pad_mask = self.get_pad_mask(mask_para).to(self.device)
        similarity_matrix = self.model(source_input, target_input, pad_mask)

        # mark the gt corresponding in similarity matrix
        gt_mask = mask_para[0].to(self.device)
        pad_mask = torch.add(pad_mask, gt_mask)
        loss_np, loss_p = self.criterion(similarity_matrix, gt_mask, pad_mask)

        result_dict = {
            "loss": loss_np.detach().cpu().item(),
            "positive_loss": loss_p.detach().cpu().item()
        }

        return result_dict, loss_np
    
    def set_dataset(self, train_dataset_path: str, valid_dataset_path: str, batch_size: int) -> tuple:
        calc_adjs_tab = {
            "ResGCN": True,
            "ViT": False
        }
        self.calc_adjs = calc_adjs_tab[self.cfg.TRAIN.MATCHING.FEATURE_EXTRACT]
        train_dataset = MatchingDataset(train_dataset_path, self.cfg, self.logger, calc_adjs=self.calc_adjs)
        valid_dataset = MatchingDataset(valid_dataset_path, self.cfg, self.logger, calc_adjs=self.calc_adjs)
        return DataLoader(train_dataset, batch_size, num_workers=0, shuffle=True), DataLoader(valid_dataset, batch_size, num_workers=0, shuffle=False)
    
    def set_loss(self) -> nn.Module:
        return FocalLoss()
    
    def set_model(self) -> nn.Module:
        feature_extract = feature_extract_nets[self.cfg.TRAIN.MATCHING.FEATURE_EXTRACT](self.cfg, self.logger)
        fuse = feature_fuse_nets[self.cfg.TRAIN.MATCHING.FEATURE_FUSE](self.cfg, self.logger)
        decoder = decoder_nets[self.cfg.TRAIN.MATCHING.DECODER](self.cfg, self.logger)

        return MatchingNet(feature_extract, fuse, decoder)
    
    def train_batch(self, batch: tuple) -> dict:
        self.model.train()
        return self.model_forward(batch)
        
    def valid_batch(self, batch: tuple):
        self.model.eval()
        result_dict, _ = self.model_forward(batch)
        return result_dict, "positive_loss"