"""
@Description :   Trainer 的基类
@Author      :   tqychy 
@Time        :   2025/01/20 14:39:39
"""
import sys

sys.path.append("./")
import os
from glob import glob

import numpy as np
import torch
import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

def get_pad_mask(mask_para):
    """
    input points in each fragments are padded to a fixed nums.
    padded mask denotes the padded part in similarity matrix which
    calculated by source and target point feature.
    """
    bs, maxs, _ = mask_para[0].shape
    pad_mask = torch.zeros((bs, maxs, maxs), dtype=torch.bool)
    for m in range(bs):
        a = mask_para[2][m]
        b = mask_para[1][m]
        pad_mask[m][:, mask_para[2][m]:] = True
        pad_mask[m][mask_para[1][m]:, :] = True

    return pad_mask

def get_concat_adj2(adj, max_len):
    device = adj.device
    temp_adj = torch.zeros((2, 0), dtype=torch.int).to(device)
    for i in range(len(adj)):
        b = torch.nonzero(adj[i]).transpose(0, 1)
        # a = adj[i].coalesce().indices() #(2,8602)
        temp_adj = torch.hstack((temp_adj, b + i * max_len))

    return temp_adj

class BaseTrainer():
    def __init__(self, writer_path: str, *args):
        self.cfg, self.logger = args
        self.device = torch.device(self.cfg.GLOBALS.DEVICE)
        self.checkpoint_path = os.path.join(
            self.cfg.TRAIN.CHECKPOINT_PATH, self.cfg.GLOBALS.EXPR_NAME, "parameter")
        self.writer_path = writer_path
        os.makedirs(self.checkpoint_path, exist_ok=True)
        os.makedirs(self.writer_path, exist_ok=True)
        self.writer = SummaryWriter(self.writer_path)

        self.logger.debug("加载训练集和验证集。")
        train_dataset_path, valid_dataset_path = self.cfg.TRAIN.TRAIN_DATA_PATH, self.cfg.TRAIN.VALID_DATA_PATH
        self.train_loader, self.valid_loader = self.set_dataset(
            train_dataset_path, valid_dataset_path, self.cfg.TRAIN.BATCH_SIZE)

        self.logger.debug("加载模型。")
        self.model = self.set_model().to(self.device)

        self.logger.debug("设置损失函数和优化器。")
        self.optimizer = torch.optim.Adam(self.model.parameters(),
                                          lr=self.cfg.TRAIN.LR,
                                          weight_decay=self.cfg.TRAIN.WEIGHT_DECAY)
        self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            self.optimizer, T_max=self.cfg.TRAIN.EPOCH)
        self.criterion = self.set_loss()

    def train(self):
        self.logger.debug(f"开始训练 {self.cfg.TRAIN.TYPE}。")
        self.logger.debug(self.model)
        self.main_results = []  # 使用这个结果判断最佳模型参数
        start_epoch = self.load_checkpoint() if self.cfg.TRAIN.LOAD_FROM_CHECKPOINTS else 0

        with tqdm(total=self.cfg.TRAIN.EPOCH - start_epoch) as pbar:
            for epoch in range(start_epoch, self.cfg.TRAIN.EPOCH):

                main_result, epoch_results_dict = self.run_epoch()

                if self.check_if_best(main_result, self.main_results):
                    self.save_best_checkpoint()
                    log_str = f"找到目前最优模型，epoch: {epoch}"
                    for key, val in epoch_results_dict.items():
                        valid_val = val["valid"]
                        log_str += f", 验证 {key}: {valid_val}"
                    log_str += " 。"
                    self.logger.debug(log_str)
                self.main_results.append(main_result)

                log_str = f"epoch: {epoch}"
                for key, val in epoch_results_dict.items():
                    self.writer.add_scalars(key, val, epoch)
                    train_val = val["train"]
                    valid_val = val["valid"]
                    log_str += f"; 训练 {key}: {train_val}, 验证 {key}: {valid_val}"
                log_str += " 。"

                if epoch > 0 and epoch % self.cfg.TRAIN.SAVE_INTERVAL == 0:
                    self.save_checkpoint(epoch)

                self.logger.debug(log_str)
                pbar.set_description(self.pbar_desc(epoch_results_dict))
                pbar.update(1)

    def run_epoch(self):
        train_results_dict = None  # 存储训练过程中的结果
        valid_results_dict = None  # 存储验证过程中的结果
        main_result_name = None
        # 训练
        with tqdm(total=len(self.train_loader), leave=False, desc="训练") as pbar_t:
            for batch in self.train_loader:
                train_results, loss = self.train_batch(batch)

                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

                if train_results_dict == None:
                    train_results_dict = {k: [] for k in train_results}
                for k, v in train_results.items():
                    train_results_dict[k].append(v)
                pbar_t.update(1)

            self.scheduler.step()
        # 验证
        with tqdm(total=len(self.valid_loader), leave=False, desc="测试") as pbar_v:
            with torch.no_grad():
                for batch in self.valid_loader:
                    valid_results, name = self.valid_batch(batch)
                    if valid_results_dict == None:
                        valid_results_dict = {k: [] for k in valid_results}
                    if main_result_name == None:
                        main_result_name = name
                    for k, v in valid_results.items():
                        valid_results_dict[k].append(v)
                    pbar_v.update(1)

        epoch_results_dict = {}
        for k in train_results_dict.keys():
            train_val = train_results_dict[k]
            valid_val = valid_results_dict[k]

            mean_train_val = sum(train_val) / \
                len(train_val) if len(train_val) > 0 else 0.
            mean_valid_val = sum(valid_val) / \
                len(valid_val) if len(valid_val) > 0 else 0.

            epoch_results_dict[k] = {
                "train": mean_train_val, "valid": mean_valid_val}

        main_result = epoch_results_dict[main_result_name]["valid"]

        return main_result, epoch_results_dict

    @staticmethod
    def check_if_best(current_result, results) -> bool:
        if len(results) > 0 and current_result > max(results):
            return True
        return False

    @staticmethod
    def pbar_desc(epoch_result_dict: dict) -> str:
        raise NotImplementedError

    def model_forward(self, batch: tuple):
        raise NotImplementedError

    def set_dataset(self, train_dataset_path: str, valid_dataset_path: str, batch_size: int) -> tuple:
        raise NotImplementedError

    def set_loss(self) -> nn.Module:
        raise NotImplementedError

    def set_model(self) -> nn.Module:
        raise NotImplementedError

    def train_batch(self, batch: tuple):
        # 这里要 model.train()
        raise NotImplementedError

    def valid_batch(self, batch: tuple):
        # 这里要 model.eval()
        raise NotImplementedError

    def save_checkpoint(self, epoch: int):
        """
        存储一个 epoch 结束之后的 checkpoints
        """
        state_dict = {
            "epoch": epoch,
            "main_results": self.main_results,
            "model_state_dict": self.model.state_dict(),
            "optimizer_state_dict": self.optimizer.state_dict()
        }
        path = os.path.join(self.checkpoint_path, f"checkpoint_{epoch}.tar")
        if not os.path.exists(path):
            torch.save(state_dict, path)

    def save_best_checkpoint(self):
        """
        存储验证 Acc 最高的 checkpoint
        """
        path = os.path.join(self.checkpoint_path, "checkpoint_best.tar")
        if os.path.exists(path):
            os.remove(path)
        torch.save({'model_state_dict': self.model.state_dict()}, path)

    def load_checkpoint(self):
        checkpoints = glob(self.checkpoint_path + '/*')
        if len(checkpoints) == 0:
            self.logger.debug(
                f'No checkpoints found at {self.checkpoint_path}')
            return 0

        checkpoints = [os.path.splitext(os.path.basename(path))[
            0].split('_')[-1] for path in checkpoints]
        checkpoints = np.array(checkpoints, dtype=float)
        checkpoints = np.sort(checkpoints)
        path = os.path.join(self.checkpoint_path,
                            f"checkpoint_{int(checkpoints[-1])}.tar")

        self.logger.info(f'Loaded checkpoint from: {path}')
        state_dict = torch.load(path, weights_only=True)
        self.main_results = state_dict["main_results"]
        self.model.load_state_dict(state_dict['model_state_dict'])
        self.optimizer.load_state_dict(state_dict['optimizer_state_dict'])
        epoch = state_dict['epoch']

        return epoch

    @staticmethod
    def get_pad_mask(mask_para):
        return get_pad_mask(mask_para)

    @staticmethod
    def get_concat_adj2(adj, max_len):
        return get_concat_adj2(adj, max_len)


class BaseTester():
    def __init__(self, *args):
        self.cfg, self.logger = args
        self.device = torch.device(self.cfg.GLOBALS.DEVICE)

        self.logger.debug("加载测试集。")
        self.test_dataset, self.test_loader = self.set_dataset(
            self.cfg.TEST.TEST_DATA_PATH)

        self.model = self.set_model().to(self.device)
        self.model.load_state_dict(torch.load(
            self.cfg.TEST.STAT_DICT_PATH, weights_only=True)["model_state_dict"])
        self.model.eval()
        self.model.requires_grad_(False)

    def set_model(self) -> nn.Module:
        raise NotImplementedError

    def set_dataset(self, test_dataset_path: str) -> tuple:
        raise NotImplementedError
    
    def test(self):
        raise NotImplementedError
    
    @staticmethod
    def get_pad_mask(mask_para):
        return get_pad_mask(mask_para)

    @staticmethod
    def get_concat_adj2(adj, max_len):
        return get_concat_adj2(adj, max_len)
