import math
import sys
import os
import random
import datetime
import json
from torch.utils.data import ConcatDataset, DataLoader
from typing import Iterable
import torchaudio
from collections import defaultdict
from pathlib import Path
from copy import deepcopy
import re
import matplotlib.pyplot as plt
from collections import defaultdict

import torch.nn as nn
import torch
from torch.nn import functional as F
import tempfile
import torch.distributed as dist
import numpy as np

from timm.utils import accuracy
from timm.optim import create_optimizer
from timm.scheduler import create_scheduler
from torch import optim
import utils
from torch.distributions.multivariate_normal import MultivariateNormal
import copy

import sys
sys.path.append('/public/home/xxx/projects/cl/cl/ACL')
from acil import ACILLearner
sys.path.append('/public/home/xxx/projects/cl/cl/ACL/EAT/models')
from modules import AltAttention
sys.path.append('/public/home/xxx/projects/cl/cl/ACL/trainers')
from ranpac_trainer import create_optimizer_2

def train_one_epoch_default(model: torch.nn.Module, criterion, data_loader: Iterable, optimizer: torch.optim.Optimizer,
                    device: torch.device, epoch: int, max_norm: float = 0,
                    set_training_mode=True, task_id=-1, class_mask=None, args=None, ):
    model.train(set_training_mode)
    
    if args.distributed and utils.get_world_size() > 1:
        data_loader.sampler.set_epoch(epoch)
        
    metric_logger = utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter('Lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
    metric_logger.add_meter('Loss', utils.SmoothedValue(window_size=1, fmt='{value:.4f}'))
    header = f'Train: Epoch[{epoch + 1:{int(math.log10(args.epochs)) + 1}}/{args.epochs}]'

    # cnt_total = 0
    # for input, target in metric_logger.log_every(data_loader, args.print_freq, header):
    #     cnt_total += 1
    cnt_total = len(data_loader)
    # import pdb; pdb.set_trace()
    cnt = 0

    for input, target in metric_logger.log_every(data_loader, args.print_freq, header):
        cnt += 1
        if args.percentage_epoch != 1 and float(cnt / cnt_total) > args.percentage_epoch:
            break
        input = input.to(device, non_blocking=True)
        if random.random() < args.random_rate and task_id < args.aug_tasks: 
            print('perform aug')
            input = spectrogram_augment(input, args.time_aug, args.freq_aug)
        target = target.to(device, non_blocking=True)
        output = model.module.naive_classification(input, task_id=task_id, train=set_training_mode)
        logits = output['logits']

        # here is the trick to mask out classes of non-current tasks
        if args.train_mask and class_mask is not None:
            mask = class_mask[task_id]
            not_mask = np.setdiff1d(np.arange(args.nb_classes), mask)
            not_mask = torch.tensor(not_mask, dtype=torch.int64).to(device)
            logits = logits.index_fill(dim=1, index=not_mask, value=float('-inf'))

        loss = criterion(logits, target)  # base criterion (CrossEntropyLoss)

        acc1, acc5 = accuracy(logits, target, topk=(1, 5))

        if not math.isfinite(loss.item()):
            print("Loss is {}, stopping training".format(loss.item()))
            sys.exit(1)

        optimizer.zero_grad()
        loss.backward()
        # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
        optimizer.step()

        torch.cuda.synchronize()
        metric_logger.update(Loss=loss.item())
        metric_logger.update(Lr=optimizer.param_groups[0]["lr"])
        metric_logger.meters['Acc@1'].update(acc1.item(), n=input.shape[0])
        metric_logger.meters['Acc@5'].update(acc5.item(), n=input.shape[0])

    # gather the stats from all processes
    metric_logger.synchronize_between_processes()
    print("Averaged stats:", metric_logger)
    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}

def train_one_epoch_additional(model: torch.nn.Module, criterion, data_loader: Iterable, optimizer: torch.optim.Optimizer,
                    device: torch.device, epoch: int, max_norm: float = 0,
                    set_training_mode=True, task_id=-1, class_mask=None, args=None, total_neg=None, total_pos=None):
    model.train(set_training_mode)
    
    if args.distributed and utils.get_world_size() > 1:
        data_loader.sampler.set_epoch(epoch)
        
    metric_logger = utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter('Lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
    metric_logger.add_meter('Loss', utils.SmoothedValue(window_size=1, fmt='{value:.4f}'))
    header = f'Train: Epoch[{epoch + 1:{int(math.log10(args.epochs)) + 1}}/{args.epochs}]'

    # cnt_total = 0
    # for input, target in metric_logger.log_every(data_loader, args.print_freq, header):
    #     cnt_total += 1
    cnt_total = len(data_loader)
    # import pdb; pdb.set_trace()
    cnt = 0
    total_neg = copy.deepcopy(total_neg)
    total_pos = copy.deepcopy(total_pos)
    for k, v in total_neg.items():
        if isinstance(v, torch.Tensor):
            total_neg[k] = v.clone().detach()
    for k, v in total_pos.items():
        if isinstance(v, torch.Tensor):
            total_pos[k] = v.clone().detach()

    for i in range(int(task_id * args.num_classes / args.num_tasks), int((task_id + 1) * args.num_classes / args.num_tasks)): # labels
        # import pdb; pdb.set_trace()
        if len(total_neg[i]) == 0:
            total_neg[i] = []
        elif len(total_neg[i]) == 1:
            total_neg[i] = total_neg[i][0].unsqueeze(0)
        else:
            # print(total_neg[i])
            total_neg[i] = torch.stack(total_neg[i])
        total_pos[i] = torch.stack(total_pos[i])

    calc_tl = nn.MarginRankingLoss(margin=args.margin)

    for input, target in metric_logger.log_every(data_loader, args.print_freq, header):
        cnt += 1
        if args.percentage_epoch != 1 and float(cnt / cnt_total) > args.percentage_epoch:
            break
        input = input.to(device, non_blocking=True)
        if random.random() < args.random_rate: 
            print('perform aug')
            input = spectrogram_augment(input, args.time_aug, args.freq_aug)
        target = target.to(device, non_blocking=True)
        output = model.module.naive_classification(input, task_id=task_id, train=set_training_mode)
        logits = output['logits']
        features = output['pre_logits']
        # for i in range(features.shape[0]): features[i] = features[i] / features[i].norm(dim=-1, keepdim=True)
        features = F.normalize(features, dim=-1, eps=1e-12)

        if args.contrastive_pretraining == "True" and args.tl_ratio != 0:
            dist_hardest_negative = []
            dist_hardest_positive = [] 

            for i in range(features.shape[0]):
                # import pdb; pdb.set_trace()
                label = int(target[i])

                if len(total_pos[label]) == 0:
                    print(f'do not have any positive label for label {label}')
                    raise AssertionError
                else:
                    # dists_pos = torch.norm(total_pos[label] - features[i].unsqueeze(0), dim=1)  # [N]
                    # max_dist = torch.mean(dists_pos, dim=0)
                    max_dist = torch.norm(features[i] - torch.mean(total_pos[label], dim=0))
                    dist_hardest_positive.append(max_dist)
                
                if len(total_neg[label]) == 0:
                    print(f'do not have any negative label for label {label}')
                    dist_hardest_negative.append(max_dist + args.margin)
                    min_dist = max_dist + args.margin
                else:
                    dists_neg = torch.norm(total_neg[label] - features[i].unsqueeze(0), dim=1)  # [N]
                    k = min(1, dists_neg.shape[0]) 
                    topk_vals, _ = torch.topk(dists_neg, k, largest=False)  # 取最小的 k 个
                    min_dist = torch.mean(topk_vals)
                    dist_hardest_negative.append(min_dist)
                print(f'min_dist : {min_dist}, max_dist : {max_dist}')

            # import pdb; pdb.set_trace()

            dist_hardest_positive = torch.stack(dist_hardest_positive, dim=0)
            dist_hardest_negative = torch.stack(dist_hardest_negative, dim=0)
            assert dist_hardest_positive.shape == dist_hardest_negative.shape
            yy = torch.ones_like(dist_hardest_positive)
            loss_tl = calc_tl(dist_hardest_negative, dist_hardest_positive, yy)
            
            # import pdb; pdb.set_trace()

        # here is the trick to mask out classes of non-current tasks
        if args.train_mask and class_mask is not None:
            mask = class_mask[task_id]
            not_mask = np.setdiff1d(np.arange(args.nb_classes), mask)
            not_mask = torch.tensor(not_mask, dtype=torch.int64).to(device)
            logits = logits.index_fill(dim=1, index=not_mask, value=float('-inf'))

        loss = criterion(logits, target)  # base criterion (CrossEntropyLoss)

        if args.contrastive_pretraining == "True" and args.tl_ratio != 0:
            print(f'loss_cls : {loss}, loss_fea : {loss_tl}')
            loss = loss + args.tl_ratio * loss_tl
        else:
            print(f'loss_cls : {loss}')
        acc1, acc5 = accuracy(logits, target, topk=(1, 5))

        if not math.isfinite(loss.item()):
            print("Loss is {}, stopping training".format(loss.item()))
            sys.exit(1)

        optimizer.zero_grad()
        loss.backward()
        # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
        optimizer.step()

        torch.cuda.synchronize()
        metric_logger.update(Loss=loss.item())
        metric_logger.update(Lr=optimizer.param_groups[0]["lr"])
        metric_logger.meters['Acc@1'].update(acc1.item(), n=input.shape[0])
        metric_logger.meters['Acc@5'].update(acc5.item(), n=input.shape[0])

    # gather the stats from all processes
    metric_logger.synchronize_between_processes()
    print("Averaged stats:", metric_logger)
    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}

