"""
@Description :   碎片匹配的训练
@Author      :   tqychy 
@Time        :   2025/01/02 12:14:16
"""
import sys

sys.path.append("./")
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

from dataset import PairingDataset
from nets import InfoNCE, PairingNet, feature_extract_nets, feature_fuse_nets
from trainers.base_trainers import BaseTrainer


class PairingTrainer(BaseTrainer):
    def __init__(self, writer_path, *args):
        super().__init__(writer_path, *args)
    
    @staticmethod
    def check_if_best(current_result, results) -> bool:
        if len(results) > 0 and current_result < min(results):
            return True
        return False
    
    @staticmethod
    def pbar_desc(epoch_result_dict: dict) -> str:
        mean_infor_loss_t, mean_infor_loss_v = epoch_result_dict["infor_loss"]["train"], epoch_result_dict["infor_loss"]["valid"]
        mean_top1_recall_t, mean_top1_recall_v = epoch_result_dict["top1_recall"]["train"], epoch_result_dict["top1_recall"]["valid"]
        return f"infor(t/v): {mean_infor_loss_t:.3f}/{mean_infor_loss_v:.3f}, recall(t/v): {mean_top1_recall_t:.3f}/{mean_top1_recall_v:.3f}"

    def model_forward(self, batch):
        idx, imgs, pcd, c_input, t_input, adjs, factors = batch

        idx_s, idx_t = idx
        source_input = {
            "c_input": c_input[0].to(self.device),
            "t_input": t_input[0].to(self.device),
            "pcd": pcd[0].to(self.device)
        }
        target_input = {
            "c_input": c_input[1].to(self.device),
            "t_input": t_input[1].to(self.device),
            "pcd": pcd[1].to(self.device)
        }
        if self.calc_adjs:
            max_point_nums = len(pcd[0][0])
            adj_s = self.get_concat_adj2(adjs[0], max_point_nums)
            adj_t = self.get_concat_adj2(adjs[1], max_point_nums)
            source_input["adj"] = adj_s.to(self.device)
            target_input["adj"] = adj_t.to(self.device)
            
        f_s = self.model(source_input)
        f_t = self.model(target_input)
        infor_loss = self.criterion(
            f_s, f_t, gt_pairs=(idx_s, idx_t))
        infor_s = self.criterion(
            f_s, f_s, gt_pairs=(idx_s, idx_s))
        infor_t = self.criterion(
            f_t, f_t, gt_pairs=(idx_t, idx_t))
        only_negative_weight = 0.5
        total_loss = infor_loss + \
            (only_negative_weight*infor_s +
             only_negative_weight*infor_t)/2
        top1_recall, top5_recall = self.calc_recall(
            f_s.detach().cpu(), f_t.detach().cpu())

        result_dict = {
            "infor_loss": infor_loss.detach().cpu().item(),
            "total_loss": total_loss.detach().cpu().item(),
            "top1_recall": top1_recall.item(),
            "top5_recall": top5_recall.item()
        }
        
        return result_dict, total_loss

    
    def set_dataset(self, train_dataset_path: str, valid_dataset_path: str, batch_size: int) -> tuple:
        calc_adjs_tab = {
            "ResGCN": True,
            "ViT": False
        }
        self.calc_adjs = calc_adjs_tab[self.cfg.TRAIN.PAIRING.FEATURE_EXTRACT]
        train_dataset = PairingDataset(train_dataset_path, self.cfg, self.logger, calc_adjs=self.calc_adjs)
        valid_dataset = PairingDataset(valid_dataset_path, self.cfg, self.logger, calc_adjs=self.calc_adjs)
        return DataLoader(train_dataset, batch_size, num_workers=0, shuffle=True), DataLoader(valid_dataset, batch_size, num_workers=0, shuffle=True)
    
    def set_loss(self) -> nn.Module:
        return InfoNCE(self.cfg.TRAIN.PAIRING.INFONCE_TEMPERATURE)
    
    def set_model(self) -> nn.Module:
        feature_extract = feature_extract_nets[self.cfg.TRAIN.PAIRING.FEATURE_EXTRACT](self.cfg, self.logger)
        fuse = feature_fuse_nets[self.cfg.TRAIN.PAIRING.FEATURE_FUSE](self.cfg, self.logger)

        return PairingNet(feature_extract, fuse)
    
    def train_batch(self, batch: tuple):
        self.model.train()
        return self.model_forward(batch)
        
    def valid_batch(self, batch: tuple):
        self.model.eval()
        result_dict, _ = self.model_forward(batch)
        return result_dict, "infor_loss"

    @staticmethod
    def calc_recall(batch_feature_s, batch_feature_t):
        F_normalized_s = nn.functional.normalize(batch_feature_s, p=2, dim=1)
        F_normalized_t = nn.functional.normalize(batch_feature_t, p=2, dim=1)
        bs = batch_feature_s.shape[0]
        GT_pairs = []
        for i in range(0, bs):
            GT_pairs.append([i, i])
        cos_sim_matrix = torch.matmul(F_normalized_s, F_normalized_t.T)
        sort_matrix = torch.sort(cos_sim_matrix, dim=-1, descending=True)

        idx = sort_matrix[1]
        idx = idx.cpu().numpy()
        l = []
        for i in range(len(GT_pairs)):
            l.append(np.argwhere(idx[GT_pairs[i][0]] == GT_pairs[i][1]))

        result = np.array(l).reshape(-1)
        top1 = (result < 1).sum() / len(l)
        top5 = (result < 5).sum() / len(l)

        return torch.tensor([top1]), torch.tensor([top5])


if __name__ == "__main__":

    outputs = torch.tensor([0.6, 0.4, 0.1, 0.8, 0.8])
    preds = outputs > 0.5
    labels = torch.tensor([1., 1., 0., 1., 0.])
    print(preds)
    print(sum(preds == labels) / len(labels))
