from copy import deepcopy
import os
import sys
import warnings
import argparse
from argparse import Namespace

warnings.filterwarnings("ignore")
import math
import time
import json
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.backends.cudnn as cudnn
from torchvision import transforms
from pathlib import Path
import glob
import collections
import re  # 添加正则表达式模块

# 假设这些 util 都在您的路径下
from util.icbhi_dataset import ICBHIDataset
from util.icbhi_util import get_score
from util.augmentation import SpecAugment
from util.misc import adjust_learning_rate, warmup_learning_rate, set_optimizer, AverageMeter, accuracy
from models import get_backbone_class


# ==============================================================================
# 🧩 Part 1: Mixup & Loss Helpers
# ==============================================================================

def mixup_data(x, y, alpha=1.0):
    '''Returns mixed inputs, pairs of targets, and lambda'''
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1
    batch_size = x.size(0)
    index = torch.randperm(batch_size).cuda()
    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam


def mixup_criterion(criterion, pred, y_a, y_b, lam):
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)


class LabelSmoothingCrossEntropy(nn.Module):
    def __init__(self, smoothing=0.1):
        super(LabelSmoothingCrossEntropy, self).__init__()
        self.smoothing = smoothing

    def forward(self, pred, target):
        n_classes = pred.size(1)
        log_pred = F.log_softmax(pred, dim=1)
        with torch.no_grad():
            true_dist = torch.zeros_like(pred)
            true_dist.fill_(self.smoothing / (n_classes - 1))
            true_dist.scatter_(1, target.data.unsqueeze(1), 1.0 - self.smoothing)
        return torch.mean(torch.sum(-true_dist * log_pred, dim=1))


# ==============================================================================
# 🧩 Part 2: Deep Ensembles Implementation (修复版)
# ==============================================================================