@torch.no_grad()
def evaluate(model: torch.nn.Module, output_layer, data_loader,
             device, i=-1, task_id=-1, class_mask=None, target_task_map=None, args=None, cls_mean=None):
    criterion = torch.nn.CrossEntropyLoss()

    metric_logger = utils.MetricLogger(delimiter="  ")
    header = 'Test: [Task {}]'.format(i + 1)

    # switch to evaluation mode
    model.eval()
    # original_model.eval()

    num_classes = output_layer.size(0)  
    cm = torch.zeros((num_classes, num_classes), dtype=torch.int64) 
    all_correct = defaultdict(int)
    all_total = defaultdict(int)

    with torch.no_grad():
        for input, target in metric_logger.log_every(data_loader, args.print_freq, header):
            input = input.to(device, non_blocking=True)
            target = target.to(device, non_blocking=True)

            if args.run_analytic =="True":
                output = model.module.forward_rp(input, task_id=task_id, drs_lora=0)["x"]
                logits = F.linear(output.cpu(), output_layer)  
                target_cpu = target.cpu()

                acc1, acc5 = accuracy(logits, target_cpu, topk=(1, 5))
                metric_logger.meters['Acc@1'].update(acc1.item(), n=input.shape[0])
                metric_logger.meters['Acc@5'].update(acc5.item(), n=input.shape[0])

                pred = logits.argmax(dim=1)
                for t, p in zip(target_cpu.tolist(), pred.tolist()):
                    all_total[t] += 1
                    if t == p:
                        all_correct[t] += 1

                idx = target_cpu * num_classes + pred
                cm.view(-1).index_add_(0, idx, torch.ones_like(idx, dtype=torch.int64))
            elif args.run_analytic == "ncm":

                output = model.module.forward_rp(input, task_id=task_id, drs_lora=0)["x"] 
                output_cpu = output.detach().cpu()
                target_cpu = target.cpu()            
                classes = sorted(cls_mean.keys())
                means = torch.stack([cls_mean[c].detach().cpu() for c in classes], dim=0) 

                dists = torch.cdist(output_cpu, means)   
                preds = dists.argmin(dim=1)            
                preds = torch.tensor([classes[i] for i in preds.tolist()]) 

                correct = (preds == target_cpu).sum().item()
                acc1 = correct / target_cpu.size(0) * 100.0
                metric_logger.meters['Acc@1'].update(acc1, n=input.shape[0])
                metric_logger.meters['Acc@5'].update(acc1, n=input.shape[0])

            else:
                output = model.module.naive_classification(input, task_id=task_id, drs_lora=0)
                logits = output['logits']
            
                assert (args.task_inc and class_mask is not None) == False
                if args.task_inc and class_mask is not None:
                    # adding mask to output logits
                    mask = class_mask[i]
                    mask = torch.tensor(mask, dtype=torch.int64).to(device)
                    logits_mask = torch.ones_like(logits, device=device) * float('-inf')
                    logits_mask = logits_mask.index_fill(1, mask, 0.0)
                    logits = logits + logits_mask

                loss = criterion(logits, target)

                acc1, acc5 = accuracy(logits, target, topk=(1, 5))
                metric_logger.meters['Loss'].update(loss.item())
                metric_logger.meters['Acc@1'].update(acc1.item(), n=input.shape[0])
                metric_logger.meters['Acc@5'].update(acc5.item(), n=input.shape[0])
            

    
        if args.run_analytic == "True":
        # if 0:
            for cls in range(num_classes):
                total = all_total.get(cls, 0)
                correct = all_correct.get(cls, 0)
                if total > 0:
                    acc = correct / total
                    print(f"[Class {cls}] Acc={acc:.4f}  Total={total}  Correct={correct}")
                else:
                    # print(f"[Class {cls}] no samples")
                    pass

            print("\n=== Misclassification breakdown per class (true -> predicted: count) ===")
            for cls in range(num_classes):
                total = all_total.get(cls, 0)
                if total == 0:
                    # print(f"Class {cls}: no samples")
                    continue

                row = cm[cls] 
                mistakes = [(j, row[j].item()) for j in range(num_classes) if j != cls and row[j].item() > 0]
                if not mistakes:
                    print(f"Class {cls}: no misclassifications")
                    continue

                mistakes.sort(key=lambda x: x[1], reverse=True)
                top_items = ", ".join([f"{cls}->{j}: {cnt}" for j, cnt in mistakes])
                print(f"Class {cls}: {top_items}")

        metric_logger.synchronize_between_processes()
        print(
            '* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f}'
            .format(top1=metric_logger.meters['Acc@1'], top5=metric_logger.meters['Acc@5']))

        return {k: meter.global_avg for k, meter in metric_logger.meters.items()}


