# engine_dil.py (修正routing_info处理的版本)

import torch
import torch.nn as nn
import numpy as np
import copy
import os
import util.lr_decay as lrd
import torch.optim as optim
import util.misc as misc
import util.lr_sched as lr_sched
from util.losses import MultiClassFocalLoss
from torch.utils.data import WeightedRandomSampler
from fft_utils import compute_and_save_amp_key
# 从单领域engine导入成熟的评估和路由分析函数
from engine_finetune_soft_dil import evaluate, kl_sparsity_loss, log_detailed_routing_analysis
# 导入MetricsManager
from metrics import MetricsManager


class EarlyStopping:
    """防止在单个任务上过拟合。"""
    def __init__(self, patience=10, min_delta=0.001):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = float('inf')
        self.early_stop = False

    def __call__(self, val_loss):
        if val_loss < self.best_loss - self.min_delta:
            self.best_loss = val_loss
            self.counter = 0
        else:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True

def train_one_epoch_dil(
    model: torch.nn.Module,
    criterion: torch.nn.Module,
    data_loader: torch.utils.data.DataLoader,
    optimizer: torch.optim.Optimizer,
    device: torch.device,
    epoch: int,
    loss_scaler,
    log_writer,
    args,
    domain_name: str,
    metrics_manager: MetricsManager = None  # 可选参数，保持向后兼容
):
    """
    DIL场景下单个epoch的训练函数，集成MetricsManager支持。
    """
    model.train(True)
    
    # 如果提供了MetricsManager，使用它来管理指标
    if metrics_manager:
        metrics_manager.reset_epoch_states()
    
    metric_logger = misc.MetricLogger(delimiter="  ")
    metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}'))
    header = f'Epoch: [{epoch+1}/{args.epochs_per_domain}] Domain: [{domain_name}]'
    
    criterion_cls = MultiClassFocalLoss(gamma=0.5) if args.use_focal_loss else criterion
    rho_target, lambda_kl = args.sparsity_target, args.sparsity_lambda
    all_routing_data, all_targets = [], []
    
    # 如果没有MetricsManager，保持原有的指标收集方式
    if not metrics_manager:
        true_labels, pred_labels, pred_softmax = [], [], []

    optimizer.zero_grad()
    for data_iter_step, (samples, targets) in enumerate(metric_logger.log_every(data_loader, 50, header)):
        # 调整学习率
        if data_iter_step % args.accum_iter == 0:
            # 使用每个domain内的epoch和step来调整，而不是全局的
            lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args)

        samples, targets = samples.to(device, non_blocking=True), targets.to(device, non_blocking=True)
        
        # 与单领域设置保持一致，使用混合精度训练
        with torch.cuda.amp.autocast():
            class_logits, routing_info = model(samples)
            loss_cls = criterion_cls(class_logits, targets)
            loss = loss_cls
            
            # +++ 新增：应用Pull-Constraint辅助损失 +++
            pull_sim = None
            if hasattr(args, 'pull_constraint') and args.pull_constraint and routing_info:
                # 查找routing_info中的pull_sim
                for item in routing_info:
                    if isinstance(item, tuple) and len(item) == 2 and item[0] == 'pull_sim':
                        pull_sim = item[1]
                        break
                
                if pull_sim is not None:
                    # 我们要最大化相似度，等于最小化 (-相似度)
                    # 所以从主损失中减去一个加权的相似度
                    loss = loss - args.pull_constraint_coeff * pull_sim
                    metric_logger.update(pull_sim=pull_sim.item())
            
            loss_router = torch.tensor(0.0)
            if routing_info and lambda_kl > 0:
                # +++ 修复：过滤掉pull_sim项，只处理真正的routing信息 +++
                real_routing_info = [item for item in routing_info 
                                   if not (isinstance(item, tuple) and len(item) == 2 and item[0] == 'pull_sim')]
                if real_routing_info:
                    all_gates = torch.cat([g for g, _ in real_routing_info], dim=1)
                    loss_router = kl_sparsity_loss(all_gates, rho=rho_target)
                    loss = loss + lambda_kl * loss_router

        if data_iter_step % 10 == 0:
            # +++ 修复：过滤routing_info，只保存真正的routing数据用于分析 +++
            filtered_routing_info = [item for item in routing_info 
                                   if not (isinstance(item, tuple) and len(item) == 2 and item[0] == 'pull_sim')]
            all_routing_data.append(filtered_routing_info)
            all_targets.append(targets.clone())

        # 根据是否有MetricsManager选择指标收集方式
        if metrics_manager:
            metrics_manager.update_epoch_states(class_logits, targets)
        else:
            # 收集训练预测结果用于计算指标
            with torch.no_grad():
                output_softmax = nn.Softmax(dim=1)(class_logits)
                output_labels = output_softmax.argmax(dim=1)
                
                true_labels.extend(targets.cpu().numpy())
                pred_labels.extend(output_labels.cpu().numpy())
                pred_softmax.extend(output_softmax.cpu().numpy())

        loss_value = loss.item()
        loss /= args.accum_iter
        
        # 使用 loss_scaler 进行混合精度训练
        loss_scaler(loss, optimizer, clip_grad=args.clip_grad, 
                   parameters=model.parameters(), create_graph=False,
                   update_grad=(data_iter_step + 1) % args.accum_iter == 0)
        ####################################################
        if (data_iter_step + 1) % args.accum_iter == 0:
            # 在 optimizer.step() 之前检查梯度
            if data_iter_step % 50 == 0: # 每50个step打印一次
                print(f"\n--- Gradient Check at Epoch {epoch}, Step {data_iter_step} ---")
                total_norm = 0
                for name, p in model.named_parameters():
                    if 'lora' in name or 'head' in name or 'router' in name or 'domain_keys' in name:
                        if p.requires_grad and p.grad is not None:
                            param_norm = p.grad.data.norm(2)
                            total_norm += param_norm.item() ** 2
                            print(f"  - Grad norm for {name}: {param_norm.item():.4e}")
                total_norm = total_norm ** 0.5
                print(f"  - Total grad norm: {total_norm:.4e}\n")
        ####################################################
        if (data_iter_step + 1) % args.accum_iter == 0:
            optimizer.zero_grad()

        torch.cuda.synchronize()
        metric_logger.update(loss=loss_value)
        if lambda_kl > 0:
            metric_logger.update(loss_router=loss_router.item())
        metric_logger.update(lr=optimizer.param_groups[0]["lr"])

    # 根据是否有MetricsManager选择结果计算方式
    if metrics_manager:
        # 使用MetricsManager计算和记录所有指标
        task_id = args.domains.index(domain_name)
        train_stats = metrics_manager.compute_and_log_epoch_metrics(
            epoch=epoch, task_id=task_id, phase='train', loss=metric_logger.meters["loss"].global_avg
        )
    else:
        # 保持原有的指标计算方式
        true_labels = np.array(true_labels)
        pred_labels = np.array(pred_labels)
        pred_softmax = np.array(pred_softmax)

        from sklearn.metrics import (
            accuracy_score, roc_auc_score, f1_score, average_precision_score,
            cohen_kappa_score
        )
        import torch.nn.functional as F

        train_accuracy = accuracy_score(true_labels, pred_labels)
        train_f1 = f1_score(true_labels, pred_labels, average='macro', zero_division=0)
        true_onehot = F.one_hot(torch.from_numpy(true_labels), num_classes=args.nb_classes).numpy()
        train_roc_auc = roc_auc_score(true_onehot, pred_softmax, multi_class='ovr', average='macro')
        train_average_precision = average_precision_score(true_onehot, pred_softmax, average='macro')
        train_kappa = cohen_kappa_score(true_labels, pred_labels)
        train_score = (train_f1 + train_roc_auc + train_kappa) / 3

        # 记录训练指标到TensorBoard
        if log_writer:
            log_writer.add_scalar(f'3_Performance_Train/accuracy_{domain_name}', train_accuracy, epoch)
            log_writer.add_scalar(f'3_Performance_Train/f1_score_{domain_name}', train_f1, epoch)
            log_writer.add_scalar(f'3_Performance_Train/roc_auc_{domain_name}', train_roc_auc, epoch)
            log_writer.add_scalar(f'3_Performance_Train/kappa_{domain_name}', train_kappa, epoch)
            log_writer.add_scalar(f'3_Performance_Train/score_{domain_name}', train_score, epoch)

        # 保存训练指标到CSV文件
        import csv
        os.makedirs(os.path.join(args.output_dir, args.task), exist_ok=True)
        train_results_path = os.path.join(args.output_dir, args.task, f'metrics_train_{domain_name}.csv')
        with open(train_results_path, 'a', newline='', encoding='utf8') as cfa:
            wf = csv.writer(cfa)
            if cfa.tell() == 0:
                wf.writerow(['epoch', 'train_loss', 'accuracy', 'f1', 'roc_auc', 'kappa', 'score', 'lr'])
            wf.writerow([
                epoch, 
                metric_logger.meters["loss"].global_avg, 
                train_accuracy, 
                train_f1, 
                train_roc_auc, 
                train_kappa, 
                train_score,
                optimizer.param_groups[0]["lr"]
            ])

        print(f'Train - Accuracy: {train_accuracy:.4f}, F1: {train_f1:.4f}, ROC AUC: {train_roc_auc:.4f}, Kappa: {train_kappa:.4f}, Score: {train_score:.4f}')

        train_stats = {
            'loss': metric_logger.meters["loss"].global_avg,
            'accuracy': train_accuracy,
            'f1': train_f1,
            'roc_auc': train_roc_auc,
            'average_precision': train_average_precision,
            'kappa': train_kappa,
            'score': train_score
        }

    # Epoch结束时的日志和分析
    if log_writer and all_routing_data:
        log_detailed_routing_analysis(
            all_routing_data, all_targets, args.nb_classes, log_writer, epoch,
            prefix=f'train_{domain_name}', save_plots=True, output_dir=args.output_dir, args=args
        )

    metric_logger.synchronize_between_processes()
    print("Averaged stats for domain", domain_name, ":", metric_logger)
    
    return train_stats

