
import argparse
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image
import torch.nn as nn
from models.TransFaceGCNEmotion import TransFaceGCNEmotion
from loss.ContrastiveLoss import GraphContrastiveLoss
from loss.SemiSupervisedLoss import SemiSupervisedExpressionLoss
import json
import os
import pandas as pd
import torch
from tqdm import tqdm
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, roc_auc_score
from tabulate import tabulate
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader
import numpy as np
import re
# 设备配置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")




def clean_ansi(text):
    """移除ANSI颜色代码"""
    ansi_escape = re.compile(r'\x1B(?:[@-Z\\-_] |\[[0-?]*[ -/]*[@-~])')
    return ansi_escape.sub('', text)


def print_side_by_side(val_dataset_list, log_file, loss, epoch, final_metrics, best_metrics,val_loss):
    """并排输出多个数据集的验证结果"""
    # headers = ["Metric", "DOLOS(Current|Best)", "BagOfLies(Current|Best)", "MU3D(Current|Best)", "RLT (Current|Best)",
    #            "SEUMLD (Current|Best)"]
    # datasets = ['DOLOS', 'BagOfLies', 'MU3D', 'RLT', 'SEUMLD']
    datasets = list(val_dataset_list.keys())
    headers = ["Metric", ]
    for dataset in datasets:
        headers.append(f'{dataset}(Current|Best)')

    # 构建对比数据行
    rows = []
    for metric in ['acc', 'f1', 'precision', 'recall', 'auc']:
        row = [metric]
        for dataset in datasets:
            current = final_metrics[dataset][metric.lower()]
            best = best_metrics[dataset][metric.lower()]
            trend = "↑" if current > best else "↓"
            color = "\033[1;31m" if current > best else "\033[1;32m"
            row.append(f"{current:.4f}  {color}{trend}\033[0m {best:.4f}")
        rows.append(row)

    # 输出对齐的表格
    print(f"\n===  Validation @ 第{epoch + 1}轮 train loss:{loss:.4f} ==== val loss:{val_loss:.4f}===")
    print(f"\n\033[1mVALIDATION RESULTS COMPARISON\033[0m train loss:{loss:.4f} ====  val loss:{val_loss:.4f}")
    file_content = tabulate(rows, headers=headers, tablefmt="grid")
    print(file_content)
    # 在写入文件时调用
    with open(log_file, "a", encoding="utf-8") as f:
        # 文件输出（无颜色代码）
        f.write(f"\n\n===  Validation @ 第{epoch + 1}轮 train loss:{loss:.4f} ==== val loss:{val_loss:.4f}===\n")
        f.write(clean_ansi(file_content))
        f.write("\n" + "=" * 50 + "\n")

def _loss_(opt, gcn_cls_outputs, gcn_feature, emotion_out, labels, emotion):
    criterion = nn.CrossEntropyLoss()
    #contrastive_loss_fn = GraphContrastiveLoss(temperature=0.1)
    # emotion_loss_fn = SemiSupervisedExpressionLoss(
    #     sup_weight=1.0,  # 监督损失权重
    #     unsup_weight=0.4,  # 无监督损失权重
    #     temp=0.05,  # 温度系数（值越小相似度分布越尖锐）
    #     frame=opt.frame_num
    # )
    # 损失计算
    # 面部帧损失 二分类欺诈与真实损失
    gcn_deceive_loss = criterion(gcn_cls_outputs, labels)
    #面部帧图关系损失
    #face_gcn_contrastive_loss = contrastive_loss_fn(gcn_feature, labels)
    #表情类别半监督损失
    #emotion_cls_loss = emotion_loss_fn(emotion_out, emotion) / opt.frame_num
    return gcn_deceive_loss

def calculate_metrics(true, pred, prob):
    """统一指标计算逻辑"""
    return {
        'acc': accuracy_score(true, pred),
        'f1': f1_score(true, pred, average='weighted'),
        'precision': precision_score(true, pred, average='weighted'),
        'recall': recall_score(true, pred, average='weighted'),
        'auc': roc_auc_score(true, prob) if len(np.unique(true)) == 2 else float('nan'),
        'composite': 0.6 * accuracy_score(true, pred) + 0.4 * f1_score(true, pred, average='weighted')
    }