class DeepEnsemble:
    """
    深度集成类，用于管理多个模型并进行集成推理
    """

    def __init__(self, model_paths, device='cuda', args=None):
        """
        初始化深度集成

        Args:
            model_paths: 模型检查点路径列表
            device: 设备
            args: 命令行参数
        """
        self.models = []
        self.device = device
        self.model_paths = model_paths
        self.args = args
        self.model_seeds = []  # 存储每个模型的种子

        print(f"🔧 初始化深度集成，加载 {len(model_paths)} 个模型...")

        for i, model_path in enumerate(model_paths):
            print(f"  [{i + 1}/{len(model_paths)}] 加载模型: {model_path}")

            # 创建模型实例
            model, seed = self._create_model(model_path)
            self.models.append(model)
            self.model_seeds.append(seed)

        print("✅ 所有模型加载完成！")
        print(f"   模型种子: {self.model_seeds}")

    def _create_model(self, model_path):
        """从检查点创建模型，并提取种子信息"""
        # 加载检查点
        checkpoint = torch.load(model_path, map_location=self.device, weights_only=False)

        # 首先尝试从模型路径中提取种子
        seed_from_path = None
        seed_pattern = r'seed(\d+)'
        match = re.search(seed_pattern, model_path)
        if match:
            seed_from_path = int(match.group(1))
            print(f"   从路径提取到种子: {seed_from_path}")

        # 解析模型配置
        if 'args' in checkpoint:
            checkpoint_args = checkpoint['args']
            # 将字典转换为Namespace对象
            if isinstance(checkpoint_args, dict):
                args_obj = Namespace()
                for k, v in checkpoint_args.items():
                    if isinstance(v, dict):
                        # 处理嵌套字典
                        v = Namespace(**v)
                    setattr(args_obj, k, v)
                checkpoint_args = args_obj
        else:
            # 如果检查点中没有args，使用传入的args
            checkpoint_args = self.args

        # 从checkpoint_args获取种子
        seed_from_checkpoint = getattr(checkpoint_args, 'seed', None)

        # 确定最终种子：优先使用checkpoint中的，否则使用路径提取的
        final_seed = seed_from_checkpoint or seed_from_path or 0

        # 创建模型
        kwargs = {}
        if getattr(checkpoint_args, 'model', 'beats') == 'beats':
            if getattr(checkpoint_args, 'nospec', False):
                kwargs['spec_transform'] = None
            else:
                # 确保args包含SpecAugment需要的所有属性
                if not hasattr(checkpoint_args, 'specaug_policy'):
                    checkpoint_args.specaug_policy = 'icbhi_ast_sup'
                if not hasattr(checkpoint_args, 'specaug_mask'):
                    checkpoint_args.specaug_mask = 'mean'
                kwargs['spec_transform'] = SpecAugment(checkpoint_args)

        base_model = get_backbone_class(getattr(checkpoint_args, 'model', 'beats'))(**kwargs)

        try:
            feat_dim = base_model.final_feat_dim
        except:
            feat_dim = 768

        # 添加分类器
        n_cls = getattr(checkpoint_args, 'n_cls', 4)
        base_model.classifier = nn.Linear(feat_dim, n_cls)

        # 加载权重
        if 'state_dict' in checkpoint:
            # 新格式检查点
            state_dict = checkpoint['state_dict']
        elif 'model_state_dict' in checkpoint:
            # 新格式检查点
            state_dict = checkpoint['model_state_dict']
        elif 'model' in checkpoint:
            # 旧格式检查点
            state_dict = checkpoint['model']
        else:
            state_dict = checkpoint

        # 清理键名
        new_state_dict = {}
        for k, v in state_dict.items():
            k = k.replace('module.', '')
            new_state_dict[k] = v

        # 加载权重
        try:
            # 首先尝试加载完整模型
            base_model.load_state_dict(new_state_dict, strict=False)
        except Exception as e:
            print(f"⚠️ 模型权重加载警告: {e}")
            # 尝试加载骨干网络
            try:
                # 只加载匹配的键
                model_dict = base_model.state_dict()
                # 筛选可加载的参数
                pretrained_dict = {k: v for k, v in new_state_dict.items()
                                   if k in model_dict and v.shape == model_dict[k].shape}
                model_dict.update(pretrained_dict)
                base_model.load_state_dict(model_dict, strict=False)
                print(f"  -> 成功加载 {len(pretrained_dict)}/{len(model_dict)} 个参数")
            except Exception as e2:
                print(f"❌ 模型权重加载失败: {e2}")

        # 如果检查点中有独立的分类器权重，也加载
        if 'classifier' in checkpoint:
            clf_state = checkpoint['classifier']
            clf_state = {k.replace('module.', ''): v for k, v in clf_state.items()}
            try:
                base_model.classifier.load_state_dict(clf_state)
                print("  -> 成功加载分类器权重")
            except Exception as e:
                print(f"⚠️ 分类器权重加载失败: {e}")

        base_model.to(self.device)
        base_model.eval()

        return base_model, final_seed

    def predict(self, x, n_passes=1, ensemble_method='vote'):

        all_model_predictions = []
        all_model_logits = []

        with torch.no_grad():
            for model_idx, model in enumerate(self.models):
                model.eval()
                model_logits = []

                for _ in range(n_passes):
                    with torch.cuda.amp.autocast():
                        if hasattr(model, 'beats'):
                            feats = model(x, training=False)
                            output = model.classifier(feats)
                            if output.dim() == 3:
                                output = output.mean(dim=1)
                        else:
                            output = model(x, training=False)
                            if hasattr(model, 'classifier'):
                                output = model.classifier(output)

                    model_logits.append(output)

                if len(model_logits) > 1:
                    model_avg_logits = torch.stack(model_logits).mean(dim=0)
                else:
                    model_avg_logits = model_logits[0]

                all_model_logits.append(model_avg_logits)
                all_model_predictions.append(torch.argmax(model_avg_logits, dim=1))

        if ensemble_method == 'average':
            all_probs = [F.softmax(logits, -1) for logits in all_model_logits]
            all_probs_tensor = torch.stack(all_probs)  # [n_models, batch_size, n_classes]
            ensemble_probs = all_probs_tensor.mean(dim=0)

        elif ensemble_method == 'vote':
            avg_logits = torch.stack(all_model_logits).mean(dim=0)
            ensemble_probs = F.softmax(avg_logits, -1)

        elif ensemble_method == 'hard_vote':
            batch_size = x.size(0)
            n_models = len(self.models)
            n_classes = all_model_logits[0].size(1)

            all_preds = torch.stack(all_model_predictions)  # [n_models, batch_size]

            votes = torch.zeros(batch_size, n_classes, device=x.device)
            for i in range(n_models):
                for j in range(batch_size):
                    votes[j, all_preds[i, j]] += 1

            ensemble_probs = votes / n_models
        else:
            raise ValueError(f"不支持的集成方法: {ensemble_method}")

        return ensemble_probs

    def validate_ensemble(self, val_loader, criterion, args, n_passes=1, ensemble_method='hard_vote'):
        losses = AverageMeter()
        top1 = AverageMeter()
        hits, counts = [0.0] * args.n_cls, [0.0] * args.n_cls
        all_probs, all_targets = [], []

        model_results = []

        for i, (model, seed) in enumerate(zip(self.models, self.model_seeds)):
            model_hits, model_counts = [0.0] * args.n_cls, [0.0] * args.n_cls
            model_probs, model_targets = [], []

            with torch.no_grad():
                for images, labels in val_loader:
                    images = images.cuda(non_blocking=True)
                    class_labels = labels[0].cuda(non_blocking=True)

                    with torch.cuda.amp.autocast():
                        if hasattr(model, 'beats'):
                            feats = model(images, training=False)
                            output = model.classifier(feats)
                            if output.dim() == 3:
                                output = output.mean(dim=1)
                        else:
                            output = model(images, training=False)
                            if hasattr(model, 'classifier'):
                                output = model.classifier(output)

                    probs = F.softmax(output, -1)
                    _, preds = torch.max(probs, 1)

                    model_probs.extend(probs.float().cpu().numpy())
                    model_targets.extend(class_labels.cpu().numpy())

                    for j in range(preds.shape[0]):
                        model_counts[class_labels[j].item()] += 1.0
                        if preds[j].item() == class_labels[j].item():
                            model_hits[class_labels[j].item()] += 1.0

            sp, se, sc, _, _, _ = get_score(model_hits, model_counts, probs=model_probs,
                                            targets=model_targets, pflag=False)

            model_result = {
                "Sp": round(sp, 2), "Se": round(se, 2), "AS": round(sc, 2),
                "F1": round(0.0, 4), "ECE": round(0.0, 2), "MCE": round(0.0, 2),
                "Latency_ms": 0.0, "Seed": seed, "Model_Index": i + 1
            }

            _, _, _, f1, ece, mce = get_score(model_hits, model_counts, probs=model_probs,
                                              targets=model_targets, pflag=True)
            model_result.update({
                "F1": round(f1, 4),
                "ECE": round(ece, 2),
                "MCE": round(mce, 2)
            })

            model_results.append(model_result)
            print(f"  模型 {i + 1} (Seed {seed}): Sp={sp:.2f}, Se={se:.2f}, Score={sc:.2f}, ECE={ece:.2f}")

        total_inference_time = 0.0
        total_samples = 0

        with torch.no_grad():
            for idx, (images, labels) in enumerate(val_loader):
                images = images.cuda(non_blocking=True)
                class_labels = labels[0].cuda(non_blocking=True)
                bsz = class_labels.shape[0]

                torch.cuda.synchronize()
                t_start = time.time()

                ensemble_probs = self.predict(images, n_passes=n_passes,
                                              ensemble_method=ensemble_method)

                torch.cuda.synchronize()
                t_end = time.time()
                total_inference_time += (t_end - t_start)
                total_samples += bsz

                probs_np = ensemble_probs.float().cpu().numpy()
                preds_np = np.argmax(probs_np, axis=1)

                all_probs.extend(probs_np)
                all_targets.extend(class_labels.cpu().numpy())

                correct = (preds_np == class_labels.cpu().numpy()).sum()
                acc = correct / bsz * 100
                top1.update(acc, bsz)

                for i in range(bsz):
                    true_label = int(class_labels[i].item())
                    pred_label = int(preds_np[i])
                    counts[true_label] += 1.0
                    if pred_label == true_label:
                        hits[true_label] += 1.0

                if (idx + 1) % args.print_freq == 0:
                    print(f'Ensemble: [{idx + 1}/{len(val_loader)}] Acc {top1.val:.3f}')

        latency_ms = (total_inference_time / total_samples) * 1000

        sp, se, sc, f1, ece, mce = get_score(hits, counts, probs=all_probs,
                                             targets=all_targets, pflag=True)

        ensemble_result = {
            "Sp": round(sp, 2), "Se": round(se, 2), "AS": round(sc, 2),
            "F1": round(f1, 4), "ECE": round(ece, 2), "MCE": round(mce, 2),
            "Latency_ms": round(latency_ms, 2),
            "Ensemble_Size": len(self.models),
            "Ensemble_Method": ensemble_method,
            "Model_Seeds": self.model_seeds
        }

        print(f'\n📊 DeepEnsemble ({ensemble_method}):')
        print(f'   Latency: {latency_ms:.2f}ms')
        print(f'   Sp: {sp:.2f}, Se: {se:.2f}, Score: {sc:.2f}')
        print(f'   ECE: {ece:.2f}, F1: {f1:.4f}')

        return ensemble_result, model_results