# +++ 新增：专门用于DIL验证的评估函数 +++
@torch.no_grad()
def evaluate_dil_with_metrics(
    model: torch.nn.Module,
    data_loader: torch.utils.data.DataLoader,
    device: torch.device,
    epoch: int,
    args,
    task_id: int,
    phase: str,
    metrics_manager: MetricsManager
):
    """使用MetricsManager的DIL验证评估函数"""
    model.eval()
    metrics_manager.reset_epoch_states()
    
    metric_logger = misc.MetricLogger(delimiter="  ")
    domain_name = args.domains[task_id]
    header = f'{phase}: Domain: [{domain_name}]'

    for _, (samples, targets) in enumerate(metric_logger.log_every(data_loader, 50, header)):
        samples, targets = samples.to(device, non_blocking=True), targets.to(device, non_blocking=True)
        
        with torch.cuda.amp.autocast():
            class_logits, _ = model(samples)
            loss = nn.CrossEntropyLoss()(class_logits, targets)

        metric_logger.update(loss=loss.item())
        metrics_manager.update_epoch_states(class_logits, targets)

    metric_logger.synchronize_between_processes()
    print(f"Averaged {phase} stats:", metric_logger)
    
    val_stats = metrics_manager.compute_and_log_epoch_metrics(
        epoch=epoch, task_id=task_id, phase=phase, loss=metric_logger.meters["loss"].global_avg
    )
    return val_stats