'''
面部视频数据集
'''
'''
计算更多指标
'''
def validate_2(opt, loss, epoch, best_metrics, model, val_dataset_list):
    model.eval()
    results = {}
    metrics = {}
    val_loss = []
    with torch.no_grad():
        for key, value in val_dataset_list.items():
            val_preds, val_true, val_probs = [], [], []
            feature_list,label_list = [] ,[]
            # 进度条配置（兼容指标扩展）
            with tqdm(value, desc=f"Validating {key}", leave=True) as pbar:
                for inputs, emotion, labels in pbar:
                    inputs = inputs.to(device)
                    labels = labels.to(device)
                    cls, emotion_out, gcn_feature = model(inputs)
                    # 统计验证集损失
                    val_loss.append(_loss_(opt, cls, emotion_out, gcn_feature, labels, emotion).item())
                    # 获取预测结果和概率
                    batch_preds = torch.argmax(cls, 1).cpu().numpy()
                    batch_probs = torch.softmax(cls, 1)[:, 1].cpu().numpy()  # 二分类取正类概率
                    batch_true = labels.cpu().numpy()
                    # 生成特征分布图
                    if opt.gen_feature_map and (epoch + 1) % 10 == 0:
                        feature_list.append(gcn_feature.flatten(1).cpu().numpy())
                        label_list.append(labels.cpu().numpy())

                    val_preds.extend(batch_preds)
                    val_true.extend(batch_true)
                    val_probs.extend(batch_probs)

                    # 实时指标计算
                    temp_metrics = calculate_metrics(
                        np.array(val_true),
                        np.array(val_preds),
                        np.array(val_probs)
                    )
                    pbar.set_postfix(temp_metrics)

            final_metrics = calculate_metrics(val_true, val_preds, val_probs)
            results[key] = final_metrics['acc'], final_metrics['f1'], final_metrics['precision'], final_metrics[
                'recall'], final_metrics['auc']
            metrics[key] = final_metrics
            print_side_by_side(val_dataset_list, opt.log_file, loss, epoch, metrics, best_metrics,np.mean(val_loss))
            #生成特征分布图
            if opt.gen_feature_map and (epoch + 1) % 10 == 0:
                feature_list = np.concatenate(feature_list, axis=0)  # (N, 2048)
                label_list = np.concatenate(label_list, axis=0)  # 0=real, 1=fake
                #tsne_feature_vision(key,feature_list,label_list,epoch=epoch,opt=opt)
    return results, np.mean(val_loss)

def test(model, test_dataset_list, opt):
    model.eval()
    results = {}
    for key, value in test_dataset_list.items():
        dir = f'{opt.save_pth}/{key}'
        if os.path.exists(dir) and len(os.listdir(dir)) > 0:
            weight_path = f'{dir}/{os.listdir(dir)[0]}'
            best_weight = torch.load(weight_path)
            weight = best_weight['best_weights']['acc']
            epoch = best_weight['epoch']
            model.load_state_dict(weight)
            cul_metrics = {}
            metrics = {}
            for evaluator_key in ['acc', 'f1', 'precision', 'recall', 'auc']:
                metrics[evaluator_key] = []
            # 进度条配置（兼容指标扩展）
            with torch.no_grad():
                val_preds, val_true, val_probs = [], [], []
                with tqdm(value, desc=f"Testing {key}", leave=True) as pbar:
                    for inputs, emotion, labels in pbar:
                        inputs = inputs.to(device)
                        cls, _1, _2 = model(inputs)
                        # 获取预测结果和概率
                        batch_preds = torch.argmax(cls, 1).cpu().numpy()
                        batch_probs = torch.softmax(cls, 1)[:, 1].cpu().numpy()  # 二分类取正类概率

                        batch_true = labels.numpy()

                        val_preds.extend(batch_preds)
                        val_true.extend(batch_true)
                        val_probs.extend(batch_probs)

                        # 实时指标计算
                        temp_metrics = calculate_metrics(
                            np.array(val_true),
                            np.array(val_preds),
                            np.array(val_probs)
                        )
                        pbar.set_postfix(temp_metrics)
                final_metrics = calculate_metrics(val_true, val_preds, val_probs)
                results[key] = final_metrics['acc'], final_metrics['f1'], final_metrics['precision'], final_metrics[
                    'recall'], final_metrics['auc']
                metrics[key] = final_metrics
    datasets = list(test_dataset_list.keys())
    headers = ["Metric", ]
    for dataset in datasets:
        headers.append(f'{dataset}(Current|Best)')

    # 构建对比数据行
    rows = []
    for metric in ['acc', 'f1', 'precision', 'recall', 'auc']:
        row = [metric]
        for dataset in datasets:
            current = metrics[dataset][metric.lower()]
            row.append(f"\033[1;31m{current:.4f}\033[0m ")
        rows.append(row)

    # 输出对齐的表格
    # 输出对齐的表格
    print(f"\n===  Test @ 第{epoch}轮权重 ===")
    print(f"\n\033[1mVALIDATION RESULTS COMPARISON\033[0m ")
    file_content = tabulate(rows, headers=headers, tablefmt="grid")
    print(file_content)
    # 在写入文件时调用
    with open(opt.log_file, "a", encoding="utf-8") as f:
        # 文件输出（无颜色代码）
        f.write(f"\n\n===  Test @ 第{epoch}轮 权重 ===\n")
        f.write(clean_ansi(file_content))
        f.write("\n" + "=" * 50 + "\n")