@torch.no_grad()
def evaluate_till_now(model: torch.nn.Module, output_layer, data_loader,
                      device, task_id=-1, class_mask=None, target_task_map=None, acc_matrix=None, evaluate_all=False, args=None, cls_mean=None):
    stat_matrix = np.zeros((4, args.num_tasks))  # 3 for Acc@1, Acc@5, Loss

    total_eva = task_id + 1
    if evaluate_all == True:
        total_eva = args.num_tasks
    for i in range(total_eva):
        test_stats = evaluate(model=model, output_layer=output_layer, data_loader=data_loader[i]['val'],
                              device=device, i=i, task_id=task_id, class_mask=class_mask, target_task_map=target_task_map,
                              args=args, cls_mean=cls_mean)

        stat_matrix[0, i] = test_stats['Acc@1']
        stat_matrix[1, i] = test_stats['Acc@5']
        # stat_matrix[3, i] = test_stats['Acc@task']

        acc_matrix[i, task_id] = test_stats['Acc@1']

    avg_stat = np.divide(np.sum(stat_matrix, axis=1), total_eva)
    task_id = total_eva - 1
    diagonal = np.diag(acc_matrix)

    result_str = "[Average accuracy till task{}]\\tAcc@1: {:.4f}\tAcc@5: {:.4f}".format(
        task_id + 1,
        # avg_stat[3],
        avg_stat[0],
        avg_stat[1],)
    if task_id > 0:
        forgetting = np.mean((np.max(acc_matrix, axis=1) -
                              acc_matrix[:, task_id])[:task_id])
        backward = np.mean((acc_matrix[:, task_id] - diagonal)[:task_id])

        result_str += "\tForgetting: {:.4f}\tBackward: {:.4f}".format(forgetting, backward)
    print(result_str)

    return test_stats

def spectrogram_augment(spec, freqm, timem):
    freq_masking = torchaudio.transforms.FrequencyMasking(freqm, iid_masks=True)
    time_masking = torchaudio.transforms.TimeMasking(timem, iid_masks=True) # [F, T]
    spec_ = spec.transpose(-2, -1).unsqueeze(0)
    spec_masked = freq_masking(spec_)
    spec_masked = time_masking(spec_masked)
    spec_masked = spec_masked.transpose(-2, -1)[0]

    return spec_masked

# def optimise_ridge_parameter(Features, Y):
#     ridges = 10.0 ** np.arange(-8,9)
#     num_val_samples = int(Features.shape[0] * 0.8)
#     losses = []
#     Q_val = Features[0:num_val_samples,:].T @ Y[0:num_val_samples,:]
#     G_val = Features[0:num_val_samples,:].T @ Features[0:num_val_samples,:]
#     for ridge in ridges:
#         Wo = torch.linalg.solve(G_val + ridge * torch.eye(G_val.size(dim = 0)), Q_val).T 
#         Y_train_pred = Features[num_val_samples::,:] @ Wo.T
#         losses.append(F.mse_loss(Y_train_pred, Y[num_val_samples::, :]))
#     ridge = ridges[np.argmin(np.array(losses))]
#     return ridge

def optimise_ridge_parameter(Features: torch.Tensor,
                             Y: torch.Tensor,
                             device: str = "cuda",
                             dtype: torch.dtype = torch.float32):

    dev = torch.device(device) if device is not None else Features.device
    X = Features.to(device=dev)
    T = Y.to(device=dev)

    N = X.shape[0]
    n_val = int(N * 0.8) 

    X_tr, T_tr = X[:n_val], T[:n_val]     
    X_te, T_te = X[n_val:], T[n_val:]         

    G = X_tr.T @ X_tr                    
    Q = X_tr.T @ T_tr                  
    I = torch.eye(G.size(0), device=dev, dtype=dtype)

    ridges = torch.logspace(-8, 8, steps=17, base=10.0, device=dev, dtype=dtype)

    losses = []
    for lam in ridges:
        Wo = torch.linalg.solve(G + lam * I, Q).T  
        Y_pred = X_te @ Wo.T                       
        losses.append(F.mse_loss(Y_pred, T_te))

    losses = torch.stack(losses)                   
    best_idx = int(torch.argmin(losses).item())
    best_ridge = ridges[best_idx].detach().to("cpu").item() 

    return best_ridge