# +++ 新增：任务结束后的评估函数，包含领域准确率计算 +++
@torch.no_grad()
def evaluate_all_seen_tasks_and_domain_accuracy(
    model: torch.nn.Module,
    continual_dataloader,
    device: torch.device,
    task_id: int,  # 当前已完成的任务ID
    args,
    metrics_manager: MetricsManager,
    acc_matrix: np.ndarray  # +++ 新增：接收现有的acc_matrix +++
):
    """评估所有已见任务并计算领域预测准确率"""
    model.eval()
    # +++ 修复：不再创建新的acc_matrix，而是使用传入的 +++
    # acc_matrix = np.zeros((len(args.domains), len(args.domains)))  # 删除这行
    
    # 存储领域预测结果
    all_domain_preds = []
    all_domain_trues = []
    
    print(f"\n--- Evaluating all seen tasks after completing Task {task_id + 1} ---")
    for prev_task_id in range(task_id + 1):
        prev_domain_name = args.domains[prev_task_id]
        test_loader = continual_dataloader[prev_task_id]['test']
        
        # 临时管理器用于计算单个任务的分类指标
        temp_manager = MetricsManager(
            num_classes=args.nb_classes, 
            epochs_per_task=1, 
            output_dir=os.path.join(args.output_dir, args.task)
        )
        temp_manager.reset_epoch_states()

        for samples, targets in test_loader:
            samples, targets = samples.to(device, non_blocking=True), targets.to(device, non_blocking=True)
            
            with torch.cuda.amp.autocast():
                # 核心：同时获取分类结果和领域预测结果
                class_logits, predicted_domain_indices = model(samples)

            temp_manager.update_epoch_states(class_logits, targets)

            # 记录领域预测结果
            all_domain_preds.extend(predicted_domain_indices.cpu().numpy())
            all_domain_trues.extend([prev_task_id] * len(targets))

        # 计算并存储分类准确率
        test_stats = temp_manager.compute_and_log_epoch_metrics(
            epoch=0, task_id=prev_task_id, phase=f'test_after_task{task_id+1}', loss=0
        )
        # +++ 修复：更新传入的acc_matrix而不是创建新的 +++
        acc_matrix[prev_task_id, task_id] = test_stats['accuracy']
        
        # 生成并保存最终的混淆矩阵
        if temp_manager.all_targets and temp_manager.all_preds_prob:
            all_targets_np = torch.cat(temp_manager.all_targets).numpy()
            all_preds_np = np.argmax(torch.cat(temp_manager.all_preds_prob).numpy(), axis=1)
            metrics_manager.generate_task_final_cm(
                targets=all_targets_np,
                preds=all_preds_np,
                current_task_id=task_id,
                eval_task_id=prev_task_id
            )

    # 计算并记录领域预测准确率
    domain_preds = np.array(all_domain_preds)
    domain_trues = np.array(all_domain_trues)
    domain_accuracy = np.mean(domain_preds == domain_trues)
    
    print(f"Domain Prediction Accuracy after Task {task_id + 1}: {domain_accuracy:.4f}")
    if metrics_manager.tb_writer:
        metrics_manager.tb_writer.add_scalar('1_DIL_Performance/domain_prediction_acc', domain_accuracy, task_id)

    # +++ 新增：利用现有的generate_task_final_cm生成领域预测混淆矩阵 +++
    print(f"\n--- Generating domain prediction confusion matrix after Task {task_id + 1} ---")
    metrics_manager.generate_task_final_cm(
        targets=domain_trues,
        preds=domain_preds,
        current_task_id=task_id,
        eval_task_id=-1  # 使用-1表示这是领域预测的混淆矩阵
    )

    # +++ 新增：计算并保存每个domain的识别准确率到JSON +++
    domain_accuracies = {}
    for domain_idx in range(task_id + 1):
        domain_name = args.domains[domain_idx]
        # 找到真实标签为当前domain的所有样本
        domain_mask = (domain_trues == domain_idx)
        if domain_mask.sum() > 0:
            # 计算这些样本中被正确识别的比例
            domain_specific_acc = np.mean(domain_preds[domain_mask] == domain_trues[domain_mask])
            domain_accuracies[domain_name] = float(domain_specific_acc)
            print(f"  {domain_name} (Task {domain_idx}): {domain_specific_acc:.4f}")
        else:
            domain_accuracies[domain_name] = 0.0
            print(f"  {domain_name} (Task {domain_idx}): No samples found")

    # 保存到JSON文件
    import json
    domain_acc_path = os.path.join(args.output_dir, args.task, f'domain_accuracies_after_task_{task_id + 1}.json')
    with open(domain_acc_path, 'w') as f:
        json.dump({
            'task_id': task_id,
            'overall_domain_accuracy': float(domain_accuracy),
            'per_domain_accuracies': domain_accuracies,
            'total_samples': len(domain_preds),
            'seen_domains': [args.domains[i] for i in range(task_id + 1)]
        }, f, indent=2)
    
    print(f"Domain accuracies saved to: {domain_acc_path}")

    # 使用MetricsManager计算并记录持续学习指标
    metrics_manager.compute_and_log_cl_metrics(acc_matrix, task_id)
    
    return acc_matrix, domain_accuracy