class FaceVideoDataset(Dataset):
    def __init__(self, file_list, labels, crop_face_frame_size=112, frame_num=4):
        self.sequences = []
        self.frame_transform = transforms.Compose([
            transforms.Resize((crop_face_frame_size, crop_face_frame_size)),
            transforms.ToTensor(),
        ])
        self.face_frames = []
        self.labels = labels
        video_path_list = file_list[:, 1]
        self.load_video(video_path_list, frame_num)

    def load_image_rgb(self, image_path, transform=None):
        return transform(Image.open(image_path).convert("RGB"))

    def load_video(self, video_path_list, frame_num):
        for path in video_path_list:
            frames = []
            assert frame_num <= 32, "人脸帧最多使用32帧"
            # 使用多少帧选取样本
            sample_positions = np.linspace(0, len(os.listdir(path)) - 1, frame_num, dtype=int)
            for index in sample_positions:
                frame_path = f'{path}/{index}.jpg'
                frame = self.load_image_rgb(image_path=frame_path, transform=self.frame_transform)
                frames.append(frame)
            frames = np.stack(frames)
            self.sequences.append(frames)

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        return self.sequences[idx], self.labels[idx]


'''
面部情绪视频数据集
'''


class FaceEmotionVideoDataset(FaceVideoDataset):
    def __init__(self, file_list, labels, crop_face_frame_size=112, frame_num=4, order=True, total_frame_num=32):
        super().__init__(file_list, labels, crop_face_frame_size=crop_face_frame_size, frame_num=frame_num)
        self.emotion_labels = []
        emotion_list = file_list[:, 2]
        self.load_emotion_labels(emotion_list=emotion_list, frame_num=frame_num, order=order,
                                 total_frame_num=total_frame_num)

    def load_emotion_labels(self, emotion_list, frame_num, order, total_frame_num):
        assert frame_num <= 32, "人脸帧最多使用32帧"
        # 使用多少帧选取样本 order 表示顺序取帧第0-order帧
        if order:
            sample_positions = np.arange(0, frame_num)
        else:
            sample_positions = np.linspace(0, total_frame_num - 1, frame_num, dtype=int)
        for emotion_str in emotion_list:
            emotions = np.array(json.loads(emotion_str.replace('None', '7')), dtype=np.int64)
            # 使用torch.nn.functional.one_hot进行One-Hot编码
            select_cul_emotions = []
            for index in sample_positions:
                select_cul_emotions.append(emotions[index])
            one_hot_labels = torch.nn.functional.one_hot(torch.from_numpy(np.array(select_cul_emotions)),
                                                         num_classes=8).float()
            self.emotion_labels.append(one_hot_labels)

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        return self.sequences[idx], self.emotion_labels[idx], self.labels[idx]


"""
抑郁症数据集
"""