class ModelWithTemperature(nn.Module):
    def __init__(self, model):
        super(ModelWithTemperature, self).__init__()
        self.model = model
        self.temperature = nn.Parameter(torch.ones(1) * 1.5)

    def forward(self, input, **kwargs):
        logits = self.model(input, **kwargs)
        return self.temperature_scale(logits)

    def temperature_scale(self, logits):
        return logits / self.temperature.unsqueeze(1).expand(logits.size(0), logits.size(1))


def find_optimal_temperature(loader, model, args):
    model.eval()
    nll_criterion = nn.CrossEntropyLoss().cuda()
    ece_criterion = get_ece
    logits_list = []
    labels_list = []

    print("Optimization: Collecting logits...")
    with torch.no_grad():
        for images, labels in loader:
            images = images.cuda()
            labels = labels[0].cuda()

            # BEATs Inference
            if args.model == 'beats':
                feats = model(images, training=False)
                # Handle classifier access for both DDP and Single GPU
                if hasattr(model, 'module'):
                    l = model.module.classifier(feats)
                else:
                    l = model.classifier(feats)

                if l.dim() == 3: l = l.mean(dim=1)
            else:
                feats = model(images, args=args, training=False)
                if hasattr(model, 'module'):
                    l = model.module.classifier(feats)
                else:
                    l = model.classifier(feats)

            logits_list.append(l)
            labels_list.append(labels)

    logits = torch.cat(logits_list).cuda()
    labels = torch.cat(labels_list).cuda()

    # LBFGS Optimization
    temperature = nn.Parameter(torch.ones(1).cuda() * 1.5)
    optimizer = optim.LBFGS([temperature], lr=0.01, max_iter=50)

    def eval():
        optimizer.zero_grad()
        loss = nll_criterion(logits / temperature, labels)
        loss.backward()
        return loss

    optimizer.step(eval)

    after_ece = ece_criterion(logits / temperature, labels)

    print('Test (TempScal): Latency N/A     | Optimal T {:.3f} ECE {:.2f}'.format(
        temperature.item(), after_ece))

    return temperature.item(), after_ece