# +++ 新增：创建纯净数据加载器的辅助函数 +++
def create_clean_dataloader_for_key_computation(train_loader, val_loader, args):
    """
    创建一个用于密钥计算的纯净数据加载器，使用训练数据但应用验证集的transform。
    这个函数避免了深拷贝大型数据集的开销。
    """
    try:
        # 获取训练数据集的基本信息
        train_dataset = train_loader.dataset
        val_dataset = val_loader.dataset
        
        # 创建一个轻量级的数据集包装器
        class CleanDatasetWrapper:
            def __init__(self, original_dataset, clean_transform):
                self.original_dataset = original_dataset
                self.clean_transform = clean_transform
                # 保留原始数据集的其他属性
                self.targets = getattr(original_dataset, 'targets', None)
                
            def __len__(self):
                return len(self.original_dataset)
                
            def __getitem__(self, idx):
                # 获取原始数据（通常是PIL Image和标签）
                if hasattr(self.original_dataset, 'samples'):
                    # 对于ImageFolder类型的数据集
                    path, target = self.original_dataset.samples[idx]
                    sample = self.original_dataset.loader(path)
                elif hasattr(self.original_dataset, 'data'):
                    # 对于其他类型的数据集
                    sample, target = self.original_dataset.data[idx], self.original_dataset.targets[idx]
                else:
                    # 通用方法：临时移除transform，获取原始数据，然后恢复
                    original_transform = self.original_dataset.transform
                    self.original_dataset.transform = None
                    sample, target = self.original_dataset[idx]
                    self.original_dataset.transform = original_transform
                
                # 应用纯净的transform
                if self.clean_transform:
                    sample = self.clean_transform(sample)
                    
                return sample, target
        
        # 创建包装后的数据集
        clean_dataset = CleanDatasetWrapper(train_dataset, val_dataset.transform)
        
        # 创建新的DataLoader
        clean_loader = torch.utils.data.DataLoader(
            clean_dataset,
            batch_size=args.batch_size,
            num_workers=args.num_workers,
            pin_memory=args.pin_mem,
            shuffle=False  # 计算密钥时不需要打乱
        )
        
        return clean_loader
        
    except Exception as e:
        print(f"Warning: Failed to create clean dataloader ({e}). Falling back to original train loader.")
        return train_loader