class DepressionFaceEmotionVideoDataset(FaceEmotionVideoDataset):
    def __init__(self, file_list, labels, crop_face_frame_size=112, frame_num=4):
        super().__init__(file_list, labels, crop_face_frame_size=crop_face_frame_size, frame_num=frame_num)
        self.depression_scores = []
        self.load_depression_scores(depression_score_list=file_list[:, 4])

    def load_depression_scores(self, depression_score_list):
        for depression_score in depression_score_list:
            self.depression_scores.append(
                torch.nn.functional.one_hot(torch.tensor(depression_score), num_classes=64).float())

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        return self.sequences[idx], self.emotion_labels[idx], self.labels[idx], self.depression_scores[idx]


def get_args():
    parser = argparse.ArgumentParser(
        'Detecting Deception in Continuous Frames of Training Faces in Pytorch')  # ArgumentParser 对象包含将命令行解析成 Python 数据类型所需的全部信息。
    parser.add_argument('--crop_face_frame_size', type=int, default=112,
                        help='这是面部尺寸大小')
    parser.add_argument('--frame_num', type=int, default=32,
                        help='这是设置面部帧数 最多32帧')
    parser.add_argument('--order', type=bool, default=False,
                        help='是否顺序取帧')
    parser.add_argument('--dataset_way', type=str, default='face_emotion_video',
                        help='数据集类型 face_video/face_emotion_video/depression_video')
    parser.add_argument('--model', type=str, default='TransFaceGCNEmotion',
                        help='这是可选择模型 【TransFaceGCNEmotion】')
    parser.add_argument('--lr', type=float, default=1e-4,
                        help='这是学习率')
    parser.add_argument('--contrastive_learn', type=bool, default=False,
                        help='是否对比学习')
    parser.add_argument('--optimizer', type=str, default='Adam',
                        help='优化器 Adam/AdamW/SGD/RMSPro')
    parser.add_argument('--random_state', type=int, default=1234,
                        help='这是随机种子 包括训练测试拆分的随机种子')
    parser.add_argument('--epoch', type=int, default=300,
                        help='这是训练轮次')
    parser.add_argument('--is_continue_train', type=bool, default=True,
                        help='是否继续训练')
    parser.add_argument('--early_stop_counter', type=int, default=10,
                        help='是否早停轮次')
    parser.add_argument('--early_stop', type=bool, default=False,
                        help='是否早停')
    parser.add_argument('--batch_size', type=int, default=128,
                        help='这是每轮训练批次大小')
    parser.add_argument('--test_size', type=float, default=0.2,
                        help='这是训练集测试拆分 0.1，表示10倍交叉验证拆分')
    parser.add_argument('--save_pth', type=str, default=f'./runs/',
                        help='这是权重保存地址')
    parser.add_argument('--log_file', type=str, default=f'./runs/metrics.txt',
                        help='这是日志文件存储地址')
    parser.add_argument('--shuffle', type=bool, default=True,
                        help='是否打乱数据集')
    parser.add_argument('--num_workers', type=int, default=4,
                        help='数据加载的子进程数量，default: 0，即数据只在主进程中加载')
    parser.add_argument('--pin_memory', type=bool, default=True,
                        help=' 当pin_memory设置为True时，数据加载器会将数据加载到的固定页(锁页内存)中，而GPU可以直接访问固定页(锁页内存)中的数据，而不需要经过额外的数据拷贝操作，因此cpu内存花销增大，但可以提高使用gpu训练时数据加载的效率，default: False')
    parser.add_argument('--drop_last', type=bool, default=True,
                        help='如果数据集样本总数不能被批次大小整除，是否丢弃最后一个不完整的批次，default: False')
    args = parser.parse_args()
    #设置随机种子
    #torch.manual_seed(args.random_state)
    #np.random.seed(args.random_state)
    if args.model != '':
        old = args.save_pth
        args.save_pth = f'{old}/{args.model}_size_{args.crop_face_frame_size}_frame_{args.frame_num}_lr_{args.lr}_bt_{args.batch_size}_optim_{args.optimizer}（前20轮图对比学习+表情半监督损失）(gcn欺骗损失)/pth'
        args.log_file = f'{old}/{args.model}_size_{args.crop_face_frame_size}_frame_{args.frame_num}_lr_{args.lr}_bt_{args.batch_size}_optim_{args.optimizer}（前20轮图对比学习+表情半监督损失）(gcn欺骗损失)/pth/metrics.txt'
    return args