def get_ece(logits, labels, n_bins=15):
    softmaxes = F.softmax(logits, dim=1)
    confidences, predictions = torch.max(softmaxes, 1)
    accuracies = predictions.eq(labels)
    ece = torch.zeros(1, device=logits.device)
    bin_boundaries = torch.linspace(0, 1, n_bins + 1)
    for i in range(n_bins):
        bin_lower = bin_boundaries[i]
        bin_upper = bin_boundaries[i + 1]
        in_bin = (confidences > bin_lower.item()) & (confidences <= bin_upper.item())
        prop_in_bin = in_bin.float().mean()
        if prop_in_bin.item() > 0:
            accuracy_in_bin = accuracies[in_bin].float().mean()
            avg_confidence_in_bin = confidences[in_bin].mean()
            ece += torch.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin
    return ece.item() * 100.0
def parse_args():
    parser = argparse.ArgumentParser('Calibration Benchmark')
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--print_freq', type=int, default=50)
    parser.add_argument('--save_dir', type=str, default='./save_calibration_benchmark')
    parser.add_argument('--resume', type=str, default=None, help='path of model checkpoint to resume')
    parser.add_argument('--eval', action='store_true', help='only evaluation')

    parser.add_argument('--batch_size', type=int, default=32)
    parser.add_argument('--optimizer', type=str, default='adam')
    parser.add_argument('--epochs', type=int, default=100)
    parser.add_argument('--learning_rate', type=float, default=5e-5)
    parser.add_argument('--lr_decay_epochs', type=str, default='120,160')
    parser.add_argument('--lr_decay_rate', type=float, default=0.1)
    parser.add_argument('--weight_decay', type=float, default=1e-6)
    parser.add_argument('--momentum', type=float, default=0.9)
    parser.add_argument('--cosine', action='store_true', help='using cosine annealing')
    parser.add_argument('--warm', action='store_true', help='warm-up')
    parser.add_argument('--warm_epochs', type=int, default=10)

    parser.add_argument('--dataset', type=str, default='icbhi')
    parser.add_argument('--data_folder', type=str, default='')
    parser.add_argument('--num_workers', type=int, default=8)
    parser.add_argument('--class_split', type=str, default='lungsound')
    parser.add_argument('--n_cls', type=int, default=4)
    # Default desired_length matches your training script (5)
    parser.add_argument('--desired_length', type=int, default=5)
    parser.add_argument('--pad_types', type=str, default='repeat')
    parser.add_argument('--n_mels', type=int, default=128)
    parser.add_argument('--sample_rate', type=int, default=16000)
    parser.add_argument('--resz', type=float, default=1)
    parser.add_argument('--nospec', default=None, action='store_true')

    # Arguments from original training script
    parser.add_argument('--test_fold', type=str, default='official')
    parser.add_argument('--raw_augment', type=int, default=0)
    parser.add_argument('--specaug_policy', type=str, default='icbhi_ast_sup')
    parser.add_argument('--specaug_mask', type=str, default='mean')

    parser.add_argument('--model', type=str, default='beats')
    parser.add_argument('--from_sl_official', action='store_true')

    parser.add_argument('--method', type=str, default='baseline',
                        choices=['baseline', 'ls', 'mixup', 'spd', 'ensemble'])
    parser.add_argument('--smoothing', type=float, default=0.1)
    parser.add_argument('--mixup_alpha', type=float, default=0.8)

    # Deep Ensembles 参数
    parser.add_argument('--ensemble_models', type=str, default='')
    parser.add_argument('--ensemble_size', type=int, default=5)
    parser.add_argument('--ensemble_method', type=str, default='hard_vote',
                        choices=['average', 'vote', 'hard_vote'], help='集成方法')
    parser.add_argument('--ensemble_passes', type=int, default=1)
    parser.add_argument('--save_all_ensemble_results', action='store_true')

    # post-hoc calibration
    parser.add_argument('--calc_temp_scaling', action='store_true')
    parser.add_argument('--calc_mc_dropout', action='store_true')
    parser.add_argument('--mc_passes', type=int, default=5)

    args = parser.parse_args()

    if isinstance(args.lr_decay_epochs, str):
        iterations = args.lr_decay_epochs.split(',')
        args.lr_decay_epochs = []
        for it in iterations:
            args.lr_decay_epochs.append(int(it))

    args.model_name = '{}_{}_{}'.format(args.dataset, args.model, args.method)
    args.save_folder = os.path.join(args.save_dir, f"{args.model_name}_seed{args.seed}")
    if not os.path.isdir(args.save_folder): os.makedirs(args.save_folder)

    if args.warm:
        args.warmup_from = args.learning_rate * 0.1
        if args.cosine:
            eta_min = args.learning_rate * (0.1 ** 3)
            args.warmup_to = eta_min + (args.learning_rate - eta_min) * (
                    1 + math.cos(math.pi * args.warm_epochs / args.epochs)) / 2
        else:
            args.warmup_to = args.learning_rate
    if args.dataset == 'icbhi':
        if args.class_split == 'lungsound':
            args.cls_list = ['normal', 'crackle', 'wheeze', 'both']
            args.device_list = ['L', 'A', 'M', '3']
    return args