def train_and_evaluate(model: torch.nn.Module, model_without_ddp: torch.nn.Module, 
                       criterion, data_loader: Iterable, data_loader_per_cls: Iterable,
                       optimizer: torch.optim.Optimizer,
                       lr_scheduler,
                       device: torch.device,
                       class_mask=None, target_task_map=None, args=None, ):
    # create matrix to save end-of-task accuracies
    acc_matrix = np.zeros((args.num_tasks, args.num_tasks))
    pre_ca_acc_matrix = np.zeros((args.num_tasks, args.num_tasks))
    global cls_mean
    global cls_cov
    cls_mean = dict()
    cls_cov = dict()

    Q_matrix = torch.zeros(args.rp_dim, args.num_classes)
    G_matrix = torch.zeros(args.rp_dim, args.rp_dim)
    fake_fc = nn.Parameter(torch.Tensor(args.num_classes, args.rp_dim))
    fea_in = {}

    for task_id in range(args.num_tasks):
        # Create new optimizer for each task to clear optimizer status
        # import pdb; pdb.set_trace()
        print(f'currently processing task {task_id}')
        # if 1:
        # if args.save_pt == "True" and task_id == 0:
        if 0:
            # drs_lora: whether to use lora subtraction
            # task_id: -1: pretrained model

            # keys = ['train']
            keys = ['val']
            show_list = [0,1]

            with torch.no_grad():
                print(f'start extracting feature for task {task_id}')
                features_all = []
                ys = []
                for _show_list in show_list:
                    for _keys in keys:
                        for X, y in data_loader[_show_list][_keys]:
                            # X_aug = spectrogram_augment(X, 15, 15)
                            X_aug = X
                            features_all.append(model.module.naive_classification(X_aug.to(model.device), drs_lora=0, task_id=task_id)['layer_results'])
                            # features_all_1_ori.append(model.module.naive_classification(X.to(model.device), drs_lora=0, task_id=task_id)['layer_results'])
                            ys.append(y)
                for layer_id in range(0, 12):
                    print(f'start extracting feature for layer {layer_id}')
                    features = []
                    labels = []
                    cnt = -1
                    for _show_list in show_list:
                        for _keys in keys:
                            for X, y in data_loader[_show_list][_keys]:
                                cnt += 1
                                features_1 = features_all[cnt][layer_id]
                                label_1 = ys[cnt]
                                features.append(features_1)
                                labels.append(label_1)
                    print('extracted')
                    features_tensor = torch.cat(features, dim=0)
                    labels_tensor = torch.cat(labels, dim=0)
                        
                    torch.save({"features": features_tensor, "labels": labels_tensor}, f"xtmp_pts_tmp/dataset_task_-1_layer_{layer_id}.pt")
                    print(f'saved to xtmp_pts_tmp/dataset_task_-1_layer_{layer_id}.pt')
    
        if 1:
        # if 0:
            if task_id == 0:
                # TODO: change it back

                if args.learn_and_frozen == "True":
                    for name, param in model.named_parameters(): 
                        param.requires_grad = False
                        if 'layer_norm' in name or 'cls_head' in name:
                            param.requires_grad = True
                            print(f'activate param: {name}')

                    for epoch in range(args.xianxuejige):
                        train_stats = train_one_epoch_default(model=model, criterion=criterion, data_loader=data_loader[task_id]['train'], optimizer=optimizer, 
                                                    device=device, epoch=epoch, max_norm=args.clip_grad, set_training_mode=True, task_id=task_id, class_mask=class_mask, args = args)
                                
                        if lr_scheduler:
                            lr_scheduler.step(epoch)

                    for name, param in model.named_parameters(): 
                        param.requires_grad = False
                        if 'lora_layer' in name and 'for_task' not in name:
                            param.requires_grad = True
                            print(f'activate param: {name}')

                    if args.epochs > 0:
                        for epoch in range(args.epochs):
                            train_stats = train_one_epoch_default(model=model, criterion=criterion, data_loader=data_loader[task_id]['train'], optimizer=optimizer, 
                                                        device=device, epoch=epoch, max_norm=args.clip_grad, set_training_mode=True, task_id=task_id, class_mask=class_mask, args = args)
                                    
                            if lr_scheduler:
                                lr_scheduler.step(epoch)
                
                else:

                    for name, param in model.named_parameters(): 
                        param.requires_grad = False
                        if (f'lora_layer' in name or 'layer_norm' in name or 'cls_head' in name) and 'for_task' not in name:
                        # if 'layer_norm' in name or 'cls_head' in name:
                            param.requires_grad = True
                            print(f'activate param: {name}')
                    if args.epochs > 0:
                        for epoch in range(args.epochs):
                            train_stats = train_one_epoch_default(model=model, criterion=criterion, data_loader=data_loader[task_id]['train'], optimizer=optimizer, 
                                                        device=device, epoch=epoch, max_norm=args.clip_grad, set_training_mode=True, task_id=task_id, class_mask=class_mask, args = args)
                                    
                            if lr_scheduler:
                                lr_scheduler.step(epoch)

            # --- test for lora_drs ---
            # if 0:
            if 1:
                if task_id >= 1:
                    
                    for name, param in model.named_parameters(): 
                        param.requires_grad = False
                        if f'lora_layer_for_task_{task_id}.' in name or 'layer_norm' in name or 'cls_head' in name:
                        # if 'layer_norm' in name or 'cls_head' in name:
                            param.requires_grad = True
                            print(f'activate param: {name}')

                    cur_dataloader = data_loader[task_id]['train']
                    with torch.no_grad():
                        print('start running drs_lora cov feature extraction')
                        features = []
                        labels = []
                        _cnt = 0
                        for X, y in cur_dataloader:
                            _cnt += 1
                            print(f'processing batch {_cnt} for drs space')
                            _ = model.module.naive_classification(X.to(model.device), task_id=task_id, drs_lora=1)
                        for name, module in model.named_modules():
                            if isinstance(module, AltAttention):
                                # model.module.lora_layer_for_task_1.v_lora_A
                                # name: module.blocks.k.attn
                                match = re.search(r'\.(\d+)\.', name)
                                if match:
                                    k = int(match.group(1)) 
                                    print(f"extracted layer index k: {k}")
                                else:
                                    print("no number found in the name.")

                                # import pdb; pdb.set_trace()
                                layer = getattr(model.module, f"lora_layer_for_task_{task_id}")
                                # param = getattr(layer, f"k_lora_A_{k}")
                                fea_in[getattr(layer, f"k_lora_A_{k}")] = deepcopy(module.cur_matrix).to(model.device)
                                fea_in[getattr(layer, f"k_lora_B_{k}")] = deepcopy(module.cur_matrix).to(model.device)
                                fea_in[getattr(layer, f"v_lora_A_{k}")] = deepcopy(module.cur_matrix).to(model.device)
                                fea_in[getattr(layer, f"v_lora_B_{k}")] = deepcopy(module.cur_matrix).to(model.device)
                                # import pdb; pdb.set_trace()
                                module.cur_matrix.zero_()
                                module.n_cur_matrix = 0
                        
                    optimizer = create_optimizer_2(args, model, is_first_epoch=False)
                    if args.sched != 'constant':
                        lr_scheduler, _ = create_scheduler(args, optimizer)
                    elif args.sched == 'constant':
                        lr_scheduler = None

                    try:
                        optimizer.get_eigens(fea_in)
                        optimizer.get_transforms()
                    except:
                        pass
                    
                    fea_in = {}

                    if args.epoch_for_other_tasks > 0:
                        for epoch in range(args.epoch_for_other_tasks):
                            train_stats = train_one_epoch_default(model=model, criterion=criterion, data_loader=data_loader[task_id]['train'], optimizer=optimizer, 
                                                        device=device, epoch=epoch, max_norm=args.clip_grad, set_training_mode=True, task_id=task_id, class_mask=class_mask, args = args)
                                    
                            if lr_scheduler:
                                lr_scheduler.step(epoch)

                elif task_id != 0:
                    import pdb; pdb.set_trace()
            # --- end of test ---
        if args.run_analytic == "True":
            print('run_analytic')
            dataset = data_loader[task_id]['train']
            X_list = []
            Y_list = []

            for X, y in dataset:
                X_list.append(X.to(device))
                Y_list.append(torch.nn.functional.one_hot(y, num_classes=args.num_classes).to(device))

            X_all = torch.cat(X_list, dim=0)  
            Y_all = torch.cat(Y_list, dim=0)  
            # import pdb; pdb.set_trace()

            print('extracting the projected feature')
            # X = model.module.forward_rp(X_all).to('cpu')
            batch_size = 16
            num_batches = X_all.size(0) // batch_size + (X_all.size(0) % batch_size != 0)

            with tempfile.TemporaryDirectory() as output_dir:
                torch.cuda.empty_cache()
                for i in range(num_batches):
                    batch = X_all[i * batch_size: (i + 1) * batch_size]
                    batch_features = model.module.forward_rp(batch, task_id=task_id, drs_lora=0)["x"].to('cpu')
                    torch.save(batch_features, os.path.join(output_dir, f"batch_{i}.pt"))
                    del batch, batch_features
                    torch.cuda.empty_cache()
                    print(f"Processed batch {i+1}/{num_batches}")

                all_features = []
                for i in range(num_batches):
                    batch_features = torch.load(os.path.join(output_dir, f"batch_{i}.pt"))
                    all_features.append(batch_features)
                    del batch_features
                    torch.cuda.empty_cache()

                X = torch.cat(all_features, dim=0).detach()
            print('extracted')
            Y = torch.cat(Y_list, dim=0).to('cpu').to(torch.float32)
            # import pdb; pdb.set_trace()
            print(X.shape, Y.shape)
            print('solving Q')
            Q_matrix_tmp = Q_matrix + X.T @ Y 
            print('solving G')
            G_matrix_tmp = G_matrix + X.T @ X
            print('solving ridge')
            ridge_tmp = optimise_ridge_parameter(X, Y, device='cuda')
            print('solving Wo')
            Wo_tmp = torch.linalg.solve(G_matrix_tmp + ridge_tmp * torch.eye(G_matrix_tmp.size(dim=0)), Q_matrix_tmp).T 
            # import pdb; pdb.set_trace()
            fake_fc.data = Wo_tmp[0:fake_fc.shape[0],:]

            del dataset
            del X_list
            del Y_list
            del X_all
            del Y_all
            del X
            del Y
            torch.cuda.empty_cache()
            # import pdb; pdb.set_trace()
            if not(args.contrastive_pretraining == "True" and task_id >= 1):
                print('not jicheng')
                Q_matrix = Q_matrix_tmp
                G_matrix = G_matrix_tmp
                ridge = ridge_tmp
                Wo = Wo_tmp
            elif args.contrastive_pretraining == "True" and task_id >= 1:
                model.eval()

                output_layer = fake_fc
                num_classes = output_layer.size(0)  # [C, D] -> 类别数
                cm = torch.zeros((num_classes, num_classes), dtype=torch.int64)  # 混淆矩阵: [true, pred]
                all_correct = defaultdict(int)
                all_total = defaultdict(int)
                tmp_dataloader = data_loader[task_id]['train']

                with torch.no_grad():
                    total_selected_neg = defaultdict(list)
                    total_selected_pos = defaultdict(list)
                    replay_buffer_features = []
                    replay_buffer_labels = []
                    cmt = 0
                    for input, target in tmp_dataloader:
                        print(cmt / len(tmp_dataloader))
                        input = input.to(device, non_blocking=True)
                        target = target.to(device, non_blocking=True)

                        # import pdb; pdb.set_trace()
                        input_augs =  [spectrogram_augment(input, args.time_aug, args.freq_aug) for _ in range(args.aug_num)]

                        output = model.module.forward_rp(input, task_id=task_id, drs_lora=0)["pre_logits"].detach()
                        target_cpu = target.cpu()
                        logits_augs = []

                        valid_labels = torch.tensor([__ for __ in range(int(args.num_classes / args.num_tasks * task_id))], device=target_cpu.device)

                        for _ in range(args.aug_num):
                            print(_)
                            # import pdb; pdb.set_trace()
                            output_aug_tmp = model.module.forward_rp(input_augs[_], task_id=task_id, drs_lora=0)["x"].detach().cpu()
                            logits_tmp = F.linear(output_aug_tmp, output_layer)  # [B, C]
                            logits_augs.append(logits_tmp)  # logits 在 CPU
                            probs = F.softmax(logits_tmp, dim=-1)  # [B, C]
                            conf, pred = probs.max(dim=-1)
                            # is_negative = pred != target
                            # print(f'conf : {conf}')
                            # high_conf = conf > args.replay_conf_thres
                            # mask_tmp = is_negative & high_conf
                            mask_tmp = (pred.unsqueeze(-1) == valid_labels).any(-1)
                            if mask_tmp.any():
                                neg_features = output_aug_tmp[mask_tmp] 
                                pred_classes = pred[mask_tmp]
                                replay_buffer_features.append(neg_features)
                                replay_buffer_labels.append(pred_classes)

                        logits_aug = torch.stack(logits_augs, dim=0)     
                        pred_aug = logits_aug.argmax(-1) 
                        # import pdb; pdb.set_trace() 
                        valid_labels = torch.tensor([__ for __ in range(int(args.num_classes / args.num_tasks * task_id))], device=pred_aug.device)
                        print(f'current valid labels: { valid_labels}')
                        mask = (pred_aug.unsqueeze(-1) == valid_labels)   # [N, B, L]
                        # mask = (pred_aug.unsqueeze(-1) != target.unsqueeze(0).unsqueeze(-1).to(pred_aug.device))
                        mask_any = mask.any(-1)
                        counts = mask_any.sum(dim=0)
                        print(f'counts : {counts}')
                        mask2 = counts > args.choose_k 
                        selected_indices = mask2.nonzero(as_tuple=True)[0]
                        for xtmp in selected_indices:
                            total_selected_neg[int(target[xtmp])].append(output[xtmp] / output[xtmp].norm(dim=-1, keepdim=True))
                            
                        for xtmp in range(input.size(0)):
                            total_selected_pos[int(target[xtmp])].append(output[xtmp] / output[xtmp].norm(dim=-1, keepdim=True))
                    
                    try:
                        replay_buffer_features = torch.cat(replay_buffer_features, dim=0)
                        replay_buffer_labels = torch.cat(replay_buffer_labels, dim=0)
                        assert replay_buffer_features.shape[0] == replay_buffer_labels.shape[0]
                        print(f'replay number : {replay_buffer_features.shape[0]}')
                    except:
                        print(f'replay number : 0')

                for cls_id in range(int(task_id * args.num_classes / args.num_tasks), int((task_id + 1) * args.num_classes / args.num_tasks)):
                    print('---------------------------------------------------------------------------------------------------------')
                    print(f'selected {len(total_selected_neg[cls_id])} samples severd as hardest negative for cls {cls_id}')
                    print(f'selected {len(total_selected_pos[cls_id])} samples severd as hardest positive for cls {cls_id}')
                print('---------------------------------------------------------------------------------------------------------')
                # import pdb; pdb.set_trace()

                print('start training with feature push')

                for epoch in range(args.epoch_for_additional_training):
                    # import pdb; pdb.set_trace()
                    train_stats = train_one_epoch_additional(model=model, criterion=criterion, data_loader=data_loader[task_id]['train'], optimizer=optimizer, 
                                                device=device, epoch=epoch, max_norm=args.clip_grad, set_training_mode=True, task_id=task_id, class_mask=class_mask, args = args, total_neg = total_selected_neg, total_pos = total_selected_pos)
                            
                    if lr_scheduler:
                        lr_scheduler.step(epoch)

                print('start training analytic classification head for another time')

                # for another time
                dataset = data_loader[task_id]['train']
                X_list = []
                Y_list = []

                for X, y in dataset:
                    X_list.append(X.to(device))
                    Y_list.append(torch.nn.functional.one_hot(y, num_classes=args.num_classes).to(device))

                X_all = torch.cat(X_list, dim=0)  
                Y_all = torch.cat(Y_list, dim=0)  
                # import pdb; pdb.set_trace()

                print('extracting the projected feature')
                # X = model.module.forward_rp(X_all).to('cpu')
                batch_size = 16
                num_batches = X_all.size(0) // batch_size + (X_all.size(0) % batch_size != 0)

                with tempfile.TemporaryDirectory() as output_dir:
                    torch.cuda.empty_cache()
                    for i in range(num_batches):
                        batch = X_all[i * batch_size: (i + 1) * batch_size]
                        batch_features = model.module.forward_rp(batch, task_id=task_id, drs_lora=0)["x"].to('cpu')
                        torch.save(batch_features, os.path.join(output_dir, f"batch_{i}.pt"))
                        del batch, batch_features
                        torch.cuda.empty_cache()
                        print(f"Processed batch {i+1}/{num_batches}")

                    all_features = []
                    for i in range(num_batches):
                        batch_features = torch.load(os.path.join(output_dir, f"batch_{i}.pt"))
                        all_features.append(batch_features)
                        del batch_features
                        torch.cuda.empty_cache()

                    X = torch.cat(all_features, dim=0).detach()
                print('extracted')
                Y = torch.cat(Y_list, dim=0).to('cpu').to(torch.float32)
                # import pdb; pdb.set_trace()
                print(X.shape, Y.shape)
                # print('Catting replayed features')
                # X = torch.cat([X, replay_buffer_features], dim=0)
                # replay_buffer_labels = torch.nn.functional.one_hot(replay_buffer_labels, num_classes=args.num_classes)
                # Y = torch.cat([Y, replay_buffer_labels], dim=0)
                # print('Catted')
                print(X.shape, Y.shape)
                print('solving Q')
                Q_matrix = Q_matrix + X.T @ Y 
                print('solving G')
                G_matrix = G_matrix + X.T @ X
                print('solving ridge')
                ridge = optimise_ridge_parameter(X, Y)
                print('solving Wo')
                Wo = torch.linalg.solve(G_matrix + ridge * torch.eye(G_matrix.size(dim=0)), Q_matrix).T 
                # import pdb; pdb.set_trace()
                fake_fc.data = Wo[0:fake_fc.shape[0],:]

                del dataset
                del X_list
                del Y_list
                del X_all
                del Y_all
                del X
                del Y
                torch.cuda.empty_cache()
        elif args.run_analytic == 'ncm':
            print('running ncm')
            dataset = data_loader[task_id]['train']
            X_list = []
            Y_list = []

            for X, y in dataset:
                X_list.append(X.to(device))
                Y_list.append(torch.nn.functional.one_hot(y, num_classes=args.num_classes).to(device))

            X_all = torch.cat(X_list, dim=0)  
            Y_all = torch.cat(Y_list, dim=0)  
            batch_size = 16
            num_batches = X_all.size(0) // batch_size + (X_all.size(0) % batch_size != 0)

            with torch.inference_mode(): 
                with tempfile.TemporaryDirectory() as output_dir:
                    torch.cuda.empty_cache()
                    features_list = []
                    labels_list = []

                    for i in range(num_batches):
                        batch = X_all[i * batch_size: (i + 1) * batch_size]            
                        batch_labels = Y_all[i * batch_size: (i + 1) * batch_size]

                        feats = model.module.forward_rp(batch, task_id=task_id, drs_lora=0)["x"]

                        feats_cpu = feats.detach().to("cpu", copy=True)
                        labels_cpu = batch_labels.detach().to("cpu", copy=True) if batch_labels.is_floating_point() or batch_labels.requires_grad else batch_labels.cpu()

                        features_list.append(feats_cpu)
                        labels_list.append(labels_cpu)


                        del feats, feats_cpu, labels_cpu, batch, batch_labels
                        torch.cuda.synchronize()            
                        torch.cuda.empty_cache()
                        print(f"Processed batch {i+1}/{num_batches})")

                # 拼接所有特征与标签
                all_features = torch.cat(features_list, dim=0)  # [N, D]
                all_labels = torch.cat(labels_list, dim=0)      # [N, C]

                # 计算每类均值
                for c in range(args.num_classes):
                    mask = all_labels[:, c] == 1
                    if mask.sum() > 0:
                        mean_vec = all_features[mask].mean(dim=0)  # [D]
                        cls_mean[c] = mean_vec
                        print(f"Class {c}: mean size {mean_vec.shape}")
                    else:
                        print(f"Class {c}: no samples found")

        
        if task_id == args.num_tasks - 1:
        # if 1:
            test_stats = evaluate_till_now(model=model, output_layer=fake_fc, data_loader=data_loader,
                                        device=device,
                                        task_id=task_id, class_mask=class_mask, target_task_map=target_task_map,
                                        acc_matrix=acc_matrix, args=args, cls_mean=cls_mean)

            # if not args.trained_backbone_and_pet:
            if 1:
                if args.output_dir and utils.is_main_process():
                    Path(os.path.join(args.output_dir, 'checkpoint')).mkdir(parents=True, exist_ok=True)

                    checkpoint_path = os.path.join(args.output_dir, 'checkpoint/task{}_checkpoint.pth'.format(task_id + 1))
                    state_dict = {
                        'model': model_without_ddp.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'args': args,
                    }
                    if args.sched is not None and args.sched != 'constant':
                        state_dict['lr_scheduler'] = lr_scheduler.state_dict()

                    utils.save_on_master(state_dict, checkpoint_path)

                log_stats = {
                            **{f'test_{k}': v for k, v in test_stats.items()},
                            }

                if args.output_dir and utils.is_main_process():
                    with open(os.path.join(args.output_dir,
                                        '{}_stats.txt'.format(datetime.datetime.now().strftime('log_%Y_%m_%d_%H_%M'))),
                            'a') as f:
                        f.write(json.dumps(log_stats) + '\n')
        # if 0:
        # if 1:
        if args.save_pt == "True":
            # drs_lora: whether to use lora subtraction
            # task_id: -1: pretrained model

            # keys = ['train']
            keys = ['val']
            show_list = [task_id]

            with torch.no_grad():
                print(f'start extracting feature for task {task_id}')
                features_all = []
                features_all_before = []
                ys = []
                for _show_list in show_list:
                    for _keys in keys:
                        for X, y in data_loader[_show_list][_keys]:
                            # X_aug = spectrogram_augment(X, 15, 15)
                            X_aug = X
                            features_all.append(model.module.naive_classification(X_aug.to(model.device), drs_lora=0, task_id=task_id)['layer_results'])
                            features_all_before.append(model.module.naive_classification(X_aug.to(model.device), drs_lora=0, task_id=task_id-1)['layer_results'])
                            
                            # features_all_1_ori.append(model.module.naive_classification(X.to(model.device), drs_lora=0, task_id=task_id)['layer_results'])
                            ys.append(y)
                for layer_id in range(11, 12):
                    print(f'start extracting feature for layer {layer_id}')
                    features = []
                    features_before = []
                    labels = []
                    cnt = -1
                    for _show_list in show_list:
                        for _keys in keys:
                            for X, y in data_loader[_show_list][_keys]:
                                cnt += 1
                                features_1 = features_all[cnt][layer_id]
                                features_1_before = features_all_before[cnt][layer_id]
                                label_1 = ys[cnt]
                                features.append(features_1)
                                features_before.append(features_1_before)
                                labels.append(label_1)
                    print('extracted')
                
                    features_tensor = torch.cat(features, dim=0)
                    features_tensor_before = torch.cat(features_before, dim=0)
                    labels_tensor = torch.cat(labels, dim=0)
                        
                    torch.save({"features": features_tensor, "features_before": features_tensor_before, "labels": labels_tensor}, f"xtmp_pts_tmp/ranpac_dataset_{args.dataset}_task_{task_id}_layer_{layer_id}.pt")
                    print(f'saved to xtmp_pts_tmp/ranpac_dataset_{args.dataset}_task_{task_id}_layer_{layer_id}.pt')

        if task_id + 1 >= args.learn_tasks:
            if args.run_analytic == "True":
                for name, param in model.named_parameters(): 
                    param.requires_grad = False
                print('start ranpac')

                print(f'currently processing task {task_id}')
                print('start merging dataset')
                task_num = len(data_loader)
                if args.dataset == 'timit':
                    datasets = [data_loader[task_id_inner]['train'].dataset for task_id_inner in range(task_id+1, task_num)]
                    merged_dataset = ConcatDataset(datasets)
                    merged_loader = DataLoader(merged_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4)
                    print('merged')

                    dataset = merged_loader

                    X_list = []
                    Y_list = []

                    for X, y in dataset:
                        X_list.append(X)
                        Y_list.append(torch.nn.functional.one_hot(y, num_classes=args.num_classes))
                    X_all = torch.cat(X_list, dim=0).to(device)
                    Y_all = torch.cat(Y_list, dim=0).to(device)

                    # import pdb; pdb.set_trace()

                    print('extracting the projected feature')
                    # X = model.module.forward_rp(X_all).to('cpu')
                    batch_size = 64
                    num_batches = X_all.size(0) // batch_size + (X_all.size(0) % batch_size != 0)

                    with tempfile.TemporaryDirectory() as output_dir:
                        for i in range(num_batches):
                            batch = X_all[i * batch_size: (i + 1) * batch_size]
                            batch_features = model.module.forward_rp(batch, task_id=task_id, drs_lora=0)["x"].to('cpu')
                            torch.save(batch_features, os.path.join(output_dir, f"batch_{i}.pt"))
                            del batch, batch_features
                            torch.cuda.empty_cache()
                            print(f"Processed batch {i+1}/{num_batches}")

                        all_features = []
                        for i in range(num_batches):
                            batch_features = torch.load(os.path.join(output_dir, f"batch_{i}.pt"))
                            all_features.append(batch_features)

                    X = torch.cat(all_features, dim=0).detach()
                    print('extracted')
                    Y = torch.cat(Y_list, dim=0).to('cpu').to(torch.float32)
                    # import pdb; pdb.set_trace()
                    print(X.shape, Y.shape)
                    print('solving Q')
                    Q_matrix = Q_matrix + X.T @ Y 
                    print('solving G')
                    G_matrix = G_matrix + X.T @ X
                    print('solving ridge')
                    ridge = optimise_ridge_parameter(X, Y)
                    print('solving Wo')
                    Wo = torch.linalg.solve(G_matrix + ridge*torch.eye(G_matrix.size(dim=0)), Q_matrix).T 
                    # import pdb; pdb.set_trace()
                    fake_fc.data = Wo[0:fake_fc.shape[0],:]
                else:
                    for task_id_inner in range(task_id+1, task_num):
                        print(f'now processing task: {task_id_inner}')
                        datasets = [data_loader[task_id_inner]['train'].dataset]
                        # datasets = [data_loader[task_id_inner]['train'].dataset for task_id_inner in range(task_id+1, task_num)]

                        merged_dataset = ConcatDataset(datasets)
                        merged_loader = DataLoader(merged_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4)
                        print('merged')

                        dataset = merged_loader

                        X_list = []
                        Y_list = []

                        for X, y in dataset:
                            X_list.append(X)
                            Y_list.append(torch.nn.functional.one_hot(y, num_classes=args.num_classes))
                        X_all = torch.cat(X_list, dim=0).to(device)
                        Y_all = torch.cat(Y_list, dim=0).to(device)

                        # import pdb; pdb.set_trace()

                        print('extracting the projected feature')
                        # X = model.module.forward_rp(X_all).to('cpu')
                        batch_size = 64
                        num_batches = X_all.size(0) // batch_size + (X_all.size(0) % batch_size != 0)

                        with tempfile.TemporaryDirectory() as output_dir:
                            for i in range(num_batches):
                                batch = X_all[i * batch_size: (i + 1) * batch_size]
                                batch_features = model.module.forward_rp(batch, task_id=task_id, drs_lora=0)["x"].to('cpu')
                                torch.save(batch_features, os.path.join(output_dir, f"batch_{i}.pt"))
                                del batch, batch_features
                                torch.cuda.empty_cache()
                                print(f"Processed batch {i+1}/{num_batches}")

                            all_features = []
                            for i in range(num_batches):
                                batch_features = torch.load(os.path.join(output_dir, f"batch_{i}.pt"))
                                all_features.append(batch_features)

                        X = torch.cat(all_features, dim=0).detach()
                        print('extracted')
                        Y = torch.cat(Y_list, dim=0).to('cpu').to(torch.float32)
                        # import pdb; pdb.set_trace()
                        print(X.shape, Y.shape)
                        print('solving Q')
                        Q_matrix = Q_matrix + X.T @ Y 
                        print('solving G')
                        G_matrix = G_matrix + X.T @ X
                        print('solving ridge')
                        ridge = optimise_ridge_parameter(X, Y)
                        print('solving Wo')
                        Wo = torch.linalg.solve(G_matrix + ridge*torch.eye(G_matrix.size(dim=0)), Q_matrix).T 
                        # import pdb; pdb.set_trace()
                        fake_fc.data = Wo[0:fake_fc.shape[0],:]
        
                test_stats = evaluate_till_now(model=model, output_layer=fake_fc, data_loader=data_loader,
                                            device=device,
                                            task_id=task_id, class_mask=class_mask, target_task_map=target_task_map,
                                            acc_matrix=acc_matrix, evaluate_all=True, args=args)

                if args.output_dir and utils.is_main_process():
                    Path(os.path.join(args.output_dir, 'checkpoint')).mkdir(parents=True, exist_ok=True)

                    checkpoint_path = os.path.join(args.output_dir, 'checkpoint/task{}_checkpoint.pth'.format(task_id + 1))
                    state_dict = {
                        'model': model_without_ddp.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'args': args,
                    }
                    if args.sched is not None and args.sched != 'constant':
                        state_dict['lr_scheduler'] = lr_scheduler.state_dict()

                    utils.save_on_master(state_dict, checkpoint_path)

                log_stats = {
                            **{f'test_{k}': v for k, v in test_stats.items()},
                            }

                if args.output_dir and utils.is_main_process():
                    with open(os.path.join(args.output_dir,
                                        '{}_stats.txt'.format(datetime.datetime.now().strftime('log_%Y_%m_%d_%H_%M'))),
                            'a') as f:
                        f.write(json.dumps(log_stats) + '\n')
                return
            elif args.run_analytic == 'ncm':
                task_num = len(data_loader)
                for task_id_inner in range(task_id+1, task_num):
                    dataset = data_loader[task_id_inner]['train'].dataset

                    X_list = []
                    Y_list = []
                    # import pdb; pdb.set_trace()

                    for X, y in dataset:
                        # import pdb; pdb.set_trace()
                        X_list.append(X)
                        y = torch.tensor(y).to(X.device)
                        Y_list.append(torch.nn.functional.one_hot(y, num_classes=args.num_classes).unsqueeze(0))
                    

                    X_all = torch.cat(X_list, dim=0).to(device)
                    Y_all = torch.cat(Y_list, dim=0).to(device)

                    cls_sum = None     # shape [C, D] on cpu
                    cls_cnt = None     # shape [C] on cpu

                    del X_list, Y_list

                    amp_ctx = torch.autocast(device_type="cuda", dtype=torch.float16) if torch.cuda.is_available() else nullcontext()

                    with torch.inference_mode(), amp_ctx:
                        for i in range(num_batches):
                            batch = X_all[i * batch_size: (i + 1) * batch_size].to("cuda", non_blocking=True)
                            if batch.ndim == 3:        # [B, F, T]
                                batch = batch.unsqueeze(1) # -> [B, 1, F, T]
                            batch_labels = Y_all[i * batch_size: (i + 1) * batch_size]  # one-hot [B, C] 或 索引 [B]
                            # import pdb; pdb.set_trace()

                            feats = model.module.forward_rp(batch, task_id=task_id, drs_lora=0)["x"]  # [B, D]
                            feats = feats.detach().to("cpu", dtype=torch.float32, copy=True)  # CPU 累计
                            B, D = feats.shape

                            # 初始化累计容器
                            if cls_sum is None:
                                C = args.num_classes
                                cls_sum = torch.zeros(C, D, dtype=torch.float32)  # cpu
                                cls_cnt = torch.zeros(C, dtype=torch.long)        # cpu

                            # 兼容 one-hot / 索引标签
                            if batch_labels.dim() > 1:
                                y_idx = batch_labels.argmax(dim=1).to(torch.long)
                            else:
                                y_idx = batch_labels.to(torch.long)

                            # 累计
                            for c in y_idx.unique().tolist():
                                mask = (y_idx == c).cpu()
                                if feats.shape[0] != mask.shape[0]:
                                    import pdb; pdb.set_trace()
                                if mask.any():
                                    cls_sum[c] += feats[mask].sum(dim=0)
                                    cls_cnt[c] += int(mask.sum().item())

                            del batch, batch_labels, feats, y_idx
                            torch.cuda.synchronize()
                            torch.cuda.empty_cache()
                            if (i + 1) % 10 == 0:
                                print(f"Processed {i+1}/{num_batches}")

                    del X_all, Y_all

                    for c in range(args.num_classes):
                        if cls_cnt[c] > 0:
                            mean_vec = cls_sum[c] / cls_cnt[c].item()
                            cls_mean[c] = mean_vec
                            print(f"Class {c}: mean size {mean_vec.shape}, count={cls_cnt[c].item()}")
                        else:
                            print(f"Class {c}: no samples found")

                test_stats = evaluate_till_now(model=model, output_layer=fake_fc, data_loader=data_loader,
                                            device=device,
                                            task_id=task_id, class_mask=class_mask, target_task_map=target_task_map,
                                            acc_matrix=acc_matrix, evaluate_all=True, args=args, cls_mean=cls_mean)

                if args.output_dir and utils.is_main_process():
                    Path(os.path.join(args.output_dir, 'checkpoint')).mkdir(parents=True, exist_ok=True)

                    checkpoint_path = os.path.join(args.output_dir, 'checkpoint/task{}_checkpoint.pth'.format(task_id + 1))
                    state_dict = {
                        'model': model_without_ddp.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'args': args,
                    }
                    if args.sched is not None and args.sched != 'constant':
                        state_dict['lr_scheduler'] = lr_scheduler.state_dict()

                    utils.save_on_master(state_dict, checkpoint_path)

                log_stats = {
                            **{f'test_{k}': v for k, v in test_stats.items()},
                            }

                if args.output_dir and utils.is_main_process():
                    with open(os.path.join(args.output_dir,
                                        '{}_stats.txt'.format(datetime.datetime.now().strftime('log_%Y_%m_%d_%H_%M'))),
                            'a') as f:
                        f.write(json.dumps(log_stats) + '\n')
                return