'''
加载数据集
'''


def load_way_dataset(x, y, way='face_video', frame_num=32, crop_face_frame_size=224, order=False):
    if way == 'face_video':
        return FaceVideoDataset(x, y, frame_num=frame_num, crop_face_frame_size=crop_face_frame_size)
    elif way == 'face_emotion_video':
        return FaceEmotionVideoDataset(x, y, frame_num=frame_num, crop_face_frame_size=crop_face_frame_size,
                                       order=order)
    elif way == 'depression_video':
        return DepressionFaceEmotionVideoDataset(x, y, frame_num=frame_num, crop_face_frame_size=crop_face_frame_size)
    else:
        raise ValueError(f"无对应数据集模式: {way}")

'''
读取并拆分数据集
'''
def load_split_data_set(path, random_state=1234, test_size=0.2,label_index=-1):
    # 读取数据集
    data = pd.read_csv(path, index_col=False).to_numpy()
    # 数据集拆分
    train_x, test_x, train_y, test_y = train_test_split(data[:, :], data[:, label_index], random_state=random_state,
                                                        test_size=test_size)
    return train_x, test_x, train_y, test_y

def load_dataset_train_test(opt):
    # 读取数据集
    train_x, other_x, train_y, other_y = load_split_data_set('./datasets/DOLOS/face-img-all.csv',
                                                             random_state=opt.random_state, test_size=opt.test_size)

    valid_x, test_x, valid_y, test_y = train_test_split(other_x, other_y, random_state=1234, test_size=0.5)

    train_loader = DataLoader(
        load_way_dataset(train_x, train_y, way=opt.dataset_way, frame_num=opt.frame_num,
                         crop_face_frame_size=opt.crop_face_frame_size, order=opt.order),
        batch_size=opt.batch_size, drop_last=opt.drop_last, shuffle=opt.shuffle, num_workers=opt.num_workers,
        pin_memory=opt.pin_memory)

    val_dataset_list = {
        'DOLOS': DataLoader(
            load_way_dataset(valid_x, valid_y, way=opt.dataset_way, frame_num=opt.frame_num,
                             crop_face_frame_size=opt.crop_face_frame_size, order=opt.order),
            batch_size=1, shuffle=opt.shuffle, num_workers=opt.num_workers, pin_memory=opt.pin_memory),
    }

    test_dataset_list = {
        'DOLOS': DataLoader(
            load_way_dataset(test_x, test_y, way=opt.dataset_way, frame_num=opt.frame_num,
                             crop_face_frame_size=opt.crop_face_frame_size, order=opt.order),
            batch_size=1, shuffle=opt.shuffle, num_workers=opt.num_workers, pin_memory=opt.pin_memory),
    }
    return train_loader, val_dataset_list, test_dataset_list


'''
加载模型
'''


def get_model(opt):
    if opt.model == 'TransFaceGCNEmotion':  # 1e-4
        model = TransFaceGCNEmotion(image_size=opt.crop_face_frame_size, frame=opt.frame_num).to(device)
    else:
        model = TransFaceGCNEmotion(image_size=opt.crop_face_frame_size, frame=opt.frame_num).to(device)
    # 优化器
    if opt.optimizer == 'AdamW':
        optimizer = torch.optim.AdamW(model.parameters(), lr=opt.lr, weight_decay=0.1)
    elif opt.optimizer == 'SGD':
        optimizer = torch.optim.SGD(model.parameters(), lr=opt.lr)
    elif opt.optimizer == 'RMSprop':
        optimizer = torch.optim.RMSprop(model.parameters(), lr=opt.lr)
    else:
        optimizer = torch.optim.Adam(model.parameters(), lr=opt.lr)
    # 学习率衰减策略
    scheduler = None
    # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100, eta_min=1e-6)
    return model, optimizer, scheduler