def get_class_counts(dataset, args):
    counts = torch.zeros(args.n_cls)
    labels_list = getattr(dataset, 'labels', getattr(dataset, 'train_labels', None))
    if labels_list is None:
        from torch.utils.data import DataLoader
        temp_loader = DataLoader(dataset, batch_size=256, shuffle=False, num_workers=args.num_workers)
        for _, labels in temp_loader:
            class_lbls = labels[0]
            for l in class_lbls:
                if l < args.n_cls: counts[int(l)] += 1
    else:
        for label_item in labels_list:
            l = int(label_item[0]) if isinstance(label_item, (list, tuple, np.ndarray)) else int(label_item)
            if l < args.n_cls: counts[l] += 1
    return counts.cuda()


def set_model(args):
    kwargs = {}
    if args.model == 'beats':
        if args.nospec:
            kwargs['spec_transform'] = None
        else:
            kwargs['spec_transform'] = SpecAugment(args)
    base_model = get_backbone_class(args.model)(**kwargs)
    try:
        feat_dim = base_model.final_feat_dim
    except:
        feat_dim = 768

    model = base_model
    # Check if classifier exists, else add it.
    # IMPORTANT: We add the classifier TO the model instance here so we can load it later.
    if not hasattr(model, 'classifier'):
        model.classifier = nn.Linear(feat_dim, args.n_cls)

    model.cuda()

    if args.method == 'ls':
        criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing).cuda()
    else:
        criterion = nn.CrossEntropyLoss().cuda()

    if args.model != 'beats' and args.from_sl_official:
        base_model.load_sl_official_weights()

    optimizer = set_optimizer(args, model.parameters())
    return model, criterion, optimizer