def train_and_evaluate_dil(model, model_without_ddp, continual_dataloader,
                           loss_scaler, device, log_writer, args):
    """
    DIL训练和评估的主编排函数，已全面集成MetricsManager和权重迁移功能。
    """
    # +++ 初始化MetricsManager +++
    metrics_manager = MetricsManager(
        num_classes=args.nb_classes,
        epochs_per_task=args.epochs_per_domain,
        output_dir=os.path.join(args.output_dir, args.task),
        tb_writer=log_writer
    )

    # 1. 初始化CL度量和续训状态
    start_task_id = 0
    acc_matrix = np.zeros((len(args.domains), len(args.domains)))
    if args.resume_dir:
        model_without_ddp, start_task_id, acc_matrix = misc.load_model_dil(args, model_without_ddp)

    # 2. 根据方法类型加载之前的keys（如果是续训）
    amp_keys = {}
    if start_task_id > 0:
        if args.dil_method == 'fft':
            print("Resuming: Loading FFT keys for previous tasks...")
            amp_keys_dir = os.path.join(args.output_dir, args.task, "amp_keys")
            for i in range(start_task_id):
                domain_name = args.domains[i]
                key_path = os.path.join(amp_keys_dir, f"{domain_name}_key.pt")
                if os.path.exists(key_path): 
                    amp_keys[domain_name] = torch.load(key_path, map_location='cpu')
            model_without_ddp.load_amp_keys(amp_keys)
        
        elif 'key_value' in args.dil_method:
            print("Resuming: Domain keys for Key-Value method loaded automatically via state_dict.")
            # 验证加载的keys是否有效
            loaded_keys_count = 0
            for i in range(start_task_id):
                domain_name = args.domains[i]
                if hasattr(model_without_ddp, 'domain_keys') and domain_name in model_without_ddp.domain_keys:
                    key_norm = model_without_ddp.domain_keys[domain_name].abs().sum()
                    if key_norm > 1e-6:
                        loaded_keys_count += 1
            print(f"Verified {loaded_keys_count}/{start_task_id} domain keys are valid.")
        
        elif args.dil_method == 'kmeans':
            print("Resuming: K-Means centers loaded automatically via state_dict.")
            # 验证加载的centers是否有效
            if hasattr(model_without_ddp, 'all_kmeans_centers'):
                total_centers = model_without_ddp.all_kmeans_centers.shape[0]
                expected_centers = start_task_id * args.kmeans_n_clusters
                print(f"Verified {total_centers} K-Means centers loaded (expected: {expected_centers}).")

    # 3. DIL主循环：按顺序训练每个领域
    for task_id in range(start_task_id, len(args.domains)):
        domain_name = args.domains[task_id]
        print(f"\n{'='*25} Starting Task {task_id + 1}/{len(args.domains)}: {domain_name} {'='*25}")
        
        # 关键：为每个领域设置活动模块，并更新优化器参数组
        model_without_ddp.set_active_domain(domain_name)
        
        # +++ 新增：在设置活动域后，根据参数决定是否执行权重迁移 +++
        if args.transfer_weights and task_id > 0:
            model_without_ddp.transfer_weights_from_previous_task(task_id)
        
        params_groups = lrd.param_groups_lrd(model_without_ddp, args.weight_decay, layer_decay=args.layer_decay)
        optimizer = optim.AdamW(params_groups, lr=args.lr)

        early_stopping = EarlyStopping(patience=args.early_stopping_patience) if args.use_early_stopping else None
        best_model_state_task, best_val_loss_task = None, float('inf')
        
        # 获取当前任务的数据加载器
        current_task_dataloaders = continual_dataloader[task_id]
        original_train_loader = current_task_dataloaders['train']
        train_loader_for_epoch = original_train_loader # 默认使用原始加载器
        
        # 为当前领域创建类别均衡采样器（如果启用）
        if args.use_class_balance_sampler:
            print(f"INFO: Using Class-balanced Sampler for domain '{domain_name}'.")
            train_targets = torch.tensor(original_train_loader.dataset.targets)
            class_counts = torch.bincount(train_targets, minlength=args.nb_classes)
            class_weights = 1. / torch.sqrt(class_counts.float().clamp(min=1))
            sample_weights = class_weights[train_targets]
            sampler_train = WeightedRandomSampler(weights=sample_weights, num_samples=len(sample_weights), replacement=True)
            # 更新dataloader以使用新的采样器
            train_loader_for_epoch = torch.utils.data.DataLoader(
                original_train_loader.dataset, sampler=sampler_train, batch_size=args.batch_size,
                num_workers=args.num_workers, pin_memory=args.pin_mem, drop_last=True
            )

        # 4. 单个任务的训练循环 (包含Early Stopping)
        for epoch in range(args.epochs_per_domain):
            if args.distributed and hasattr(train_loader_for_epoch.sampler, 'set_epoch'):
                train_loader_for_epoch.sampler.set_epoch(epoch)
            
            criterion = nn.CrossEntropyLoss()
            # 使用集成了MetricsManager的训练函数
            train_stats = train_one_epoch_dil(
                model, criterion, train_loader_for_epoch, optimizer, device, epoch,
                loss_scaler, log_writer, args, domain_name, metrics_manager
            )
            
            # 启用当前domain的validation评估
            # 关键修改：在validation时强制使用当前domain
            model_without_ddp.set_force_current_domain(True)  # 强制使用当前domain
            val_stats = evaluate_dil_with_metrics(
                model, continual_dataloader[task_id]['val'], device, epoch, args, task_id, 'val', metrics_manager
            )
            model_without_ddp.set_force_current_domain(False)  # 恢复正常模式

            # Early stopping机制
            if early_stopping:
                if val_stats['loss'] < best_val_loss_task:
                    best_val_loss_task = val_stats['loss']
                    best_model_state_task = copy.deepcopy(model_without_ddp.state_dict())
                    print(f"  Task '{domain_name}' new best val_loss: {best_val_loss_task:.4f} at epoch {epoch+1}")
                early_stopping(val_stats['loss'])
                if early_stopping.early_stop:
                    print(f"Early stopping triggered at epoch {epoch+1} for task '{domain_name}'.")
                    break
        
        # 恢复最佳模型状态
        if early_stopping and best_model_state_task is not None:
            print(f"Restoring best model state for task '{domain_name}'.")
            model_without_ddp.load_state_dict(best_model_state_task)
        
        # +++ 核心修改：创建用于密钥计算的纯净数据加载器 +++
        print("\n--- Preparing for domain key computation using clean data ---")
        
        # 获取原始训练和验证数据加载器
        original_train_loader = continual_dataloader[task_id]['train']
        val_loader = continual_dataloader[task_id]['val']
        
        # 创建纯净的数据加载器（使用验证集的transform）
        key_computation_loader = create_clean_dataloader_for_key_computation(
            original_train_loader, val_loader, args
        )
        print("Created clean dataloader for key computation (using validation transforms on training data).")
        
        # === 根据方法选择计算领域密钥 ===
        # 5. 训练完成后，根据选择的方法计算该领域的密钥
        # +++ 修改：使用纯净的数据加载器 +++
        
        if args.dil_method == 'fft':
            print(f"Computing FFT key for domain '{domain_name}' using clean data...")
            amp_keys_dir = os.path.join(args.output_dir, args.task, "amp_keys")
            key_path = os.path.join(amp_keys_dir, f"{domain_name}_key.pt")
            # 使用纯净的数据加载器
            compute_and_save_amp_key(key_computation_loader, key_path)
            # 加载到CPU，让模型自己处理设备转换
            amp_keys[domain_name] = torch.load(key_path, map_location='cpu')
            model_without_ddp.load_amp_keys(amp_keys)
        
        elif args.dil_method == 'key_value':
            print(f"Computing feature vector key for domain '{domain_name}' using clean data...")
            # 使用纯净的数据加载器
            model_without_ddp.compute_and_set_domain_key(domain_name, key_computation_loader)
            # Key现在直接存储在模型的参数中，无需单独加载
        
        # +++ 新增：为可学习方法增加处理分支 +++
        elif args.dil_method == 'key_value_learnable':
            print("Domain keys are learned during training. No post-task computation needed.")
            # 无需执行任何操作，因为key在训练过程中已经被优化
            pass
        
        # +++ 新增：K-Means方法的处理分支 +++
        elif args.dil_method == 'kmeans':
            print(f"Computing K-Means keys for domain '{domain_name}' using clean data...")
            # 调用模型内的新函数来更新K-Means密钥池
            model_without_ddp.update_kmeans_keys_for_task(
                task_id=task_id, 
                data_loader=key_computation_loader, 
                n_clusters=args.kmeans_n_clusters
            )
        
        else:
            raise NotImplementedError(f"DIL method '{args.dil_method}' is not supported. Available methods: 'fft', 'key_value', 'key_value_learnable', 'kmeans'")
        
        # 清理临时数据加载器（如果需要的话）
        del key_computation_loader
        
        # 6. 任务结束后，在所有已见任务上进行评估
        # 模型现在会内部使用对应于args.dil_method的逻辑
        # +++ 修复：传递现有的acc_matrix而不是让函数创建新的 +++
        acc_matrix, domain_acc = evaluate_all_seen_tasks_and_domain_accuracy(
            model, continual_dataloader, device, task_id, args, metrics_manager, acc_matrix
        )

        # 7. 计算并打印CL度量
        print(f"Accuracy Matrix after Task {task_id + 1}:\n{np.round(acc_matrix, 4)}")
        
        # 修正：计算所有已见任务（包括当前任务）的平均准确率
        # acc_matrix[i, task_id] 表示在完成task_id后，对第i个任务的测试准确率
        avg_acc = np.mean(acc_matrix[:task_id + 1, task_id])  # 包括当前任务
        print(f"Average Accuracy (all seen tasks): {avg_acc:.4f}")
        
        forgetting = 0.0
        if task_id > 0:
            max_prev_accs = np.max(acc_matrix[:task_id, :task_id + 1], axis=1)
            current_accs_on_old_tasks = acc_matrix[:task_id, task_id]
            forgetting = np.mean(max_prev_accs - current_accs_on_old_tasks)
            print(f"Forgetting: {forgetting:.4f}")

        if log_writer:
            log_writer.add_scalar('1_DIL_Performance/avg_acc', avg_acc, task_id)
            if task_id > 0: log_writer.add_scalar('1_DIL_Performance/forgetting', forgetting, task_id)
        
        # 8. 保存DIL检查点
        misc.save_model_dil(args, epoch, model, model_without_ddp, optimizer, loss_scaler, task_id, acc_matrix)

    print("\n======== DIL Training Finished ========")
    return acc_matrix