'''
模型训练
'''
def train(opt):
    train_loader, val_dataset_list, test_dataset_list = load_dataset_train_test(opt=opt)
    best_metrics_list = {}
    for key, value in val_dataset_list.items():
        if not os.path.exists(opt.save_pth): os.makedirs(opt.save_pth, exist_ok=True)
        data_pth = os.path.join(opt.save_pth, key)
        if not os.path.exists(data_pth): os.makedirs(data_pth, exist_ok=True)
        best_metrics = {
            'epoch': 0,
            'precision': -np.inf,
            'recall': -np.inf,
            'auc': -np.inf,
            'acc': -np.inf,
            'f1': -np.inf,
            'best_weights': None
        }
        # 是否继续学习
        if opt.is_continue_train:
            dir = f'{opt.save_pth}/{key}'
            if os.path.exists(dir) and len(os.listdir(dir)) > 0:
                weight_path = f'{dir}/{os.listdir(dir)[0]}'
                best_weight = torch.load(weight_path)
                best_metrics['epoch'] = best_weight['epoch']
                best_metrics['precision'] = float(best_weight['precision'])
                best_metrics['recall'] = float(best_weight['recall'])
                best_metrics['auc'] = float(best_weight['auc'])
                best_metrics['acc'] = float(best_weight['acc'])
                best_metrics['f1'] = float(best_weight['f1'])
                best_metrics['best_weights'] = best_weight['best_weights']

        best_metrics_list[key] = best_metrics

    criterion = nn.CrossEntropyLoss().to(device)
    # 图对比损失
    contrastive_loss_fn = GraphContrastiveLoss(temperature=0.1).to(device)
    # 表情半监督损失
    emotion_loss_fn = SemiSupervisedExpressionLoss(
        sup_weight=1.0,  # 监督损失权重
        unsup_weight=0.4,  # 无监督损失权重
        temp=0.05,  # 温度系数（值越小相似度分布越尖锐）
        frame=opt.frame_num
    ).to(device)
    model, optimizer, scheduler = get_model(opt)

    # 是否继续训练
    train_continue_info = best_metrics_list[list(best_metrics_list.keys())[0]]
    start_epoch = 0
    if train_continue_info['epoch']:
        start_epoch = int(train_continue_info['epoch'])
    if opt.is_continue_train:
        if train_continue_info['best_weights']:
            print(f"读取权重文件成功 => 第{start_epoch}轮继续训练 \n")
            model.load_state_dict(train_continue_info['best_weights']['acc'])
    best_weights = {}
    best_epoch_loss = np.inf
    total_train_loss = []
    total_val_loss = []
    early_stop_counter = 0
    for epoch in range(start_epoch, opt.epoch):
        # === 训练阶段 ===
        model.train()
        step = 0
        # 实时进度条（带指标预览）
        pbar = tqdm(train_loader, desc=f"Train Epoch {epoch + 1}", bar_format="{l_bar}{bar:20}{r_bar}", leave=True)
        pbar.clear()  # 先清除进度条
        train_loss = []
        for inputs, emotion, labels in pbar:
            inputs, emotion, labels = inputs.to(device), emotion.to(device), labels.to(device)
            optimizer.zero_grad()
            gcn_cls_outputs, emotion_out, gcn_feature = model(inputs, is_contrastive=True)
            # 损失计算
            # 面部帧损失 二分类欺诈与真实损失
            gcn_deceive_loss = criterion(gcn_cls_outputs, labels)
            # 面部帧图关系损失
            face_gcn_contrastive_loss = contrastive_loss_fn(gcn_feature, labels)  / opt.frame_num
            # 表情类别半监督损失
            emotion_cls_loss = emotion_loss_fn(emotion_out, emotion) / opt.frame_num
            # 总损失
            #u  = 0.1
            loss = gcn_deceive_loss +  emotion_cls_loss + face_gcn_contrastive_loss
            optimizer.step()
            train_loss.append(loss.item())
            step += 1
            pbar.set_postfix({
                "total_average_loss": f"{np.mean(train_loss):.4f}",
                "lr": f"{optimizer.param_groups[0]['lr']:.1e}",
                "early_stop_counter": f"{early_stop_counter}"
            })
        average_train_loss = np.mean(train_loss)
        total_train_loss.append(average_train_loss)
        # 是否采用对比学习
        if opt.contrastive_learn:
            if average_train_loss < best_epoch_loss:
                best_epoch_loss = average_train_loss
                # 删除之前的权重
                key = 'contrastive'
                if not os.path.exists(os.path.join(opt.save_pth, key)):
                    os.mkdir(os.path.join(opt.save_pth, key))
                if len(os.listdir(os.path.join(opt.save_pth, key))) > 0:
                    old_pth = os.listdir(os.path.join(opt.save_pth, key))[0]
                    old_pth_path = os.path.join(os.path.join(opt.save_pth, key), old_pth)
                    if os.path.exists(old_pth_path):
                        os.remove(old_pth_path)
                # 保存新的权重
                torch.save(model.state_dict().copy(),
                           f"{os.path.join(opt.save_pth, key)}/best_{epoch}_{average_train_loss:.4f}_weights.pt")
                print(
                    f"\n 当前对比损失{average_train_loss:.4f} ===> 保存至 {os.path.join(opt.save_pth, key)}/best_{epoch}_{average_train_loss}_weights.pt")
            continue;
        # === 验证阶段 ===
        pbar.clear()  # 先清除进度条
        # 其他数据集的验证
        results, average_val_loss = validate_2(opt, average_train_loss, epoch, best_metrics_list, model, val_dataset_list)
        total_val_loss.append(average_val_loss)
        # === 早停判断 === 该策略暂时不用
        if opt.early_stop:
            #print(f'最小loss {best_epoch_loss} 当前损失{average_val_loss}')
            # 记录验证集损失下降情况
            if average_val_loss < best_epoch_loss:
                best_epoch_loss = average_val_loss
                early_stop_counter = 0
            else:
                early_stop_counter = early_stop_counter + 1
            if early_stop_counter > opt.early_stop_counter:
                print(f"\n⏹ Early stopping at epoch {epoch + 1} ")
                break;
        if average_train_loss < 5e-2:
             print(f"\n⏹ Early stopping at epoch {epoch + 1} ")
             break;
        # === 更新新的最大指标 ===
        for key, value in results.items():
            current_acc, current_f1, current_precision, current_recall, current_auc = value
            best_metrics = best_metrics_list[key]
            is_save_pth = False
            if current_acc > best_metrics['acc']:
                best_metrics['acc'] = current_acc
                best_metrics['f1'] = current_f1
                best_metrics['auc'] = current_auc
                best_metrics['recall'] = current_recall
                best_metrics['precision'] = current_precision
                best_weights['acc'] = model.state_dict().copy()
                is_save_pth = True
            # 最终保存与加载
            if is_save_pth:
                weights = {
                    'best_weights': best_weights,
                    'epoch': f"{epoch + 1}",
                    'acc': f"{best_metrics['acc']:.2f}",
                    'f1': f"{best_metrics['f1']:.2f}",
                    'auc': f"{best_metrics['auc']:.2f}",
                    'recall': f"{best_metrics['recall']:.2f}",
                    'precision': f"{best_metrics['precision']:.2f}",
                }
                best_result_score = f"{epoch + 1}_acc-{best_metrics['acc']:.2f}_f1-{best_metrics['f1']:.2f}_recall-{best_metrics['recall']:.2f}_precision-{best_metrics['precision']:.2f}_auc-{best_metrics['auc']:.2f}"
                # 删除之前的权重
                if len(os.listdir(os.path.join(opt.save_pth, key))) > 0:
                    old_pth = os.listdir(os.path.join(opt.save_pth, key))[0]
                    old_pth_path = os.path.join(os.path.join(opt.save_pth, key), old_pth)
                    if os.path.exists(old_pth_path):
                        os.remove(old_pth_path)
                # 保存新的权重
                torch.save(weights,
                           f"{os.path.join(opt.save_pth, key)}/best_{best_result_score}_weights.pt")

    # 保存最后一轮权重
    # 训练结束保存最终权重
    end_weight_path = f"{opt.save_pth}_end/{opt.dataset_way}"
    if not os.path.exists(f"{opt.save_pth}_end"): os.mkdir(f"{opt.save_pth}_end")
    if not os.path.exists(end_weight_path): os.mkdir(end_weight_path)
    if len(os.listdir(end_weight_path)) > 0:
        os.remove(os.path.join(end_weight_path, os.listdir(end_weight_path)[0]))
    torch.save(model.state_dict().copy(), f"{end_weight_path}/end_{epoch}_weights.pt")
    if test_dataset_list != None:
        test(model, test_dataset_list, opt)


if __name__ == '__main__':
    opt = get_args()
    train(opt)