def train(train_loader, model, criterion, optimizer, epoch, args, scaler=None):
    model.train()
    losses = AverageMeter()
    top1 = AverageMeter()

    for idx, (images, labels) in enumerate(train_loader):
        images = images.cuda(non_blocking=True)
        class_labels = labels[0].cuda(non_blocking=True)
        bsz = class_labels.shape[0]
        warmup_learning_rate(args, epoch, idx, len(train_loader), optimizer)

        with torch.cuda.amp.autocast():
            if args.method == 'mixup':
                inputs, targets_a, targets_b, lam = mixup_data(images, class_labels, args.mixup_alpha)
                if args.model == 'beats':
                    features = model(inputs, training=True)
                    if hasattr(model, 'module'):
                        output = model.module.classifier(features)
                    else:
                        output = model.classifier(features)
                    if output.dim() == 3: output = output.mean(dim=1)
                else:
                    output = model.classifier(model(args.transforms(inputs), args=args, training=True))
                loss = mixup_criterion(criterion, output, targets_a, targets_b, lam)
            else:
                if args.model == 'beats':
                    features = model(images, training=True)
                    if hasattr(model, 'module'):
                        output = model.module.classifier(features)
                    else:
                        output = model.classifier(features)
                    if output.dim() == 3: output = output.mean(dim=1)
                else:
                    if args.nospec:
                        features = model(images, args=args, training=True)
                    else:
                        features = model(args.transforms(images), args=args, training=True)
                    output = model.classifier(features)
                loss = criterion(output, class_labels)

        losses.update(loss.item(), bsz)
        [acc1], _ = accuracy(output[:bsz], class_labels, topk=(1,))
        top1.update(acc1[0], bsz)

        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        if (idx + 1) % args.print_freq == 0:
            print(
                f'Train: [{epoch}][{idx + 1}/{len(train_loader)}] Loss {losses.val:.3f} ({losses.avg:.3f}) Acc {top1.val:.3f} ({top1.avg:.3f})')


def validate(val_loader, model, criterion, args, best_acc, best_model=None, log_prefix="Test"):
    save_bool = False
    model.eval()

    is_mc_mode = args.calc_mc_dropout
    if is_mc_mode:
        model.apply(lambda m: m.train() if isinstance(m, nn.Dropout) else None)

    losses = AverageMeter()
    top1 = AverageMeter()
    hits, counts = [0.0] * args.n_cls, [0.0] * args.n_cls
    all_probs, all_targets = [], []

    n_passes = args.mc_passes if is_mc_mode else 1

    total_inference_time = 0.0
    total_samples = 0

    with torch.no_grad():
        for idx, (images, labels) in enumerate(val_loader):
            images = images.cuda(non_blocking=True)
            class_labels = labels[0].cuda(non_blocking=True)
            bsz = class_labels.shape[0]

            torch.cuda.synchronize()
            t_start = time.time()

            batch_probs_sum = torch.zeros(bsz, args.n_cls).cuda()

            for _ in range(n_passes):
                with torch.cuda.amp.autocast():
                    if args.model == 'beats':
                        feats = model(images, training=False)
                        if hasattr(model, 'module'):
                            output = model.module.classifier(feats)
                        else:
                            output = model.classifier(feats)
                        if output.dim() == 3: output = output.mean(dim=1)
                    else:
                        feats = model(images, args=args, training=False)
                        if hasattr(model, 'module'):
                            output = model.module.classifier(feats)
                        else:
                            output = model.classifier(feats)
                    probs = F.softmax(output, -1)
                    loss_val = criterion(output, class_labels)
                batch_probs_sum += probs

            final_probs = batch_probs_sum / n_passes

            torch.cuda.synchronize()
            t_end = time.time()
            total_inference_time += (t_end - t_start)
            total_samples += bsz

            losses.update(loss_val.item(), bsz)
            _, preds = torch.max(final_probs, 1)
            [acc1], _ = accuracy(final_probs, class_labels, topk=(1,))
            top1.update(acc1[0], bsz)
            all_probs.extend(final_probs.float().cpu().numpy())
            all_targets.extend(class_labels.cpu().numpy())
            for i in range(preds.shape[0]):
                counts[class_labels[i].item()] += 1.0
                if preds[i].item() == class_labels[i].item(): hits[class_labels[i].item()] += 1.0

    latency_ms = (total_inference_time / total_samples) * 1000
    sp, se, sc, f1, ece, mce = get_score(hits, counts, probs=all_probs, targets=all_targets, pflag=True)

    result_dict = {
        "Sp": round(sp, 2), "Se": round(se, 2), "AS": round(sc, 2),
        "F1": round(f1, 4), "ECE": round(ece, 2), "MCE": round(mce, 2),
        "Latency_ms": round(latency_ms, 2)
    }

    current_best_score = best_acc['AS'] if isinstance(best_acc, dict) else 0.0

    if not (args.calc_temp_scaling or is_mc_mode):
        if sc > current_best_score and se > 0.1:
            save_bool = True
            best_acc = result_dict
            best_model = deepcopy(model.state_dict())
    else:
        best_acc = result_dict

    mode_str = "MC-Drop" if is_mc_mode else log_prefix
    print('{}: Latency {:.2f}ms | Sp {:.2f} Se {:.2f} Score {:.2f} ECE {:.2f}'.format(
        mode_str, latency_ms, sp, se, sc, ece))

    return best_acc, best_model, save_bool


def find_ensemble_models(args):
    if args.ensemble_models:
        model_paths = args.ensemble_models.split(',')
        model_paths = [p.strip() for p in model_paths if p.strip()]
    else:
        if args.method == 'ensemble':
            base_pattern = os.path.join(args.save_dir, f"{args.dataset}_{args.model}_*_seed*", "best_model.pth")
        else:
            base_pattern = os.path.join(args.save_dir, f"{args.dataset}_{args.model}_{args.method}_seed*",
                                        "best_model.pth")

        model_paths = glob.glob(base_pattern)

        if not model_paths:
            model_paths = glob.glob(os.path.join(args.save_dir, "*", "best_model.pth"))

    if len(model_paths) > args.ensemble_size:
        print(f"找到 {len(model_paths)} 个模型，使用前 {args.ensemble_size} 个")
        model_paths = model_paths[:args.ensemble_size]

    if not model_paths:
        raise ValueError(f" 没有找到可用的模型文件！请检查模型路径或使用 --ensemble_models 参数指定模型")

    print(f" 找到 {len(model_paths)} 个模型用于深度集成:")
    for i, path in enumerate(model_paths):
        print(f"  [{i + 1}] {path}")

    return model_paths


def main():
    args = parse_args()
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    cudnn.deterministic = True
    cudnn.benchmark = True

    # Loader Setup
    if not args.nospec and args.model != 'beats': args.transforms = SpecAugment(args)

    args.h, args.w = int(args.desired_length * 100 - 2), 128

    if args.model == 'beats':
        t_tr, t_val = None, None
    else:
        t_tr = transforms.Compose(
            [transforms.ToTensor(), transforms.Resize((int(args.h * args.resz), int(args.w * args.resz)))])
        t_val = transforms.Compose(
            [transforms.ToTensor(), transforms.Resize((int(args.h * args.resz), int(args.w * args.resz)))])

    tr_set = ICBHIDataset(train_flag=True, transform=t_tr, args=args, print_flag=True)
    val_set = ICBHIDataset(train_flag=False, transform=t_val, args=args, print_flag=True)

    train_loader = torch.utils.data.DataLoader(tr_set, batch_size=args.batch_size, shuffle=True,
                                               num_workers=args.num_workers, pin_memory=True, drop_last=True)
    val_loader = torch.utils.data.DataLoader(val_set, batch_size=args.batch_size, shuffle=False,
                                             num_workers=args.num_workers, pin_memory=True)

    args.cls_num_list = get_class_counts(tr_set, args)

    def save_results_to_json(result_entry, key_suffix=""):
        json_path = os.path.join(args.save_dir, f"calibration_results_{args.method}.json")
        if os.path.exists(json_path):
            with open(json_path, 'r') as f:
                try:
                    data = json.load(f)
                except json.JSONDecodeError:
                    data = {}
        else:
            data = {}

        seed_key = f"Seed_{args.seed}{key_suffix}"
        data[seed_key] = result_entry

        with open(json_path, 'w') as f:
            json.dump(data, f, indent=4)
        print(f"Results saved to {json_path} (Key: {seed_key})")

    def save_single_model_results(model_results, ensemble_identifier=""):
        json_path = os.path.join(args.save_dir, f"calibration_results_{args.method}.json")
        if os.path.exists(json_path):
            with open(json_path, 'r') as f:
                try:
                    data = json.load(f)
                except json.JSONDecodeError:
                    data = {}
        else:
            data = {}

        for i, model_result in enumerate(model_results):
            model_seed = model_result.get("Seed", i)
            model_idx = model_result.get("Model_Index", i + 1)

            if ensemble_identifier:
                key = f"Ensemble_{ensemble_identifier}_Model_{model_idx}_Seed_{model_seed}"
            else:
                key = f"Model_{model_idx}_Seed_{model_seed}"

            data[key] = model_result
            print(f"Saved {key}: Sp={model_result['Sp']:.2f}, Se={model_result['Se']:.2f}, AS={model_result['AS']:.2f}")

        with open(json_path, 'w') as f:
            json.dump(data, f, indent=4)
        print(f"All model results saved to {json_path}")

    def save_ensemble_results(ensemble_result, ensemble_identifier=""):
        json_path = os.path.join(args.save_dir, f"calibration_results_{args.method}.json")
        if os.path.exists(json_path):
            with open(json_path, 'r') as f:
                try:
                    data = json.load(f)
                except json.JSONDecodeError:
                    data = {}
        else:
            data = {}

        if ensemble_identifier:
            key = f"Ensemble_{ensemble_identifier}"
        else:
            model_seeds = ensemble_result.get("Model_Seeds", [])
            seed_str = "_".join([str(s) for s in model_seeds])
            key = f"Ensemble_Seeds_{seed_str}"

        data[key] = ensemble_result

        with open(json_path, 'w') as f:
            json.dump(data, f, indent=4)
        print(f"Ensemble results saved to {json_path} (Key: {key})")
        return key

    # ================= Deep Ensembles 模式 =================
    if args.method == 'ensemble' and args.eval:
        print("=" * 60)
        print("🎯 Deep Ensembles 推理模式")
        print("=" * 60)

        model_paths = find_ensemble_models(args)

        ensemble = DeepEnsemble(model_paths, device='cuda', args=args)

        criterion = nn.CrossEntropyLoss().cuda()

        ensemble_result, model_results = ensemble.validate_ensemble(
            val_loader, criterion, args,
            n_passes=args.ensemble_passes,
            ensemble_method=args.ensemble_method
        )

        ensemble_key = save_ensemble_results(ensemble_result, ensemble_identifier=f"Seed_{args.seed}")

        if args.save_all_ensemble_results:
            save_single_model_results(model_results, ensemble_identifier=f"Seed_{args.seed}")

        return

    model, criterion, optimizer = set_model(args)
    scaler = torch.cuda.amp.GradScaler()

    best_acc = {"AS": 0.0}
    best_model_state = None

    # ================= Checkpoint Loading =================
    if args.resume:
        if os.path.isfile(args.resume):
            print(f"=> loading checkpoint '{args.resume}'")
            checkpoint = torch.load(args.resume, map_location='cpu', weights_only=False)

            # 1. Load Backbone
            state_dict = checkpoint['model'] if 'model' in checkpoint else checkpoint
            clean_state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
            model.load_state_dict(clean_state_dict, strict=False)

            # 2. Load Classifier
            if 'classifier' in checkpoint:
                print("   [Classifier Found] Loading classifier weights separately...")
                clf_state = {k.replace('module.', ''): v for k, v in checkpoint['classifier'].items()}
                model.classifier.load_state_dict(clf_state)
                print("   [Classifier Loaded] Successfully loaded into model.classifier")
        else:
            print(f"=> no checkpoint found at '{args.resume}'")
            return

    # ================= Evaluation Mode =================
    if args.eval:
        print(f"=> Evaluation / Post-hoc Analysis Mode (Method: {args.method})")

        # 1. Standard Evaluation
        print("--- Standard Evaluation ---")
        args.calc_mc_dropout = False
        std_results, _, _ = validate(val_loader, model, criterion, args, best_acc={"AS": 0.0},
                                     log_prefix="Test (Standard)")
        save_results_to_json(std_results, key_suffix="_Standard")

        # 2. Temperature Scaling
        if args.calc_temp_scaling:
            print(" - -- Temperature Scaling Analysis - --")
            ts_model = ModelWithTemperature(model)
            temp, scaled_ece = find_optimal_temperature(val_loader, model, args)
            ts_results = std_results.copy()
            ts_results['ECE'] = round(scaled_ece, 2)
            ts_results['T'] = round(temp, 3)
            save_results_to_json(ts_results, key_suffix="_TempScaling")

            # 3. MC Dropout
            if args.calc_mc_dropout:
                print("  - -- MC Dropout Analysis - --")
            args.calc_mc_dropout = True
            args.mc_passes = 5
            mc_results, _, _ = validate(val_loader, model, criterion, args, best_acc={"AS": 0.0})
            save_results_to_json(mc_results, key_suffix="_MCDropout")

        return  # End of Eval

    # ================= Training Mode =================
    print(f"Start Training: {args.method} (Seed {args.seed})")
    for epoch in range(1, args.epochs + 1):
        adjust_learning_rate(args, optimizer, epoch)
        train(train_loader, model, criterion, optimizer, epoch, args, scaler)
        best_acc, best_model_state, save_bool = validate(val_loader, model, criterion, args, best_acc, best_model_state,
                                                         log_prefix="Test (Val)")

        if save_bool:
            print(f" New Best Score: {best_acc['AS']:.2f}")
            torch.save(best_model_state, os.path.join(args.save_folder, 'best_model.pth'))

    save_results_to_json(best_acc, key_suffix="_Best")


if __name__ == '__main__':
    main()
