from copy import deepcopy

import numpy as np
import torch
from sklearn.manifold import TSNE
from torch import nn
import dataloader.data_utils as data_utils

from utils import *
from tqdm import tqdm
import torch.nn.functional as F
import logging
from losses import SupConLoss, CosineDistanceSum, MaxMinCosineDistance, ContrastiveMarginLoss, AngularLoss, \
    prototypical_loss_cosine, AngularPenaltySMLoss
# import umap
from . import supcon
from .continual_learning import simple_reg_loss_l1, adjust_gradients, freeze_resnet_layers

from sklearn.metrics import accuracy_score

# from ..dualnet18_encoder import BasicBlock
from ..resnet18_encoder import resnet18, AdapterLayer, mode_context, BasicBlock


def custom_loss(output, target, gamma=2.0, alpha=None):
    """
    Custom loss function with a focus on misclassified samples.

    Parameters:
    - output: logits from the model
    - target: ground truth labels
    - gamma: focusing parameter for the focal loss.
    - alpha: weighting factor for class-specific weights. This can be a list or tensor of weights.

    Returns:
    - total_loss: combined loss value.
    """

    # Convert logits to probabilities using softmax
    probs = F.softmax(output, dim=-1)

    # One-hot encode the target labels
    one_hot = F.one_hot(target, num_classes=output.size(1)).float()

    # Extract the probabilities corresponding to the correct classes
    pt = (probs * one_hot).sum(dim=1)  # pt is the probability assigned to the correct class

    # Focal loss component
    focal_weight = (1 - pt) ** gamma  # Focus on misclassified samples
    L1 = F.cross_entropy(output, target, reduction='none')  # Per-sample cross-entropy loss

    if alpha is not None:  # If class weights are provided, apply them
        class_weights = torch.tensor(alpha, device=target.device)  # Convert to tensor on the same device
        L1 = L1 * class_weights[target]  # Apply class-specific weights

    L1 = L1 * focal_weight  # Apply focal weighting

    # Total loss is now the average of these weighted losses
    total_loss = L1.mean()

    return total_loss


def label_smoothing_loss(output, target, args, epsilon=0.1):
    """
    Label smoothing loss function.

    Parameters:
    - output: logits from the model
    - target: ground truth labels
    - epsilon: smoothing parameter.

    Returns:
    - loss: label smoothing loss value.
    """
    # Convert logits to log-probabilities using log_softmax
    log_probs = F.log_softmax(output, dim=-1)

    # One-hot encode the target labels
    # num_classes = output.size(1)
    one_hot = torch.zeros_like(output).scatter(1, target.view(-1, 1), 1)

    # Apply label smoothing
    smoothed_labels = (1 - epsilon) * one_hot + epsilon / args.num_classes

    # Compute the label smoothing loss
    loss = -torch.sum(smoothed_labels * log_probs, dim=1)

    return loss.mean()


def advanced_custom_loss(output, target, args, session, gamma=0.1):
    """
    Advanced custom loss function integrating dynamic masking and differential penalization for
    continual learning scenarios, with proper device handling.

    Parameters:
    - output: logits from the model (ensure these are on the correct device)
    - target: ground truth labels (ensure these are on the correct device)
    - args: configuration object containing metadata like base_class
    - session: current session number in the continual learning scenario
    - gamma: weighting factor for the second term of the loss

    Returns:
    - total_loss: combined loss value
    """
    device = output.device  # Get the device from the output tensor
    num_classes = output.size(1)
    base_class = args.base_class

    # Standard cross-entropy loss (L1)
    L1 = F.cross_entropy(output, target)

    # Generate one-hot encoding
    one_hot = F.one_hot(target, num_classes=num_classes).to(device).float()

    # Dynamic soften factor based on session
    soften_factor = max(0, 1 - session / 10)  # Reduce soften factor as session number increases

    # Create different masks for old and new classes
    new_classes_mask = torch.where(torch.arange(num_classes, device=device) >= base_class + (session - 1) * args.way,
                                   1 - soften_factor,  # Less penalty for new class logits
                                   torch.tensor([1e-6], device=device).expand(
                                       num_classes))  # High penalty for old class logits

    # Apply the mask to dampen logits selectively
    masked_output = output * ((1 - one_hot) + one_hot * new_classes_mask)

    # Applying cross-entropy to the masked output
    L2 = F.cross_entropy(masked_output, target)

    # Combine losses
    total_loss = L1 + gamma * L2

    return total_loss


def add_noise(model, scale=0.1):
    """Add Gaussian noise to all parameters of the model."""
    for param in model.module.encoder.parameters():
        if param.requires_grad:
            # Ensure the noise addition does not alter the computation graph.
            with torch.no_grad():
                noise = torch.randn_like(param) * scale
                param.add_(noise)  # This is an in-place operation


def calculate_prototype_deviation_loss(features, prototypes, train_labels):
    # Ensure prototypes are on the same device as features
    prototypes = deepcopy(prototypes).to(features.device)
    # Calculate the distance between batch features and their corresponding class prototypes
    prototype_loss = 0.0
    for i, prototype in enumerate(prototypes):
        # Find the indices of the samples in the batch that belong to the current class
        indices = (train_labels == i)
        # Calculate the loss only for the features of the current class
        if indices.any():
            prototype_loss += (features[indices] - prototype).pow(2).mean()
    prototype_loss /= len(prototypes)
    return prototype_loss


def base_train_with_feature_deviation_constraint(model, trainloader, optimizer, scheduler, epoch, args, prototypes):
    tl = Averager()
    ta = Averager()
    model = model.train()

    tqdm_gen = tqdm(trainloader)

    for i, batch in enumerate(tqdm_gen, 1):
        data, train_label = [_.cuda() for _ in batch]
        # with torch.no_grad():
        #     for param in model.parameters():
        #         if param.requires_grad:
        #             # Calculate noise
        #             noise = torch.randn_like(param) * args.noise_scale
        #             # Calculate parameter norm
        #             param_norm = param.norm()
        #             # Clip noise, to not exceed 10% of parameter norm
        #             noise = torch.clamp(noise, -0.1 * param_norm, 0.1 * param_norm)
        #             # Add noise to parameters
        #             param.data = param.data + noise
        model.module.mode = 'encoder'
        features = model(data)
        model.module.mode = args.base_mode

        logits = model(data)  # Assuming your model returns both logits and features
        logits = logits[:, :args.base_class]
        loss = F.cross_entropy(logits, train_label)

        # Compute the prototype deviation loss
        proto_loss = calculate_prototype_deviation_loss(features, prototypes, train_label)

        # Combine the classification loss and the prototype deviation loss
        total_loss = loss + proto_loss

        acc = count_acc(logits, train_label)

        lrc = scheduler.get_last_lr()[0]
        tqdm_gen.set_description(
            'Session 0, epo {}, lrc={:.4f},total loss={:.4f}, acc={:.4f}'.format(epoch, lrc, total_loss.item(), acc))
        tl.add(total_loss.item())
        ta.add(acc)

        optimizer.zero_grad()
        total_loss.backward()  # Make sure to backpropagate the total loss

        # # Add noise to the gradients if they exist

        with torch.no_grad():
            for param in model.parameters():
                if param.requires_grad and param.grad is not None:
                    # 计算噪声
                    noise = torch.randn_like(param.grad) * args.noise_scale
                    # 计算梯度的范数
                    grad_norm = param.grad.norm()
                    # 裁剪噪声，使其不超过梯度范数的10%
                    noise = torch.clamp(noise, -0.1 * grad_norm, 0.1 * grad_norm)
                    # 添加噪声
                    param.grad += noise
        optimizer.step()
        # Add noise to the parameters without in-place operations

        # Before the backward pass, add noise to the model parameters
    # add_noise(model, scale=args.noise_scale)
    tl = tl.item()
    ta = ta.item()
    return tl, ta


# def inter_class_separation(features, labels):
#     unique_labels = labels.unique()
#     separation_loss = 0.0
#     centers = []
#     for label in unique_labels:
#         class_features = features[labels == label]
#         class_center = class_features.mean(dim=0)
#         centers.append(class_center)
#     for i in range(len(centers)):
#         for j in range(i + 1, len(centers)):
#             separation_loss += 1 - F.cosine_similarity(centers[i].unsqueeze(0), centers[j].unsqueeze(0))
#     return separation_loss / (len(centers) * (len(centers) - 1) / 2)
#
#
# def intra_class_compactness(features, labels):
#     unique_labels = labels.unique()
#     compactness_loss = 0.0
#     for label in unique_labels:
#         class_features = features[labels == label]
#         class_center = class_features.mean(dim=0)
#         compactness_loss += torch.mean(1 - F.cosine_similarity(class_features, class_center.unsqueeze(0)))
#     return compactness_loss / len(unique_labels)
#
#
# # 定义旋转预测任务
# def rotate_images(images):
#     rotated_images = []
#     labels = []
#     for image in images:
#         angle = random.choice([0, 90, 180, 270])
#         if angle == 0:
#             rotated_image = image
#         elif angle == 90:
#             rotated_image = image.permute(0, 2, 1).flip(1)
#         elif angle == 180:
#             rotated_image = image.flip(1).flip(2)
#         elif angle == 270:
#             rotated_image = image.permute(0, 2, 1).flip(2)
#         rotated_images.append(rotated_image)
#         labels.append(angle // 90)
#     return torch.stack(rotated_images), torch.tensor(labels).cuda()
#
#
# # 自监督学习任务的损失函数
# def self_supervised_loss(model, images, args):
#     rotated_images, rotation_labels = rotate_images(images)
#     rotation_logits = model(rotated_images)
#     loss = F.cross_entropy(rotation_logits[:, :4], rotation_labels)
#     return loss


def supervised_contrastive_loss(features, labels, temperature=0.1):
    features = F.normalize(features, dim=1)
    batch_size = features.shape[0]
    labels = labels.contiguous().view(-1, 1)
    mask = torch.eq(labels, labels.T).float().cuda()
    contrast_count = features.shape[0]
    anchor_dot_contrast = torch.div(
        torch.matmul(features, features.T),
        temperature)
    logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
    logits = anchor_dot_contrast - logits_max.detach()
    mask = mask - torch.eye(batch_size).cuda()
    exp_logits = torch.exp(logits) * mask
    log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))
    mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1)
    loss = -mean_log_prob_pos
    loss = loss.mean()
    return loss


def fusion_aug_generate_label(y_a, y_b, session, args):
    current_total_cls_num = args.base_class + session * args.way
    if session == 0:  # base session -> increasing: [(args.base_class) * (args.base_class - 1)]/2
        assert y_a != y_b
        if y_a > y_b:  # make label y_a smaller than y_b
            y_a, y_b = y_b, y_a
        label_index = ((2 * current_total_cls_num - y_a - 1) * y_a) // 2 + (y_b - y_a) - 1
    else:  # incremental session -> increasing: [(args.way) * (args.way - 1)]/2
        y_a = y_a - (current_total_cls_num - args.way)
        y_b = y_b - (current_total_cls_num - args.way)
        assert y_a != y_b
        if y_a > y_b:  # make label y_a smaller than y_b
            y_a, y_b = y_b, y_a
        label_index = ((2 * args.way - y_a - 1) * y_a) // 2 + (y_b - y_a) - 1
    return int(label_index + current_total_cls_num)


def fusion_aug_one_image(x, y, session, args, alpha=20.0, mix_times=4):  # mixup based
    batch_size = x.size()[0]
    mix_data = []
    mix_target = []

    for _ in range(mix_times):
        index = torch.randperm(batch_size).cuda()
        for i in range(batch_size):
            if y[i] != y[index][i]:
                new_label = fusion_aug_generate_label(y[i].item(), y[index][i].item(), session, args)
                lam = np.random.beta(alpha, alpha)
                if lam < 0.4 or lam > 0.6:
                    lam = 0.5
                mix_data.append(lam * x[i] + (1 - lam) * x[index, :][i])
                mix_target.append(new_label)

    new_target = torch.tensor(mix_target, dtype=torch.long).cuda()
    y = torch.cat((y, new_target), 0)
    for item in mix_data:
        x = torch.cat((x, item.unsqueeze(0)), 0)

    return x, y


class CenterLoss(nn.Module):
    def __init__(self, num_classes, feat_dim, device='cuda'):
        super(CenterLoss, self).__init__()
        self.num_classes = num_classes
        self.feat_dim = feat_dim
        self.device = device
        self.centers = nn.Parameter(torch.randn(num_classes, feat_dim).to(device))

    def forward(self, x, labels):
        # Normalize the features and centers
        x = F.normalize(x, p=2, dim=1)
        centers = F.normalize(self.centers, p=2, dim=1)

        # Select the centers for the given labels
        expanded_centers = centers.index_select(0, labels)

        # Calculate cosine similarity and convert to cosine distance
        cosine_similarity = torch.sum(x * expanded_centers, dim=1)
        cosine_distance = 1 - cosine_similarity

        # Calculate the mean of the cosine distances as the loss
        loss = cosine_distance.mean()
        return loss


# def minimize_similarity_loss(features, batch_classes, all_classes, session, args, prototypes, temperature=0.1):
#     # features: (batch_size, feature_dim)
#     # batch_classes: list of class indices in the current batch
#     # all_classes: list of all class indices
#
#     # Ensure labels is a tensor
#     if isinstance(batch_classes, list):
#         labels = torch.tensor(batch_classes).to(features.device)
#
#     # Compute within-batch mean features
#     batch_means = []
#     unique_classes = list(set(labels.tolist()))
#     for cls in unique_classes:
#         cls_features = features[labels == cls]
#         cls_mean = cls_features.mean(dim=0)
#         batch_means.append(cls_mean)
#     batch_means = torch.stack(batch_means)  # (num_classes_in_batch, feature_dim)
#     # print(f'batch_means size: {batch_means.size()}')  # Debugging line
#
#     # Compute pseudo targets for unseen classes
#     unseen_classes = list(set(all_classes) - set(batch_classes))
#     if session > 0:
#         pseudo_targets = prototypes[:args.base_class + (session - 1) * args.way]
#     else:
#         pseudo_targets = torch.rand(len(unseen_classes), features.size(1)).to(features.device)
#
#     # Combine within-batch means, pseudo targets, and all features
#     combined_features = torch.cat([batch_means, pseudo_targets, features], dim=0)
#
#     # Compute pairwise cosine similarities
#     similarities = F.cosine_similarity(combined_features.unsqueeze(1), combined_features.unsqueeze(0), dim=-1)
#
#     # Compute orthogonality loss
#     log_similarities = F.log_softmax(similarities / temperature, dim=-1)
#     orthogonality_loss = -torch.mean(log_similarities)
#
#     return orthogonality_loss
def maximize_similarity_loss(features, batch_classes, all_classes, session, args, prototypes, temperature=0.1):
    # Ensure labels is a tensor
    if isinstance(batch_classes, list):
        labels = torch.tensor(batch_classes).to(features.device)
    else:
        labels = batch_classes

    # Compute within-batch mean features
    batch_means = []
    unique_classes = list(set(labels.tolist()))
    for cls in unique_classes:
        cls_features = features[labels == cls]
        cls_mean = cls_features.mean(dim=0)
        batch_means.append(cls_mean)
    batch_means = torch.stack(batch_means)  # (num_classes_in_batch, feature_dim)

    # Compute pseudo targets for unseen classes
    unseen_classes = list(set(all_classes) - set(batch_classes))
    if session > 0:
        pseudo_targets = prototypes[:args.base_class + (session - 1) * args.way]
    else:
        # pseudo_targets = torch.rand(len(unseen_classes), features.size(1)).to(features.device)
        pseudo_targets = prototypes[unseen_classes].to(features.device)

    # Combine within-batch means, pseudo targets, and all features
    combined_features = torch.cat([batch_means, pseudo_targets, features], dim=0)
    # combined_features = torch.cat([batch_means, features], dim=0)

    # Compute pairwise cosine similarities
    similarities = F.cosine_similarity(combined_features.unsqueeze(1), combined_features.unsqueeze(0), dim=-1)

    # # Create a mask to exclude self-similarities
    mask = torch.ones_like(similarities, device=similarities.device)
    mask.fill_diagonal_(0)

    # Compute similarity loss to maximize similarities
    log_similarities = F.log_softmax(similarities / temperature, dim=-1)*mask
    # masked_log_similarities = log_similarities * mask
    similarity_loss = -torch.mean(log_similarities)

    return similarity_loss


def classwise_similarity_loss(features, batch_classes, all_classes, session, args, prototypes, temperature=0.1):
    # Ensure labels is a tensor
    if isinstance(batch_classes, list):
        labels = torch.tensor(batch_classes).to(features.device)
    else:
        labels = batch_classes

    # Compute within-batch mean features
    unique_classes = list(set(labels.tolist()))
    batch_means = torch.stack([features[labels == cls].mean(dim=0) for cls in unique_classes])

    # Expand labels to match combined features
    # Expand labels to match combined features
    extended_labels = torch.cat([torch.tensor(unique_classes, device=features.device), labels],
                                dim=0)  # Combine with original labels

    # Combine batch means with original features
    combined_features = torch.cat([batch_means, features], dim=0)  # (num_classes_in_batch + num_samples, feature_dim)

    # Compute pairwise cosine similarities
    similarities = F.cosine_similarity(combined_features.unsqueeze(1), combined_features.unsqueeze(0), dim=-1)

    # Create masks for class separation and self-similarity exclusion
    positive_mask = torch.zeros_like(similarities, device=features.device).bool()
    for i, cls in enumerate(unique_classes):
        positive_mask[i, extended_labels == cls] = True
        positive_mask[extended_labels == cls, i] = True

    negative_mask = ~positive_mask  # For class mean dissimilarity

    # Exclude self-similarities (diagonal elements)
    identity_mask = torch.eye(combined_features.size(0), device=features.device).bool()
    positive_mask &= ~identity_mask
    negative_mask &= ~identity_mask

    # Compute similarity loss
    positive_similarities = similarities[positive_mask] / temperature
    negative_similarities = similarities[negative_mask] / temperature

    positive_loss = -torch.mean(F.log_softmax(positive_similarities, dim=-1))
    negative_loss = torch.mean(F.log_softmax(negative_similarities, dim=-1))

    similarity_loss = positive_loss + negative_loss

    return similarity_loss


def base_pretrain(model, trainloader, train_set, optimizer, scheduler, epoch, args, session):
    tl = Averager()
    tmp_prototypes = calculate_class_prototypes(train_set, session, trainloader.dataset.transform,
                                                model,
                                                args)
    ta = Averager()
    model = model.train()
    # standard classification for pretrain
    tqdm_gen = tqdm(trainloader)
    for i, batch in enumerate(tqdm_gen, 1):
        data, train_label = [_.cuda() for _ in batch]
        model.module.mode = 'encoder'
        features = model(data)
        batch_classes = train_label.tolist()
        all_classes = list(range(args.base_class))
        ortho_loss = maximize_similarity_loss(features, batch_classes, all_classes, session, args,
                                              prototypes=tmp_prototypes)
        model.module.mode = args.base_mode
        logits = model(data)
        logits = logits[:, :args.base_class]

        loss = F.cross_entropy(logits, train_label)+args.balance * ortho_loss
        acc = count_acc(logits, train_label)

        total_loss = loss

        lrc = scheduler.get_last_lr()[0]
        tqdm_gen.set_description(
            'Session 0, epo {}, lrc={:.4f},total loss={:.4f} acc={:.4f}'.format(epoch, lrc, total_loss.item(), acc))
        tl.add(total_loss.item())
        ta.add(acc)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    tl = tl.item()
    ta = ta.item()
    return tl, ta


def base_train(model, trainloader, train_set, optimizer, scheduler, epoch, args, session):
    tl = Averager()
    ta = Averager()
    mode_context.set('normal')
    model.module.train_backbone()
    model.module.mode = args.base_mode
    tmp_prototypes = calculate_class_prototypes(train_set, session, trainloader.dataset.transform,
                                                model,
                                                args)
    angular_loss = AngularPenaltySMLoss(args)
    model = model.train()
    # standard classification for pretrain
    tqdm_gen = tqdm(trainloader)
    if args.class_relation == 'None':
        class_relations = None
    else:
        if args.class_relation == 'feat':
            class_protos = class_protos = model.module.fc.weight.detach()  # [num_all_classes, c_b]
        else:
            assert args.class_relation == 'wg'
            class_protos = model.module.fc.weight.detach()  # [num_all_classes, c_b]

        class_protos = class_protos[:args.base_class]
        class_protos = F.normalize(class_protos, dim=-1)  # [num_classes, c_b]
        class_relations = torch.mm(class_protos, class_protos.t())  # [num_classes, num_classes]

        average_relation = (torch.sum(class_relations) - class_relations.shape[0]) / (
                class_relations.shape[0] * class_relations.shape[1] - class_relations.shape[0])

    if args.in_domain_feat_cls_weight != 0.0:
        if args.in_domain_class_relation == 'None':
            in_domain_class_relations = None
        else:
            if args.in_domain_class_relation == 'feat':
                in_domain_class_protos = model.in_domain_fc.weight.detach()  # [num_all_classes, c_d]
            else:
                assert args.in_domain_class_relation == 'wg'
                in_domain_class_protos = model.in_domain_fc_base.weight.detach()  # [num_all_classes, c_d]

            in_domain_class_protos = in_domain_class_protos[:args.base_class]
            in_domain_class_protos = F.normalize(in_domain_class_protos, dim=-1)  # [num_classes, c_d]
            in_domain_class_relations = torch.mm(in_domain_class_protos,
                                                 in_domain_class_protos.t())  # [num_classes, num_classes]

            in_domain_average_relation = (torch.sum(in_domain_class_relations) - in_domain_class_relations.shape[0]) / (
                    in_domain_class_relations.shape[0] * in_domain_class_relations.shape[1] -
                    in_domain_class_relations.shape[0])
    for i, batch in enumerate(tqdm_gen, 1):
        data, train_label = [_.cuda() for _ in batch]

        # with torch.no_grad():
        #     for param in model.module.encoder.parameters():
        #         if param.requires_grad and param.grad is not None:
        #             # Calculate noise
        #             noise = torch.randn_like(param) * args.noise_scale
        #             # Calculate parameter norm
        #             param_norm = param.norm()
        #             # Clip noise, to not exceed 10% of parameter norm
        #             noise = torch.clamp(noise, -0.1 * param_norm, 0.1 * param_norm)
        #             # Add noise to parameters
        #             param.data = param.data + noise

        # logits = model(data)
        # logits = logits[:, :args.base_class]
        # loss = F.cross_entropy(logits, train_label)
        # # loss = custom_loss(logits, train_label)
        # acc = count_acc(logits, train_label)
        #
        # total_loss = loss
        #
        # optimizer.zero_grad()
        # loss.backward()
        # 正向传播和计算损失

        optimizer.zero_grad()
        model.module.mode = 'encoder'
        features = model(data)

        # sup_loss = supervised_contrastive_loss(features, train_label)
        # pseudo_labels = int(args.base_class * (args.base_class - 1) / 2)

        model.module.mode = args.base_mode
        # logits = model(data)
        # logits = logits[:, :args.base_class]
        # 计算Orthogonality Loss
        batch_classes = train_label.tolist()
        all_classes = list(range(args.base_class))
        ortho_loss = maximize_similarity_loss(features, batch_classes, all_classes, session, args, prototypes=tmp_prototypes)
        #
        logits = model(data)
        logits = logits[:, :args.base_class]
        # # loss = F.cross_entropy(logits[:, :args.base_class], train_label)
        #

        #
        loss = F.cross_entropy(logits, train_label) + args.balance * ortho_loss
        # # loss = label_smoothing_loss(logits, train_label,args)
        # loss = angular_loss(logits, train_label)
        # # loss = custom_loss(logits[:, :args.base_class], train_label)

        loss.backward()  # 计算梯度
        #
        # # 计算扰动并应用到参数
        with torch.no_grad():
            for name, module in model.named_modules():
                if isinstance(module, BasicBlock) and module.include_adapter:
                    # if isinstance(module, BasicBlock):
                    for param_name, param in module.named_parameters():
                        if 'conv1.weight' in param_name or 'conv2.weight' in param_name or 'conv3.weight' in param_name:
                            norm = param.grad.norm()
                            if norm != 0 and not torch.isnan(norm):
                                perturbation = args.rho * param.grad / norm
                                # perturbation = (args.rho * param.grad / norm) * (param ** 2 / (param ** 2).norm())
                                param.add_(perturbation)
                elif isinstance(module, torch.nn.Linear) and name == "fc":
                    for param_name, param in module.named_parameters():
                        if 'weight' in param_name:  # You can also check for 'bias' if you want to perturb it
                            norm = param.grad.norm()
                            if norm != 0 and not torch.isnan(norm):
                                perturbation = args.rho * param.grad / norm
                                # perturbation = (args.rho * param.grad / norm) * (param ** 2 / (param ** 2).norm())
                                param.add_(perturbation)

        # 重新计算扰动后的损失并进行第二次反向传播
        optimizer.zero_grad()
        # data, train_label = fusion_aug_one_image(data, train_label, session, args, alpha=20.0, mix_times=1)
        model.module.mode = 'encoder'
        features = model(data)

        # sup_loss = supervised_contrastive_loss(features, train_label)
        # pseudo_labels = int(args.base_class * (args.base_class - 1) / 2)

        model.module.mode = args.base_mode
        logits = model(data)
        logits = logits[:, :args.base_class]
        # 计算Orthogonality Loss
        batch_classes = train_label.tolist()
        all_classes = list(range(args.base_class))
        ortho_loss = maximize_similarity_loss(features, batch_classes, all_classes, session, args, prototypes=tmp_prototypes)

        # if args.cosMargin != 0.0:
        #     label_mask = F.one_hot(train_label, args.base_class)
        #     if class_relations is None:
        #         logits = logits - label_mask * args.cosMargin * args.temperature
        #     else:
        #         label_relations = torch.mm(label_mask.float(), class_relations)  # [b, num_classes]
        #         lower_bound = args.cosMargin
        #         average = args.average_cosMargin  # -0.3
        #         adj_cosMargin = average + (lower_bound - average) / (1.0 - average_relation) * (
        #                 label_relations - average_relation)
        #         logits = logits + (1 - label_mask) * adj_cosMargin * args.temperature  # [b, num_classes]

        perturbed_loss = F.cross_entropy(logits, train_label) + args.balance * ortho_loss
        # perturbed_loss = angular_loss(logits, train_label)
        # if args.in_domain_feat_cls_weight != 0.0:
        #     backbone_feat = model.end_points['final_feature']
        #     in_domain_feat = model.in_domain_forward(backbone_feat)
        #
        #     if model.in_domain_dropout_fn is None:
        #         in_domain_logits = F.linear(F.normalize(in_domain_feat, p=2, dim=-1), F.normalize(
        #             model.in_domain_fc.weight if is_base == False else model.in_domain_fc_base.weight, p=2, dim=-1))
        #     else:
        #         in_domain_logits = F.linear(model.in_domain_dropout_fn(F.normalize(in_domain_feat, p=2, dim=-1)),
        #                                     F.normalize(
        #                                         model.in_domain_fc.weight if is_base == False else model.in_domain_fc_base.weight,
        #                                         p=2, dim=-1))
        #
        #     in_domain_logits = in_domain_logits[:, :args.base_class]
        #     in_domain_logits = args.temperature * in_domain_logits
        #
        #     if args.in_domain_feat_cosMargin != 0.0:
        #         label_mask = F.one_hot(train_label, args.base_class)
        #         if in_domain_class_relations is None:
        #             in_domain_logits = in_domain_logits - label_mask * args.in_domain_feat_cosMargin * args.temperature
        #         else:
        #             in_domain_label_relations = torch.mm(label_mask.float(),
        #                                                  in_domain_class_relations)  # [b, num_classes]
        #             in_domain_lower_bound = args.in_domain_feat_cosMargin
        #             in_domain_average = args.in_domain_average_cosMargin  # -0.3
        #             in_domain_adj_cosMargin = in_domain_average + (in_domain_lower_bound - in_domain_average) / (
        #                         1.0 - in_domain_average_relation) * (
        #                                                   in_domain_label_relations - in_domain_average_relation)
        #             in_domain_logits = in_domain_logits + (
        #                         1 - label_mask) * in_domain_adj_cosMargin * args.temperature  # [b, num_classes]
        #
        #     in_domain_loss = F.cross_entropy(in_domain_logits, train_label)
        #     in_domain_acc = count_acc(in_domain_logits, train_label)
        # perturbed_loss = F.cross_entropy(logits[:, :args.base_class], train_label)
        # perturbed_loss = label_smoothing_loss(logits, train_label, args)
        # perturbed_loss = angularPenaltySMLoss(logits[:, :args.base_class], train_label)
        # perturbed_loss = custom_loss(logits[:, :args.base_class], train_label)

        perturbed_loss.backward()
        # 恢复原始参数
        with torch.no_grad():
            for name, module in model.named_modules():
                if isinstance(module, BasicBlock) and module.include_adapter:
                    # if isinstance(module, BasicBlock):
                    for param_name, param in module.named_parameters():
                        if 'conv1.weight' in param_name or 'conv2.weight' in param_name or 'conv3.weight' in param_name:
                            norm = param.grad.norm()
                            if norm != 0 and not torch.isnan(norm):
                                perturbation = args.rho * param.grad / norm
                                # perturbation = (args.rho * param.grad / norm) * (param ** 2 / (param ** 2).norm())
                                param.sub_(perturbation)
                elif isinstance(module, torch.nn.Linear) and name == "fc":
                    for param_name, param in module.named_parameters():
                        if 'weight' in param_name:  # You can also check for 'bias' if you want to perturb it
                            norm = param.grad.norm()
                            if norm != 0 and not torch.isnan(norm):
                                perturbation = args.rho * param.grad / norm
                                # perturbation = (args.rho * param.grad / norm) * (param ** 2 / (param ** 2).norm())
                                param.sub_(perturbation)
        optimizer.step()
        acc = count_acc(logits, train_label)

        total_loss = perturbed_loss

        # with torch.no_grad():
        #     for name, module in model.named_modules():
        #         if isinstance(module, BasicBlock) and module.include_adapter:
        #             for param_name, param in module.named_parameters():
        #                 if 'conv1.weight' in param_name or 'conv2.weight' in param_name:
        #                     grad = param.grad
        #                     if grad is not None:
        #                         grad_norm = grad.norm()
        #                         if grad_norm != 0 and not torch.isnan(grad_norm):
        #                             # Normalize the gradient to get the direction
        #                             perturbation_direction = grad / grad_norm
        #                             # Scale the direction by rho to get the perturbation
        #                             perturbation = perturbation_direction * args.rho
        #                             # Apply perturbation
        #                             param.data.add_(perturbation)
        #             # Add perturbation to the fc layer
        #         if isinstance(module, torch.nn.Linear) and name == "fc":
        #             for param_name, param in module.named_parameters():
        #                 if 'weight' in param_name:  # You can also check for 'bias' if you want to perturb it
        #                     grad = param.grad
        #                     if grad is not None:
        #                         grad_norm = grad.norm()
        #                         if grad_norm != 0 and not torch.isnan(grad_norm):
        #                             # Normalize the gradient to get the direction
        #                             perturbation_direction = grad / grad_norm
        #                             # Scale the direction by rho to get the perturbation
        #                             perturbation = perturbation_direction * args.rho
        #                             # Apply perturbation
        #                             param.data.add_(perturbation)

        # # Add noise to the gradients if they exist
        # if epoch > 200:
        #     with torch.no_grad():
        #         for param in model.parameters():
        #             if param.requires_grad and param.grad is not None:
        #                 # 计算噪声
        #                 noise = torch.randn_like(param.grad) * args.noise_scale
        #                 # 计算梯度的范数
        #                 grad_norm = param.grad.norm()
        #                 # 裁剪噪声，使其不超过梯度范数的10%
        #                 noise = torch.clamp(noise, -0.1 * grad_norm, 0.1 * grad_norm)
        #                 # 添加噪声
        #                 param.grad += noise
        # optimizer.step()
        lrc = scheduler.get_last_lr()[0]
        tqdm_gen.set_description(
            'Session 0, epo {}, lrc={:.4f},total loss={:.4f} acc={:.4f}'.format(epoch, lrc, total_loss.item(), acc))
        tl.add(total_loss.item())
        ta.add(acc)

        # add_noise(model, scale=args.noise_scale)
    # add_noise(model, scale=args.noise_scale)
    tl = tl.item()
    ta = ta.item()
    return tl, ta


def self_supervised_contrastive_loss(features, temperature=0.07):
    # 规范化特征
    features = F.normalize(features, dim=1)
    batch_size = features.shape[0]

    # 创建掩码用于识别正样本（相同样本的不同视图）
    labels = torch.arange(batch_size // 3).repeat(3)
    labels = labels.contiguous().view(-1, 1)

    # 创建掩码矩阵
    mask = torch.eq(labels, labels.T).float().cuda()

    # 计算点积并除以温度
    anchor_dot_contrast = torch.div(torch.matmul(features, features.T), temperature)

    # 获取每行的最大值以便数值稳定性
    logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
    logits = anchor_dot_contrast - logits_max.detach()

    # 掩码去掉对角线上的值
    mask = mask - torch.eye(batch_size).cuda()

    # 计算分母
    exp_logits = torch.exp(logits) * mask
    log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True) + 1e-9)

    # 计算每个样本的对数概率平均值
    mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1)

    # 损失是负的平均对数概率
    loss = -mean_log_prob_pos
    loss = loss.mean()

    return loss


# def base_train(model, trainloader, optimizer, scheduler, epoch, args, mask):
#     tl = Averager()
#     ta = Averager()
#     mode_context.set('normal')
#     model.module.train_backbone()
#     model.module.mode = args.base_mode
#     angularPenaltySMLoss = AngularPenaltySMLoss(args)
#     model = model.train()
#     # standard classification for pretrain
#     # _, sup_trainloader, _ = data_utils.get_supcon_dataloader(args)
#     tqdm_gen = tqdm(trainloader)
#     # tqdm_gen = tqdm(sup_trainloader)
#     # _, sup_trainloader, _ = data_utils.get_supcon_dataloader(args)
#     # sc_criterion = supcon.SupConLoss()
#     for i, batch in enumerate(tqdm_gen, 1):
#         # images, label = batch
#         # images = torch.cat([images[0], images[1]], dim=0)
#         # if torch.cuda.is_available():
#         #     images = images.cuda(non_blocking=True)
#         #     label = label.cuda(non_blocking=True)
#         beta = torch.distributions.beta.Beta(args.alpha, args.alpha).sample([]).item()
#         data, train_label = [_ for _ in batch]
#         data[0] = data[0].cuda(non_blocking=True)
#         data[1] = data[1].cuda(non_blocking=True)
#         data[2] = data[2].cuda(non_blocking=True)
#         train_label = train_label.cuda(non_blocking=True)
#
#         # optimizer.zero_grad()
#         #
#         # logits = model(data[0])
#         # loss = F.cross_entropy(logits[:, :args.base_class], train_label)
#         # # loss = label_smoothing_loss(logits, train_label,args)
#         # # loss = angularPenaltySMLoss(logits[:, :args.base_class], train_label)
#         # # loss = custom_loss(logits[:, :args.base_class], train_label)
#         #
#         # loss.backward()  # 计算梯度
#         #
#         # # 计算扰动并应用到参数
#         # with torch.no_grad():
#         #     for name, module in model.named_modules():
#         #         # if isinstance(module, BasicBlock) and module.include_adapter:
#         #         if isinstance(module, BasicBlock):
#         #             for param_name, param in module.named_parameters():
#         #                 if 'conv1.weight' in param_name or 'conv2.weight' in param_name:
#         #                     norm = param.grad.norm()
#         #                     if norm != 0 and not torch.isnan(norm):
#         #                         perturbation = args.rho * param.grad / norm
#         #                         param.add_(perturbation)
#                 # elif isinstance(module, torch.nn.Linear) and name == "fc":
#                 #     for param_name, param in module.named_parameters():
#                 #         if 'weight' in param_name:  # You can also check for 'bias' if you want to perturb it
#                 #             norm = param.grad.norm()
#                 #             if norm != 0 and not torch.isnan(norm):
#                 #                 perturbation = args.rho * param.grad / norm
#                 #                 param.add_(perturbation)
#
#         # 重新计算扰动后的损失并进行第二次反向传播
#         optimizer.zero_grad()
#         model.module.mode = 'encoder'
#         # Get the features from the model for both augmented images
#         feature_original = model(data[0])
#         features1 = model(data[1])
#         features2 = model(data[2])
#         labels = torch.cat(
#             [train_label, train_label, train_label],
#             dim=0)
#
#         # Concatenate the features from the augmented images and the prototypes
#         features = torch.cat([feature_original, features1, features2], dim=0)
#
#         sup_loss = supervised_contrastive_loss(features, labels)
#         self_sup_loss = self_supervised_contrastive_loss(features)
#
#         model.module.mode = args.base_mode
#         logits = model(data[0])
#
#         perturbed_loss = F.cross_entropy(logits[:, :args.base_class], train_label) + (
#                     1 - args.alpha) * self_sup_loss + args.alpha * sup_loss
#         # if epoch >= args.loss_iter:
#         #     logits_masked = logits.masked_fill(F.one_hot(train_label, num_classes=model.module.pre_allocate) == 1, -1e9)
#         #     logits_masked_chosen = logits_masked * mask[train_label]
#         #     pseudo_label = torch.argmax(logits_masked_chosen[:, args.base_class:], dim=-1) + args.base_class
#         #     # pseudo_label = torch.argmax(logits_masked[:,args.base_class:], dim=-1) + args.base_class
#         #     loss2 = F.cross_entropy(logits_masked, pseudo_label)
#         #
#         #     index = torch.randperm(data.size(0)).cuda()
#         #     pre_emb1 = model.module.pre_encode(data)
#         #     mixed_data = beta * pre_emb1 + (1 - beta) * pre_emb1[index]
#         #     mixed_logits = model.module.post_encode(mixed_data)
#         #
#         #     newys = train_label[index]
#         #     idx_chosen = newys != train_label
#         #     mixed_logits = mixed_logits[idx_chosen]
#         #
#         #     pseudo_label1 = torch.argmax(mixed_logits[:, args.base_class:], dim=-1) + args.base_class  # new class label
#         #     pseudo_label2 = torch.argmax(mixed_logits[:, :args.base_class], dim=-1)  # old class label
#         #     loss3 = F.cross_entropy(mixed_logits, pseudo_label1)
#         #     novel_logits_masked = mixed_logits.masked_fill(
#         #         F.one_hot(pseudo_label1, num_classes=model.module.pre_allocate) == 1, -1e9)
#         #     loss4 = F.cross_entropy(novel_logits_masked, pseudo_label2)
#         #     perturbed_loss = perturbed_loss + args.balance * (loss2 + loss3 + loss4)
#         # else:
#         #     perturbed_loss = perturbed_loss
#         # perturbed_loss = F.cross_entropy(logits[:, :args.base_class], train_label)
#         # perturbed_loss = label_smoothing_loss(logits, train_label, args)
#         # perturbed_loss = angularPenaltySMLoss(logits[:, :args.base_class], train_label)
#         # perturbed_loss = custom_loss(logits[:, :args.base_class], train_label)
#
#         perturbed_loss.backward()
#         # 恢复原始参数
#         # with torch.no_grad():
#         #     for name, module in model.named_modules():
#         #         # if isinstance(module, BasicBlock) and module.include_adapter:
#         #         if isinstance(module, BasicBlock):
#         #             for param_name, param in module.named_parameters():
#         #                 if 'conv1.weight' in param_name or 'conv2.weight' in param_name:
#         #                     norm = param.grad.norm()
#         #                     if norm != 0 and not torch.isnan(norm):
#         #                         perturbation = args.rho * param.grad / norm
#         #                         param.sub_(perturbation)
#                 # elif isinstance(module, torch.nn.Linear) and name == "fc":
#                 #     for param_name, param in module.named_parameters():
#                 #         if 'weight' in param_name:  # You can also check for 'bias' if you want to perturb it
#                 #             norm = param.grad.norm()
#                 #             if norm != 0 and not torch.isnan(norm):
#                 #                 perturbation = args.rho * param.grad / norm
#                 #                 param.sub_(perturbation)
#         optimizer.step()
#         acc = count_acc(logits, train_label)
#
#         total_loss = perturbed_loss
#
#         lrc = scheduler.get_last_lr()[0]
#         tqdm_gen.set_description(
#             'Session 0, epo {}, lrc={:.4f},total loss={:.4f} acc={:.4f}'.format(epoch, lrc, total_loss.item(), acc))
#         tl.add(total_loss.item())
#         ta.add(acc)
#
#         # add_noise(model, scale=args.noise_scale)
#     # add_noise(model, scale=args.noise_scale)
#     tl = tl.item()
#     ta = ta.item()
#     return tl, ta


def supcon_train(model, trainloader, optimizer, scheduler, epoch, args, mask):
    tl = Averager()
    ta = Averager()
    mode_context.set('normal')
    model.module.train_backbone()
    model.module.mode = args.base_mode
    angularPenaltySMLoss = AngularPenaltySMLoss(args)

    model = model.train()
    # standard classification for pretrain
    # _, sup_trainloader, _ = data_utils.get_supcon_dataloader(args)
    tqdm_gen = tqdm(trainloader)
    # tqdm_gen = tqdm(sup_trainloader)
    # _, sup_trainloader, _ = data_utils.get_supcon_dataloader(args)
    # sc_criterion = supcon.SupConLoss()
    for i, batch in enumerate(tqdm_gen, 1):
        # images, label = batch
        # images = torch.cat([images[0], images[1]], dim=0)
        # if torch.cuda.is_available():
        #     images = images.cuda(non_blocking=True)
        #     label = label.cuda(non_blocking=True)
        beta = torch.distributions.beta.Beta(args.alpha, args.alpha).sample([]).item()
        data, train_label = [_ for _ in batch]
        data[0] = data[0].cuda(non_blocking=True)
        data[1] = data[1].cuda(non_blocking=True)
        data[2] = data[2].cuda(non_blocking=True)
        train_label = train_label.cuda(non_blocking=True)

        optimizer.zero_grad()

        # 重新计算扰动后的损失并进行第二次反向传播
        optimizer.zero_grad()
        model.module.mode = 'encoder'
        # Get the features from the model for both augmented images
        feature_original = model(data[0])
        features1 = model(data[1])
        features2 = model(data[2])

        # Concatenate the features from the augmented images and the prototypes
        features = torch.cat([feature_original, features1, features2], dim=0)
        labels = torch.cat(
            [train_label, train_label, train_label],
            dim=0)
        sup_loss = supervised_contrastive_loss(features, labels)

        model.module.mode = args.base_mode
        logits = model(data[0])

        perturbed_loss = sup_loss

        perturbed_loss.backward()
        # 恢复原始参数
        with torch.no_grad():
            for name, module in model.named_modules():
                # if isinstance(module, BasicBlock) and module.include_adapter:
                if isinstance(module, BasicBlock):
                    for param_name, param in module.named_parameters():
                        if 'conv1.weight' in param_name or 'conv2.weight' in param_name:
                            norm = param.grad.norm()
                            if norm != 0 and not torch.isnan(norm):
                                perturbation = args.rho * param.grad / norm
                                param.sub_(perturbation)
                # elif isinstance(module, torch.nn.Linear) and name == "fc":
                #     for param_name, param in module.named_parameters():
                #         if 'weight' in param_name:  # You can also check for 'bias' if you want to perturb it
                #             norm = param.grad.norm()
                #             if norm != 0 and not torch.isnan(norm):
                #                 perturbation = args.rho * param.grad / norm
                #                 param.sub_(perturbation)
        optimizer.step()
        acc = count_acc(logits, train_label)

        total_loss = sup_loss

        lrc = scheduler.get_last_lr()[0]
        tqdm_gen.set_description(
            'Session 0, epo {}, lrc={:.4f},total loss={:.4f} acc={:.4f}'.format(epoch, lrc, total_loss.item(), acc))
        tl.add(total_loss.item())
        ta.add(acc)

        # add_noise(model, scale=args.noise_scale)
    # add_noise(model, scale=args.noise_scale)
    tl = tl.item()
    ta = ta.item()
    return tl, ta


def forward_with_temp_params(model, x, temp_state_dict, session, args):
    # 保存原始参数
    original_state_dict = {name: param.clone() for name, param in model.named_parameters()}

    # 将参数设置为临时参数
    temp_params = []
    try:
        with torch.no_grad():
            for name, param in model.named_parameters():
                if name in temp_state_dict:
                    # 创建一个新的参数并替换原始模型参数
                    original_param = param.data
                    param.data = temp_state_dict[name].data
                    temp_params.append((param, original_param))

        outputs = model(x)
        outputs = outputs[:, :(session + 1) * args.fake_way]

    finally:
        # 恢复原始参数
        with torch.no_grad():
            for param, original_param in temp_params:
                param.data = original_param

    return outputs


def meta_train_reptile(model, pretrain_loader, postrain_loader, train_loader, train_set, test_loader, support_loader,
                       query_loader,
                       args):
    best_acc = 0  # Fix the comma issue
    best_model_state = None
    seen_labels = set()
    # Assuming your model has these methods/attributes

    if hasattr(model.module, 'encoder') and hasattr(model.module.encoder, 'train_normal'):
        model.module.encoder.train_normal()
    if hasattr(model.module, 'mode'):
        model.module.mode = args.base_mode
    # Set the mode
    mode_context.set('parallel_adapters')
    model.train()
    model.module.train_adapter()

    # Pretrain stage with progress
    meta_loss_avg = Averager()  # Assuming Averager is defined elsewhere
    meta_acc_avg = Averager()

    # optimizer, _ = get_optimizer(args, model)
    # # Start of pre-training
    # for epoch in range(args.epochs_pretrain):
    #     pretrain_progress = tqdm(pretrain_loader, desc=f'Pretrain Epoch {epoch}')
    #     losses = []
    #     for data, labels in pretrain_progress:
    #         seen_labels.update(labels.tolist())
    #         data, labels = data.cuda(), labels.cuda()
    #         optimizer.zero_grad()
    #         outputs = model(data)
    #         outputs = outputs[:, :args.fake_base_class]
    #         # outputs = outputs[:, :args.base_class]
    #         loss = F.cross_entropy(outputs, labels)
    #         loss.backward()
    #         optimizer.step()
    #         losses.append(loss.item())
    #
    #     avg_loss = sum(losses) / len(losses)
    #
    #     # Switch to eval mode for model testing
    #     model.eval()
    #     match_accs = []
    #     with torch.no_grad():
    #         for test_data, test_labels in test_loader:
    #             # 假设 test_data 和 test_labels 已经在 GPU 上
    #             # 直接在 GPU 上创建 mask 以避免设备不一致问题
    #             mask = torch.tensor([label.item() in seen_labels for label in test_labels], device=test_data.device)
    #             if not any(mask):
    #                 continue
    #             test_data, test_labels = test_data[mask], test_labels[mask]
    #             pred_labels = model(test_data).argmax(dim=1)
    #             match_acc = accuracy_score(pred_labels.cpu().numpy(), test_labels.cpu().numpy())
    #             match_accs.append(match_acc)
    #
    #     avg_acc = sum(match_accs) / len(match_accs) if match_accs else 0
    #     print(f'Epoch {epoch}, Average Loss: {avg_loss}, Average Accuracy: {avg_acc}')
    #
    #     # Switch back to train mode for next epoch
    #     model.train()
    #
    #     if avg_acc > best_acc:
    #         best_acc = avg_acc
    #         best_model_state = deepcopy(model.state_dict())
    #
    # if best_model_state is not None:
    #     model.load_state_dict(best_model_state)

    # # Assuming your model has these methods/attributes
    # if hasattr(model.module, 'encoder') and hasattr(model.module.encoder, 'train_sl_block'):
    #     model.module.encoder.train_sl_block()
    # if hasattr(model.module, 'mode'):
    #     model.module.mode = args.base_mode
    # optimizer, _ = get_optimizer(args, model)
    # # Start of pre-training
    # for epoch in range(args.epochs_postrain):
    #     postrain_progress = tqdm(postrain_loader, desc=f'Postrain Epoch {epoch}')
    #     losses = []
    #     for data, labels in postrain_progress:
    #         seen_labels.update(labels.tolist())
    #         data, labels = data.cuda(), labels.cuda()
    #         optimizer.zero_grad()
    #         outputs = model(data)
    #         outputs = outputs[:, :args.base_class]
    #         # outputs = outputs[:, :args.base_class]
    #         loss = F.cross_entropy(outputs, labels)
    #         loss.backward()
    #         optimizer.step()
    #         losses.append(loss.item())
    #
    #     avg_loss = sum(losses) / len(losses)
    #
    #     # Switch to eval mode for model testing
    #     model.eval()
    #     match_accs = []
    #     with torch.no_grad():
    #         for test_data, test_labels in test_loader:
    #             # 假设 test_data 和 test_labels 已经在 GPU 上
    #             # 直接在 GPU 上创建 mask 以避免设备不一致问题
    #             mask = torch.tensor([label.item() in seen_labels for label in test_labels], device=test_data.device)
    #             if not any(mask):
    #                 continue
    #             test_data, test_labels = test_data[mask], test_labels[mask]
    #             pred_labels = model(test_data).argmax(dim=1)
    #             match_acc = accuracy_score(pred_labels.cpu().numpy(), test_labels.cpu().numpy())
    #             match_accs.append(match_acc)
    #
    #     avg_acc = sum(match_accs) / len(match_accs) if match_accs else 0
    #     print(f'Epoch {epoch}, Average Loss: {avg_loss}, Average Accuracy: {avg_acc}')
    #
    #     # Switch back to train mode for next epoch
    #     model.train()
    #
    #     if avg_acc > best_acc:
    #         best_acc = avg_acc
    #         best_model_state = deepcopy(model.state_dict())
    #
    # if best_model_state is not None:
    #     model.load_state_dict(best_model_state)

    # model.module.encoder.train_all()

    # Assuming your model has these methods/attributes
    if hasattr(model.module, 'encoder') and hasattr(model.module.encoder, 'train_meta_block'):
        model.module.encoder.train_meta_block()
    if hasattr(model.module, 'mode'):
        model.module.mode = args.base_mode

    # Meta-learning stage with progress bar and task counter
    optimizer, _ = get_optimizer_meta(args, model)
    # freeze_resnet_layers(model)
    if args.dataset == 'cub200':
        # if args.dataset == 'cub200' or args.dataset == 'mini_imagenet':
        with torch.no_grad():  # 确保这一操作不会被跟踪用于梯度计算
            for name, param in model.named_parameters():
                if 'adapter' in name:
                    # 初始化为非常小的随机数
                    # param.uniform_(-0.01, 0.01)  # 使用较小范围的均匀分布进行初始化
                    param.zero_()  # 将参数直接设置为零

    selected_combinations = None
    best_loss = float('inf')
    no_improve_epochs = 0  # 用于早停的无改进计数器
    patience = 10  # 早停的耐心阈值
    for epoch in range(args.epochs_meta):
        support_loader, query_loader = get_meta_data(train_set, np.arange(args.base_class), args, seed=epoch)
        model.train()
        initial_state_dict = {name: param.clone() for name, param in model.named_parameters()}
        update_direction = {name: torch.zeros_like(param) for name, param in model.named_parameters()}

        meta_loss_avg = Averager()
        meta_acc_avg = Averager()
        support_progress = tqdm(enumerate(support_loader, 1), total=len(support_loader),
                                desc=f'Epoch {epoch + 1}/{args.epochs_meta}')

        for task_idx, (task_data, task_labels) in support_progress:
            task_data, task_labels = task_data.cuda(), task_labels.cuda()
            # task_data, task_labels = fusion_aug_one_image(task_data, task_labels, 0, args, alpha=20.0, mix_times=2)

            pseudo_labels = int(args.base_class * (args.base_class - 1) / 2)
            optimizer.zero_grad()
            for _ in range(args.fast_adaptation_steps):
                task_outputs = model(task_data)
                # adjust_labels = task_labels - (task_idx - 1) * args.fake_way
                # task_outputs = task_outputs[:, (task_idx - 1) * args.fake_way:task_idx * args.fake_way]
                adjust_labels = task_labels
                # task_outputs = task_outputs[:, :task_idx * args.fake_way]
                task_outputs = task_outputs[:,:args.base_class]
                # adjust_labels = task_labels - (task_idx - 1) * args.fake_way - args.fake_base_class
                # task_outputs = task_outputs[:, args.fake_base_class + (
                #         task_idx - 1) * args.fake_way:args.fake_base_class + task_idx * args.fake_way]
                task_loss = F.cross_entropy(task_outputs, adjust_labels)
                task_loss.backward()
                optimizer.step()
                meta_loss_avg.add(task_loss.item())
                _, predicted = torch.max(task_outputs.data, 1)
                correct = (predicted == adjust_labels).sum().item()
                accuracy = correct / task_labels.size(0)
                meta_acc_avg.add(accuracy)

            support_progress.set_description(
                f'Epoch {epoch + 1}/{args.epochs_meta} - Task {task_idx} - Loss: {meta_loss_avg.item():.4f} - Acc: {meta_acc_avg.item():.4f}')

        current_loss = meta_loss_avg.item()
        print(
            f'\nEpoch {epoch + 1}/{args.epochs_meta} - Average Loss: {current_loss:.4f} - Average Acc: {meta_acc_avg.item():.4f}')

        # 检查是否达到了新的最佳损失
        if current_loss < best_loss:
            best_loss = current_loss
            no_improve_epochs = 0
        else:
            no_improve_epochs += 1

        # 早停检查
        if no_improve_epochs >= patience:
            print("Early stopping triggered due to no improvement in loss for", patience, "epochs.")
            break

        # 更新参数

        with torch.no_grad():
            for name, param in model.named_parameters():
                if 'adapter' in name:
                    update_direction[name] += param - initial_state_dict[name]
                    param.copy_(
                        initial_state_dict[name] + args.meta_step_size * update_direction[name] / len(support_loader))

    # optimizer, _ = get_optimizer_meta(args, model)
    # for epoch in range(args.epochs_meta):
    #     model.train()
    #     meta_loss_avg = Averager()
    #     meta_acc_avg = Averager()
    #     torch.autograd.set_detect_anomaly(True)
    #
    #     for session, (support_data, support_labels) in enumerate(
    #             tqdm(support_loader, desc=f'Meta-Train Support Epoch {epoch}')):
    #         # print(f"Session {session}")
    #         # 重新获取元训练数据
    #         support_loader, query_loader = get_meta_data(train_set, np.arange(args.base_class), args, seed=epoch)
    #         # _, _, support_loader, query_loader = get_fake_incremental_data(
    #         #     train_set,
    #         #     np.arange(
    #         #         args.base_class),
    #         #     args, epoch)
    #
    #         # 复制当前模型参数作为初始参数
    #         original_state_dict = {name: param.clone().detach() for name, param in model.named_parameters()}
    #         temp_state_dict = original_state_dict.copy()  # 用于快速适应的临时参数
    #
    #         # 支持集上的快速适应
    #         optimizer.zero_grad()
    #         support_data, support_labels = support_data.cuda(), support_labels.cuda()
    #         support_outputs = forward_with_temp_params(model, support_data, temp_state_dict, session, args)
    #         # adjusted_support_labels = support_labels - session * args.fake_way
    #         adjusted_support_labels = support_labels
    #         loss = F.cross_entropy(support_outputs, adjusted_support_labels)
    #         loss.backward()
    #
    #         # 根据梯度更新临时参数（不更新全局模型）
    #         with torch.no_grad():
    #             for name, param in model.named_parameters():
    #                 if param.requires_grad:
    #                     update = args.lr_meta * param.grad.data.clone()
    #                     temp_state_dict[name] = temp_state_dict[name] - update
    #
    #         # 查询集累积梯度
    #         optimizer.zero_grad()  # 清空全局梯度
    #         accumulated_loss = 0
    #         for query_session, (query_data, query_labels) in enumerate(query_loader):
    #             if query_session > session:
    #                 break
    #             query_data, query_labels = query_data.cuda(), query_labels.cuda()
    #             query_outputs = forward_with_temp_params(model, query_data, temp_state_dict, session,
    #                                                      args)  # 使用自定义前向传播函数
    #             # adjusted_query_labels = query_labels - query_session * args.fake_way
    #             adjusted_query_labels = query_labels
    #             query_loss = F.cross_entropy(query_outputs, adjusted_query_labels)
    #             query_loss.backward()  # 在全局模型上累积梯度
    #             accumulated_loss += query_loss.item()
    #             # 记录和报告
    #             acc = count_acc(query_outputs, adjusted_query_labels)
    #             meta_loss_avg.add(accumulated_loss)
    #             meta_acc_avg.add(acc)
    #
    #         # 更新全局模型参数
    #         optimizer.step()

    # 注意: model.forward_support 和 model.forward_query 是假定已经修改的模型方法，
    # 允许你传入一个参数字典来使用特定的参数集进行前向传播。

    # Optionally log or print your metrics here...

    # print(f"Epoch {epoch}, Meta Training Avg Loss: {meta_loss_avg.item()}, Avg Accuracy: {meta_acc_avg.item()}")

    # model.module.encoder.train_meta = False

    mode_context.set('normal')
    return meta_loss_avg.item(), meta_acc_avg.item()


import torch.nn.functional as F


def zero_out_adapter_weights(model):
    """
    遍历模型的所有模块，将 adapter 层的权重设置为零，并返回这些权重。
    """
    adapter_weights = {}

    for name, module in model.named_modules():
        if isinstance(module, AdapterLayer):
            # 保存 adapter 层的权重
            adapter_weights[name] = module.adapter.weight.data.clone()
            # 将 adapter 层的权重设置为零
            module.adapter.weight.data.zero_()

    return adapter_weights


def restore_adapter_weights(model, adapter_weights):
    """
    将之前保存的 adapter 层的权重恢复到模型中。
    """
    for name, module in model.named_modules():
        if isinstance(module, AdapterLayer) and name in adapter_weights:
            # 恢复 adapter 层的权重
            module.adapter.weight.data.copy_(adapter_weights[name])


# def integrate_adapters(model, original_state_dict, args):
#     """
#     这个函数遍历ResNet模型，将每个AdapterLayer的权重进行填充后加到其对应的BasicBlock的对应的conv层上。
#     然后，将original_state_dict中的adapter参数重新加载回AdapterLayer。
#     """
#     alpha = 1.0
#     # print(alpha)
#     for name, module in model.named_modules():  # 使用named_modules遍历所有子模块
#         if isinstance(module, BasicBlock):
#
#             # 处理第一个Adapter和conv层
#             if hasattr(module, 'adapter1') and module.adapter1 is not None:
#                 # 先对权重进行填充
#                 padded_weight = F.pad(module.adapter1.adapter.weight, [1, 1, 1, 1], 'constant', 0)
#                 # padded_weight = module.adapter1.adapter.weight
#                 assert module.conv1.weight.size() == padded_weight.size(), \
#                     "Padded Adapter1 and Conv1 weight size mismatch"
#                 module.conv1.weight.data += alpha * padded_weight
#                 # 恢复adapter1的参数
#                 adapter1_keys = {k: v for k, v in original_state_dict.items() if 'adapter1.adapter.' in k}
#                 for key, value in adapter1_keys.items():
#                     attr_name = key.split('.')[-1]
#                     getattr(module.adapter1.adapter, attr_name).data.copy_(value)
#
#             # 处理第二个Adapter和conv层
#             if hasattr(module, 'adapter2') and module.adapter2 is not None:
#                 # 先对权重进行填充
#                 padded_weight = F.pad(module.adapter2.adapter.weight, [1, 1, 1, 1], 'constant', 0)
#                 # padded_weight = module.adapter2.adapter.weight
#                 assert module.conv2.weight.size() == padded_weight.size(), \
#                     "Padded Adapter2 and Conv2 weight size mismatch"
#                 module.conv2.weight.data += alpha * padded_weight
#                 # # 恢复adapter2的参数
#                 adapter2_keys = {k: v for k, v in original_state_dict.items() if 'adapter2.adapter.' in k}
#                 for key, value in adapter2_keys.items():
#                     attr_name = key.split('.')[-1]
#                     getattr(module.adapter2.adapter, attr_name).data.copy_(value)
def integrate_adapters(model, original_state_dict, args):
    """
    这个函数遍历ResNet模型，将每个AdapterLayer的权重进行填充后加到其对应的BasicBlock的对应的conv层上。
    然后，将original_state_dict中的adapter参数重新加载回AdapterLayer。
    """
    # alpha = args.way / args.base_class
    # alpha = 0.01
    for name, module in model.named_modules():  # 使用named_modules遍历所有子模块
        if isinstance(module, BasicBlock):
            if hasattr(module, 'alpha') and hasattr(module, 'beta'):
                alpha = module.alpha
                beta = module.beta

                # 处理第一个Adapter和conv层
                if hasattr(module, 'adapter1') and module.adapter1 is not None:
                    # 先对权重进行填充

                    padded_weight = F.pad(module.adapter1.adapter.weight, [1, 1, 1, 1], 'constant',
                                          0)
                    # padded_weight = module.adapter1.adapter.weight
                    # padded_weight = F.pad(module.adapter1.adapter.weight, [1, 1, 1, 1], 'reflect')
                    # padded_weight = module.adapter1.conv1.weight.data @ module.adapter1.conv2.weight.data

                    assert module.conv1.weight.size() == padded_weight.size(), \
                        "Padded Adapter1 and Conv1 weight size mismatch"
                    # module.conv1.weight.data += alpha * padded_weight
                    module.conv1.weight.data += alpha * padded_weight
                    # module.conv1.weight.data = torch.lerp(module.conv1.weight.data, padded_weight, alpha)

                # 处理第二个Adapter和conv层
                if hasattr(module, 'adapter2') and module.adapter2 is not None:
                    # 先对权重进行填充

                    padded_weight = F.pad(module.adapter2.adapter.weight, [1, 1, 1, 1], 'constant',
                                          0)
                    # padded_weight = module.adapter2.adapter.weight
                    # padded_weight = F.pad(module.adapter2.adapter.weight, [1, 1, 1, 1], 'reflect')
                    # padded_weight = module.adapter2.conv1.weight.data @ module.adapter2.conv2.weight.data

                    assert module.conv2.weight.size() == padded_weight.size(), \
                        "Padded Adapter2 and Conv2 weight size mismatch"
                    # module.conv2.weight.data += alpha * padded_weightd
                    module.conv2.weight.data += alpha * padded_weight
                    # module.conv2.weight.data = torch.lerp(module.conv2.weight.data, padded_weight, alpha)
                    # 处理第二个Adapter和conv层
                if hasattr(module, 'adapter3') and module.adapter3 is not None:
                    # 先对权重进行填充

                    padded_weight = F.pad(module.adapter3.adapter.weight, [1, 1, 1, 1], 'constant',
                                          0)
                    # padded_weight = module.adapter3.adapter.weight
                    # padded_weight = F.pad(module.adapter2.adapter.weight, [1, 1, 1, 1], 'reflect')
                    # padded_weight = module.adapter2.conv1.weight.data @ module.adapter2.conv2.weight.data

                    assert module.conv3.weight.size() == padded_weight.size(), \
                        "Padded Adapter3 and Conv3 weight size mismatch"
                    # module.conv2.weight.data += alpha * padded_weight
                    module.conv3.weight.data += alpha * padded_weight
                    # module.conv2.weight.data = torch.lerp(module.conv2.weight.data, padded_weight, alpha)
                # if module.adapter1 is not None:
                #     # 获取适配器层的权重
                #     adapter1_weight1 = module.adapter1.adapter1.weight.data  # [32, 64, 1, 1]
                #     adapter1_weight2 = module.adapter1.adapter2.weight.data  # [32, 32, 3, 3]
                #     adapter1_weight3 = module.adapter1.adapter3.weight.data  # [128, 32, 1, 1]
                #
                #     # 动态获取维度大小
                #     in_channels_1 = adapter1_weight1.size(1)
                #     out_channels_1 = adapter1_weight1.size(0)
                #     out_channels_3 = adapter1_weight3.size(0)
                #
                #     # 将 adapter1_weight1 和 adapter1_weight3 合并
                #     combined_weight1 = torch.matmul(adapter1_weight3.view(out_channels_3, out_channels_1),
                #                                     adapter1_weight2.view(out_channels_1, -1))
                #     # 将 combined_weight1 与 adapter1_weight2 合并
                #     combined_weight2 = torch.matmul(combined_weight1.view(-1, out_channels_1),
                #                                     adapter1_weight1.view(out_channels_1, -1))
                #     final_weight = combined_weight2.view(out_channels_3, in_channels_1, 3, 3)
                #
                #     # 将合并后的权重加到卷积层中
                #     module.conv1.weight.data += module.alpha * final_weight
                #
                # if module.adapter2 is not None:
                #     # 获取适配器层的权重
                #     adapter2_weight1 = module.adapter2.adapter1.weight.data  # [32, 64, 1, 1]
                #     adapter2_weight2 = module.adapter2.adapter2.weight.data  # [32, 32, 3, 3]
                #     adapter2_weight3 = module.adapter2.adapter3.weight.data  # [128, 32, 1, 1]
                #
                #     # 动态获取维度大小
                #     in_channels_1 = adapter2_weight1.size(1)
                #     out_channels_1 = adapter2_weight1.size(0)
                #     out_channels_3 = adapter2_weight3.size(0)
                #
                #     # 将 adapter1_weight1 和 adapter1_weight3 合并
                #     combined_weight1 = torch.matmul(adapter2_weight3.view(out_channels_3, out_channels_1),
                #                                     adapter2_weight2.view(out_channels_1, -1))
                #     # 将 combined_weight1 与 adapter1_weight2 合并
                #     combined_weight2 = torch.matmul(combined_weight1.view(-1, out_channels_1),
                #                                     adapter2_weight1.view(out_channels_1, -1))
                #     final_weight = combined_weight2.view(out_channels_3, in_channels_1, 3, 3)
                #
                #     # 将合并后的权重加到卷积层中
                #     module.conv2.weight.data += module.alpha * final_weight
                # if module.adapter3 is not None:
                #     # 获取适配器层的权重
                #     adapter3_weight1 = module.adapter3.adapter1.weight.data  # [32, 64, 1, 1]
                #     adapter3_weight2 = module.adapter3.adapter2.weight.data  # [32, 32, 3, 3]
                #     adapter3_weight3 = module.adapter3.adapter3.weight.data  # [128, 32, 1, 1]
                #
                #     # 动态获取维度大小
                #     in_channels_1 = adapter3_weight1.size(1)
                #     out_channels_1 = adapter3_weight1.size(0)
                #     out_channels_3 = adapter3_weight3.size(0)
                #
                #     # 将 adapter1_weight1 和 adapter1_weight3 合并
                #     combined_weight1 = torch.matmul(adapter3_weight3.view(out_channels_3, out_channels_1),
                #                                     adapter3_weight2.view(out_channels_1, -1))
                #     # 将 combined_weight1 与 adapter1_weight2 合并
                #     combined_weight2 = torch.matmul(combined_weight1.view(-1, out_channels_1),
                #                                     adapter3_weight1.view(out_channels_1, -1))
                #     final_weight = combined_weight2.view(out_channels_3, in_channels_1, 3, 3)
                #
                #     # 将合并后的权重加到卷积层中
                #     module.conv3.weight.data += module.alpha * final_weight
    # 恢复适配器层的权重
    for name, param in model.named_parameters():
        if 'adapter' in name:
            original_param = original_state_dict[name]
            current_param = param.data
            # 确保原始权重和当前权重的尺寸匹配
            if original_param.size() != current_param.size():
                raise ValueError(f"Original and current weight size mismatch for {name}")
            # 恢复权重
            param.data.copy_(original_param)
    # Set adapter layer weights to zero
    # for name, param in model.named_parameters():
    #     if 'adapter' in name:
    #         # Get the current parameter data
    #         current_param = param.data
    #         # Set the current parameter data to zero
    #         param.data.zero_()


def update_original_adapters(original_state_dict, model, meta_step_size):
    """
    Update the adapter parameters in the original state dictionary.

    Args:
    - original_state_dict (dict): The original state dictionary before the forward pass and loss computation.
    - model (nn.Module): The model being trained, after the forward pass and loss computation.
    - meta_step_size (float): The step size for the meta update.

    Returns:
    - updated_state_dict (dict): A dictionary with the updated adapter parameters.
    """
    updated_state_dict = original_state_dict.copy()
    with torch.no_grad():
        for name, param in model.named_parameters():
            if 'adapter' in name and param.grad is not None:  # 检查梯度是否为None
                updated_state_dict[name] = original_state_dict[name] - meta_step_size * param.grad
            elif 'adapter' in name:
                print(f"No gradient for {name}, skipping update.")
    return updated_state_dict


def supervised_contrastive_loss(features, labels, temperature=0.1):
    features = F.normalize(features, p=2, dim=1)
    sim_matrix = torch.matmul(features, features.T) / temperature

    # 移除自相似性影响
    sim_matrix.fill_diagonal_(-1e9)

    # 创建正样本对的掩码
    labels = labels.contiguous().view(-1, 1)
    mask = torch.eq(labels, labels.T).float()

    # 应用掩码并计算每对的指数相似度
    exp_sim = torch.exp(sim_matrix)
    mask.fill_diagonal_(0)  # 确保对角线上的值不会被计算

    # 每个样本对所有其他样本的指数相似度总和
    sum_exp_sim = torch.sum(exp_sim, dim=1, keepdim=True)

    # 计算对数概率
    log_prob = sim_matrix - torch.log(sum_exp_sim + 1e-9)

    # 计算所有正样本对的平均对数概率
    mean_log_prob_pos = (mask * log_prob).sum(1) / (mask.sum(1) + 1e-9)

    # 损失是所有样本的负平均对数概率
    loss = -mean_log_prob_pos.mean()
    return loss


# def base_train(model, trainloader, optimizer, scheduler, epoch, args):
#     tl = Averager()
#     ta = Averager()
#     mode_context.set('normal')
#     model.module.train_backbone()
#     model.module.mode = args.base_mode
#     model = model.train()
#     # standard classification for pretrain
#     tqdm_gen = tqdm(trainloader)
#     lam = args.lam
#     for i, batch in enumerate(tqdm_gen, 1):
#         data, train_label = [_ for _ in batch]
#
#         # model.module.mode = 'encoder'
#         # b, c, h, w = data[1].shape
#         # original = data[0].cuda(non_blocking=True)
#         # data[1] = data[1].cuda(non_blocking=True)
#         # data[2] = data[2].cuda(non_blocking=True)
#         # train_label = train_label.cuda(non_blocking=True)
#         #
#         # # Get the features from the model for both augmented images
#         # feature_original = model(original)
#         # features1 = model(data[1])
#         # features2 = model(data[2])
#         # labels = torch.cat(
#         #     [train_label, train_label, train_label],
#         #     dim=0)
#         #
#         # # Concatenate the features from the augmented images and the prototypes
#         # features = torch.cat([feature_original, features1, features2], dim=0)
#         # sup_loss = supervised_contrastive_loss(features, labels, temperature=args.temperature_sup)
#         # model.module.mode = args.base_mode
#         logits = model(data)
#         logits = logits[:, :args.base_class]
#         cross_entropy_loss = F.cross_entropy(logits, train_label)
#         loss = cross_entropy_loss
#         optimizer.zero_grad()
#         loss.backward()
#         # optimizer.step()
#
#         # acc = count_acc(logits, train_label)
#         # #
#         # total_loss = loss
#         #         # 计算扰动并应用到参数
#         with torch.no_grad():
#                     for name, module in model.named_modules():
#                         if isinstance(module, BasicBlock) and module.include_adapter:
#                             for param_name, param in module.named_parameters():
#                                 if 'conv1.weight' in param_name or 'conv2.weight' in param_name:
#                                     norm = param.grad.norm()
#                                     if norm != 0 and not torch.isnan(norm):
#                                         perturbation = args.rho * param.grad / norm
#                                         param.add_(perturbation)
#                         # elif isinstance(module, torch.nn.Linear) and name == "fc":
#                         #     for param_name, param in module.named_parameters():
#                         #         if 'weight' in param_name:  # You can also check for 'bias' if you want to perturb it
#                         #             norm = param.grad.norm()
#                         #             if norm != 0 and not torch.isnan(norm):
#                         #                 perturbation = args.rho * param.grad / norm
#                         #                 param.add_(perturbation)
#
#                 # 重新计算扰动后的损失并进行第二次反向传播
#         optimizer.zero_grad()
#         logits = model(data)
#         perturbed_loss = F.cross_entropy(logits[:, :args.base_class], train_label)
#         perturbed_loss.backward()
#         optimizer.step()
#         acc = count_acc(logits, train_label)
#
#         total_loss = loss
#
#         lrc = scheduler.get_last_lr()[0]
#         tqdm_gen.set_description(
#             f'Epoch {epoch}, Lr: {lrc:.4f}, Total Loss: {total_loss.item():.4f}, CE Loss: {cross_entropy_loss.item():.4f}, Sup Loss: {sup_loss.item():.4f}, Acc: {acc:.4f}')
#         tl.add(total_loss.item())
#         ta.add(acc)
#
#     tl = tl.item()
#     ta = ta.item()
#     return tl, ta


def incremental_train(prototypes, model, trainloader, optimizer, session, args, optimal_params):
    tl = Averager()
    ta = Averager()
    # Set the mode
    mode_context.set('parallel_adapters')
    # mode_context.set('normal')
    # old_class = args.base_class + args.way * (session - 1)
    # new_class = args.base_class + args.way * session
    # new_fc = nn.Parameter(
    #     torch.rand(args.way, model.module.num_features, device="cuda"),
    #     requires_grad=True)
    # new_fc.data.copy_(model.module.fc.weight[old_class: new_class, :].data)
    model = model.train()
    num_epochs = 1  # 设置你想遍历数据集的次数
    gamma = 1.0
    # freeze_resnet_layers(model)
    model.module.train_adapter()
    # model.module.train_backbone()
    original_state_dict = {name: param.clone() for name, param in model.named_parameters()}
    lambda_alpha = 1.0
    lambda_beta = 1.0
    angular_loss = AngularPenaltySMLoss(args)
    # 外部循环控制epoch数
    for epoch in range(num_epochs):
        tqdm_gen = tqdm(trainloader, desc=f"Epoch {epoch + 1}/{num_epochs}")

        for i, batch in enumerate(tqdm_gen, 1):
            data, train_label = [_.cuda() for _ in batch]
            # data, train_label = fusion_aug_one_image(data, train_label, session, args, alpha=20.0, mix_times=4)

            model.module.mode = 'encoder'
            # b, c, h, w = data[1].shape

            prototypes = prototypes.cuda(non_blocking=True)
            # original = data[0].cuda(non_blocking=True)
            # data[1] = data[1].cuda(non_blocking=True)
            # data[2] = data[2].cuda(non_blocking=True)
            train_label = train_label.cuda(non_blocking=True)

            # Get the features from the model for both augmented images
            # feature_original = model(original)
            # features1 = model(data[1])
            # features2 = model(data[2])

            # Concatenate the features from the augmented images and the prototypes
            # features = torch.cat([feature_original, features1, features2], dim=0)

            # Create labels for the features
            # labels = torch.cat(
            #     [train_label, train_label, train_label,
            #      torch.arange(args.base_class + (session - 1) * args.way).cuda(non_blocking=True)],
            #     dim=0)
            labels = torch.cat(
                [train_label, train_label, train_label],
                dim=0)

            reg_loss = simple_reg_loss_l1(model, optimal_params, lambda_reg=1e-1)

            # print(reg_loss)
            # 计算最佳参数下的模型输出（教师模型）
            with torch.no_grad():
                # 保存当前参数
                current_params = deepcopy(model.state_dict())

                # 如果模型被封装在DataParallel中，需要调整参数的键名
                if isinstance(model, torch.nn.DataParallel):
                    adjusted_params = {f'module.{k}': v for k, v in current_params.items()}
                else:
                    adjusted_params = current_params

                # 使用最佳参数更新模型，这里假设optimal_params已经是正确格式

                model.load_state_dict(optimal_params)  # 注意这里使用 module 来加载参数
                model.module.mode = 'encoder'  # 同样使用 module 来设置模式
                mode_context.set('normal')
                features_old = model(data)
                model.module.mode = args.new_mode
                optimal_logits = model(data)

                # features_teacher = model.module.encoder(original, return_conv=True)  # 教师模型的特征

                # 恢复模型到最初的参数
                model.load_state_dict(current_params)  # 再次使用 module 来恢复参数
            # features_student = model.module.encoder(original, return_conv=True)  # 学生模型的特征

            # Calculate the SupCon loss

            # supcon_loss = SupConLoss(features, labels, prototypes[:args.base_class + (session - 1) * args.way])
            # model.module.mode = args.new_mode
            # mode_context.set('parallel_adapters')
            # cos_loss = prototypical_loss_cosine(features, labels, args, session)
            # print("supcon_loss:", supcon_loss)
            # model.module.mode = 'encoder'
            # old_fc = model.module.fc.weight[:old_class, :].clone().detach()
            # fc = torch.cat([old_fc, new_fc], dim=0)
            # Generate augmented prototypes and labels
            # Initialize lists for augmented prototypes and their labels
            # augmented_prototypes = []
            # augmented_labels = []
            # indices = list(range(args.base_class + (session - 1) * args.way))
            # for proto_id in range(len(prototypes)):
            #     np.random.shuffle(indices)
            #     selected_proto = prototypes[indices[0]]
            #     augmented_prototypes.append(selected_proto)
            #     augmented_labels.append(proto_id)
            #
            # # Convert list to PyTorch tensors and upload to GPU
            # augmented_prototypes = torch.stack(augmented_prototypes).cuda()
            # augmented_labels = torch.tensor(augmented_labels).cuda()
            #
            # # Forward pass through the network and calculate loss
            # output_features = model.module.fc(augmented_prototypes)
            # # output_features_new = model.module.fc(features)
            # # output_features_new = output_features_new[:, :args.base_class + session * args.way]
            # temperature_scaled_output = output_features
            # loss_protoAug = nn.CrossEntropyLoss()(temperature_scaled_output, augmented_labels)
            # loss_new = nn.CrossEntropyLoss()(output_features_new, labels)
            mode_context.set('parallel_adapters')

            # features_new = model(original)
            # loss_kd = torch.norm(features_new - features_old, p=2, dim=1)

            # features.detach()
            # logits = model.module.get_logits(features, fc)
            model.module.mode = args.new_mode

            logits = model(data)
            feature_layers = ['m1', 'm2', 'm3', 'm4']
            distillation_losses = {}
            # features_old = F.normalize(prototypes[:args.base_class], p=2, dim=-1)
            # features_nl = F.normalize(features_new, p=2, dim=-1)
            #
            # # Normalize the features and the prototypes
            # features_norm = F.normalize(features_new, p=2, dim=-1)
            # prototypes_norm = F.normalize(
            #     prototypes[:args.base_class],
            #     p=2, dim=-1)
            # weights = torch.mm(features_norm, prototypes_norm.T) * args.softmax_t
            # norm_weights = torch.softmax(weights, dim=1)
            # delta_protos = torch.matmul(norm_weights, prototypes_norm)
            #
            # delta_protos = F.normalize(delta_protos, p=2, dim=-1)
            #
            # updated_protos = (1 - args.shift_weight) * features_norm + args.shift_weight * delta_protos

            # Index the prototypes according to the class labels in train_label
            # This selects the correct prototype for each feature based on its class label
            # target_prototypes = prototypes_norm[train_label - args.base_class - (session - 1) * args.way]
            #
            # Compute the cosine similarity between each feature vector and its corresponding class prototype
            # cosine_similarity = (features_norm * target_prototypes).sum(dim=1)
            # cos_loss = -cosine_similarity.mean()

            # cos_dist = (features_nl @ features_old.T)
            # cos_dist = torch.max(cos_dist, dim=-1).values
            # cos_dist2 = (1 - cos_dist) * args.softmax_t

            # for layer in feature_layers:
            #     distillation_losses[layer] = F.mse_loss(features_student[layer], features_teacher[layer])
            # distillation_loss = sum(distillation_losses.values())
            # 计算蒸馏损失
            # distillation_loss = F.mse_loss(logits[:, :args.base_class],
            #                                optimal_logits.detach()[:, :args.base_class])
            # distillation_loss = F.mse_loss(features_now, features_old)
            # 计算蒸馏损失
            # logits_log_softmax = F.log_softmax(logits[:, :args.base_class + (
            #         session - 1) * args.way], dim=1)
            # optimal_logits_softmax = F.softmax(optimal_logits.detach()[:, :args.base_class + (
            #         session - 1) * args.way], dim=1)
            logits_log_softmax = F.log_softmax(logits, dim=1)
            optimal_logits_softmax = F.softmax(optimal_logits.detach(), dim=1)
            # distillation_loss = F.kl_div(logits_log_softmax, optimal_logits_softmax, reduction='batchmean')
            distillation_loss = -torch.sum(optimal_logits_softmax * logits_log_softmax, dim=1).mean()

            #
            # adjusted_labels = train_label - args.base_class - (session - 1) * args.way
            # adjusted_labels = labels - args.base_class - (session - 1) * args.way
            adjusted_labels = train_label
            #
            model.module.mode = 'encoder'
            features = model(data)
            model.module.mode = args.new_mode
            mode_context.set('parallel_adapters')

            # data, train_label = fusion_aug_one_image(data, train_label, session, args, alpha=20.0, mix_times=4)
            logits = model(data)
            # logits = model.module.get_logits(features_new, fc)
            # logits = logits[:, args.base_class + (
            #         session - 1) * args.way: args.base_class + session * args.way]
            logits = logits[:, :args.base_class + session * args.way]
            cross_entropy = F.cross_entropy(logits, adjusted_labels)
            # cross_entropy = advanced_custom_loss(logits, adjusted_labels, args, session, gamma=0.1)
            # cross_entropy = torch.mean(cross_entropy * cos_dist2, dim=0)
            # loss_kd = torch.sum(loss_kd * cos_dist2, dim=0)
            # cross_entropy = custom_loss(logits, adjusted_labels)
            # Pass the augmented data through the model
            # logits_aug_1 = model(data[1])
            # logits_aug_1 = logits_aug_1[:,
            #                args.base_class + (session - 1) * args.way: args.base_class + session * args.way]
            #
            # logits_aug_2 = model(data[2])
            # logits_aug_2 = logits_aug_2[:,
            #                args.base_class + (session - 1) * args.way: args.base_class + session * args.way]
            # loss_aug_1 = F.cross_entropy(logits_aug_1, adjusted_labels)
            # loss_aug_2 = F.cross_entropy(logits_aug_2, adjusted_labels)
            print("cross entropy: ", cross_entropy)
            print('reg_loss:', reg_loss)
            # print('cos_loss:', cos_loss)
            print("distillation loss: ", distillation_loss)
            # print("loss_1:", loss_aug_1)
            # print("loss_2:", loss_aug_2)
            # print("supcon_loss: ", supcon_loss)
            # print("loss_kd:", loss_kd)
            # Optionally print loss
            # print(f"Prototype Augmentation Loss: {loss_protoAug.item()}")
            # print("loss_new", loss_new)
            loss_reg = 0

            # for name, module in model.named_modules():
            #     if isinstance(module, BasicBlock):
            #         # 我们想要beta大，所以对(0.7 - beta)进行惩罚，当beta小于0.7时这将是正的。
            #         # module.beta.requires_grad = True
            #         # module.alpha.requires_grad = True
            #         reg_beta = torch.relu(0.7 - module.beta)
            #         # print(module.beta)
            #
            #         # 我们想要alpha小，所以对(alpha - 0.3)进行惩罚，当alpha大于0.3时这将是正的。
            #         reg_alpha = torch.relu(module.alpha - 0.3)
            #
            #         # 计算正则化项，将alpha和beta的正则合并
            #         regularization = lambda_alpha * reg_alpha + lambda_beta * reg_beta
            #         loss_reg += regularization
            #
            # print("loss_reg:", loss_reg)
            batch_classes = train_label.tolist()
            all_classes = list(range(args.num_classes))
            # ortho_loss = minimize_similarity_loss(features, batch_classes, all_classes, session, args, prototypes)
            loss = cross_entropy + args.w_kd * distillation_loss
            # loss = cross_entropy
            # if torch.isnan(loss).any():
            #     print("Loss is NaN!")
            # acc = count_acc(features, labels)

            # total_loss = loss+

            # lrc = scheduler.get_last_lr()[0]
            # tqdm_gen.set_description(
            #     'Session 0, epo {}, lrc={:.4f},total loss={:.4f} acc={:.4f}'.format(epoch, lrc, total_loss.item(), acc))
            # tl.add(total_loss.item())
            # ta.add(acc)

            optimizer.zero_grad()
            loss.backward()
            # adjust_gradients(model, optimal_params)

            # if (loss.grad is not None) and torch.isnan(loss.grad).any():
            #     print("Gradient is NaN!")
            # 检查梯度是否为NaN
            # for name, param in model.named_parameters():
            #     if param.grad is not None:
            #         if torch.isnan(param.grad).any():
            #             print(f"Gradient is NaN for {name}")
            #         else:
            #             print(f"Gradient for {name}: {param.grad}")
            optimizer.step()
            # 在每次梯度更新后，确保alpha和beta的值处于[0, 1]的范围内
            # for name, module in model.named_modules():
            #     if isinstance(module, BasicBlock):
            #         module.alpha.clamp_(0, 1)
            #         module.beta.clamp_(0, 1)

            # Update the model parameters based on the accumulated gradients
    # with torch.no_grad():
    #     for name, param in model.named_parameters():
    #         updated_param = original_state_dict[name] + args.meta_step_size * (
    #                 param - original_state_dict[name])
    #         param.copy_(updated_param)
    # original_state_dict = update_original_adapters(original_state_dict, model, args.meta_step_size)

    integrate_adapters(model, original_state_dict, args)
    # tl = tl.item()
    # ta = ta.item()
    # return tl, ta


def replace_base_fc(trainset, transform, model, args):
    # replace fc.weight with the embedding average of train data
    # Set the mode
    # beta = 0.5
    mode_context.set('normal')
    model = model.eval()

    trainloader = torch.utils.data.DataLoader(dataset=trainset, batch_size=128,
                                              num_workers=8, pin_memory=True, shuffle=False)
    trainloader.dataset.transform = transform
    embedding_list = []
    label_list = []
    with torch.no_grad():
        for i, batch in enumerate(trainloader):
            data, label = [_.cuda() for _ in batch]
            model.module.mode = 'encoder'
            embedding = model(data)

            embedding_list.append(embedding.cpu())
            label_list.append(label.cpu())
    embedding_list = torch.cat(embedding_list, dim=0)
    label_list = torch.cat(label_list, dim=0)

    proto_list = []

    for class_index in range(args.base_class):
        data_index = (label_list == class_index).nonzero()
        embedding_this = embedding_list[data_index.squeeze(-1)]
        embedding_this = embedding_this.mean(0)
        proto_list.append(embedding_this)

    proto_list = torch.stack(proto_list, dim=0)

    model.module.fc.weight.data[:args.base_class] = proto_list

    return model


def replace_base_fc_and_visualize(trainset, transform, model, args):
    # replace fc.weight with the embedding average of train data
    # Set the mode
    mode_context.set('normal')
    model = model.eval()

    trainloader = torch.utils.data.DataLoader(dataset=trainset, batch_size=128,
                                              num_workers=8, pin_memory=True, shuffle=False)
    trainloader.dataset.transform = transform
    embedding_list = []
    label_list = []
    with torch.no_grad():
        for i, batch in enumerate(trainloader):
            data, label = [_.cuda() for _ in batch]
            model.module.mode = 'encoder'
            embedding = model(data)

            embedding_list.append(embedding.cpu())
            label_list.append(label.cpu())
    embedding_list = torch.cat(embedding_list, dim=0)
    label_list = torch.cat(label_list, dim=0)

    # t-SNE visualization
    tsne = TSNE(n_components=2, random_state=42)
    embedding_2d = tsne.fit_transform(embedding_list)

    plt.figure(figsize=(10, 10))
    for class_index in range(args.base_class):
        idx = label_list == class_index
        plt.scatter(embedding_2d[idx, 0], embedding_2d[idx, 1], label=f'Class {class_index}', alpha=0.5)
    plt.legend()
    plt.title('t-SNE visualization of embeddings')
    plt.savefig('tsne_visualization.png')  # 保存图像
    # plt.show()

    proto_list = []

    for class_index in range(args.base_class):
        data_index = (label_list == class_index).nonzero()
        embedding_this = embedding_list[data_index.squeeze(-1)]
        embedding_this = embedding_this.mean(0)
        proto_list.append(embedding_this)

    proto_list = torch.stack(proto_list, dim=0)

    model.module.fc.weight.data[:args.base_class] = proto_list

    return model


# def replace_base_fc_and_visualize_umap(trainset, transform, model, args):
#     # replace fc.weight with the embedding average of train data
#     mode_context.set('normal')
#     model = model.eval()
#
#     trainloader = DataLoader(dataset=trainset, batch_size=128, num_workers=8, pin_memory=True, shuffle=False)
#     trainloader.dataset.transform = transform
#     embedding_list = []
#     label_list = []
#
#     with torch.no_grad():
#         for i, batch in enumerate(trainloader):
#             data, label = [_.cuda() for _ in batch]
#             model.module.mode = 'encoder'
#             embedding = model(data)
#
#             embedding_list.append(embedding.cpu())
#             label_list.append(label.cpu())
#
#     embedding_list = torch.cat(embedding_list, dim=0)
#     label_list = torch.cat(label_list, dim=0)
#
#     # UMAP visualization
#     umap_model = umap.UMAP(n_components=2, random_state=42)
#     embedding_2d = umap_model.fit_transform(embedding_list)
#
#     plt.figure(figsize=(10, 10))
#     for class_index in range(args.base_class):
#         idx = label_list == class_index
#         plt.scatter(embedding_2d[idx, 0], embedding_2d[idx, 1], label=f'Class {class_index}', alpha=0.5)
#     plt.legend()
#     plt.title('UMAP visualization of embeddings')
#     plt.savefig('umap_visualization.png')  # 保存图像
#
#     proto_list = []
#     for class_index in range(args.base_class):
#         data_index = (label_list == class_index).nonzero()
#         embedding_this = embedding_list[data_index.squeeze(-1)]
#         embedding_this = embedding_this.mean(0)
#         proto_list.append(embedding_this)
#
#     proto_list = torch.stack(proto_list, dim=0)
#     model.module.fc.weight.data[:args.base_class] = proto_list
#
#     return model


def calculate_class_prototypes(trainset, session, transform, model, args):
    # replace fc.weight with the embedding average of train data
    mode_context.set('normal')
    model = model.eval()

    # trainloader = torch.utils.data.DataLoader(dataset=trainset, batch_size=128,
    #                                           num_workers=8, pin_memory=True, shuffle=False)
    train_set, trainloader, testloader = get_dataloader(args, session)
    trainloader.dataset.transform = transform
    embedding_list = []
    label_list = []

    with torch.no_grad():
        for i, batch in enumerate(trainloader):
            if session == 0:
                data, label = [_.cuda() for _ in batch]
            else:
                data, label = [item[0].cuda() for item in batch]
            model.module.mode = 'encoder'
            embedding = model(data)

            embedding_list.append(embedding.cpu())
            label_list.append(label.cpu())
    embedding_list = torch.cat(embedding_list, dim=0)
    label_list = torch.cat(label_list, dim=0)

    proto_list = []

    for class_index in range(args.base_class):
        data_index = (label_list == class_index).nonzero()
        embedding_this = embedding_list[data_index.squeeze(-1)]
        embedding_this = embedding_this.mean(0)
        proto_list.append(embedding_this)

    proto_list = torch.stack(proto_list, dim=0)

    # model.module.fc.weight.data[:args.base_class] = proto_list

    return proto_list


def calculate_class_prototypes_incremental(session, transform, model, args):
    # replace fc.weight with the embedding average of train data
    mode_context.set('normal')
    model = model.eval()

    # trainloader = torch.utils.data.DataLoader(dataset=trainset, batch_size=128,
    #                                           num_workers=8, pin_memory=True, shuffle=False)
    train_set, trainloader, testloader = get_dataloader(args, session)
    trainloader.dataset.transform = transform
    embedding_list = []
    label_list = []
    tqdm_gen = tqdm(trainloader)
    with torch.no_grad():
        for i, batch in enumerate(tqdm_gen, 1):
            data, label = [_ for _ in batch]
            original = data.cuda(non_blocking=True)
            label = label.cuda(non_blocking=True)
            model.module.mode = 'encoder'
            embedding = model(original)

            embedding_list.append(embedding.cpu())
            label_list.append(label.cpu())
    embedding_list = torch.cat(embedding_list, dim=0)
    # print(label_list)
    label_list = torch.cat(label_list, dim=0)

    proto_list = []

    for class_index in range(args.base_class + (session - 1) * args.way, args.base_class + session * args.way):
        data_index = (label_list == class_index).nonzero()
        embedding_this = embedding_list[data_index.squeeze(-1)]
        embedding_this = embedding_this.mean(0)
        proto_list.append(embedding_this)

    proto_list = torch.stack(proto_list, dim=0)

    # model.module.fc.weight.data[:args.base_class] = proto_list

    return proto_list


def test(model, testloader, epoch, args, session, result_list=None):
    test_class = args.base_class + session * args.way
    # adapter_weights = zero_out_adapter_weights(model)
    # Set the mode
    mode_context.set('normal')
    model.module.mode = args.new_mode
    model = model.eval()
    vl = Averager()
    va = Averager()
    va5 = Averager()
    lgt = torch.tensor([])
    lbs = torch.tensor([])
    with torch.no_grad():
        for i, batch in enumerate(testloader, 1):
            data, test_label = [_.cuda() for _ in batch]
            logits = model(data)
            logits = logits[:, :test_class]
            # if session > 0:
            #     print(logits)

            loss = F.cross_entropy(logits, test_label)
            acc = count_acc(logits, test_label)
            top5acc = count_acc_topk(logits, test_label)

            vl.add(loss.item())
            va.add(acc)
            va5.add(top5acc)

            lgt = torch.cat([lgt, logits.cpu()])
            lbs = torch.cat([lbs, test_label.cpu()])
        vl = vl.item()
        va = va.item()
        va5 = va5.item()

        logging.info('epo {}, test, loss={:.4f} acc={:.4f}, acc@5={:.4f}'.format(epoch, vl, va, va5))

        lgt = lgt.view(-1, test_class)
        lbs = lbs.view(-1)

        # if session > 0:
        #     _preds = torch.argmax(lgt, dim=1)
        #     torch.save(_preds, f"pred_labels/{args.project}_{args.dataset}_{session}_preds.pt")
        #     torch.save(lbs, f"pred_labels/{args.project}_{args.dataset}_{session}_labels.pt")
        #     torch.save(model.module.fc.weight.data.cpu()[:test_class], f"pred_labels/{args.project}_{args.dataset}_{session}_weights.pt")

        if session > 0:
            save_model_dir = os.path.join(args.save_path, 'session' + str(session) + 'confusion_matrix')
            # cm = confmatrix(lgt, lbs, save_model_dir)
            cm = confmatrix(lgt, lbs)
            perclassacc = cm.diagonal()
            baseac = np.mean(perclassacc[:args.base_class])
            unseenac = np.mean(perclassacc[args.base_class:])
            cur_sessac = np.mean(perclassacc[args.base_class + (session - 1) * args.way:])

            result_list.append(f"Seen Acc:{baseac}  Unseen Acc:{unseenac}  Cur Acc:{cur_sessac}")
            # restore_adapter_weights(model, adapter_weights)
            return vl, (baseac, unseenac, va)
        else:
            # restore_adapter_weights(model, adapter_weights)
            return vl, va


import matplotlib.pyplot as plt
import torch
import numpy as np
from sklearn.manifold import TSNE
import random
import os
from sklearn.metrics.pairwise import euclidean_distances

# 设置随机种子以保证结果可复现
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

# 计算欧氏距离
def calculate_euclidean_distances(features, prototype):
    return euclidean_distances(features, prototype)

def plot_tsne(model, testloader, args, session, seed=24, num_samples_per_class=100):
    set_seed(seed)  # 设置随机种子

    model.module.mode = 'encoder'  # 使模型返回特征
    test_class = args.base_class + session * args.way

    # 获取所有的 base 和 novel 类的原型
    prototypes = model.module.fc.weight.data.cpu()[:test_class]

    # 随机选择一个 base class
    base_classes = list(range(args.base_class))  # Base类
    novel_classes = list(range(args.base_class, test_class))  # Novel类

    first_base_class = random.choice(base_classes)

    # 基于欧氏距离，选择与第一个 base class 距离最远的4个 base classes
    base_distances = calculate_euclidean_distances(prototypes[base_classes], prototypes[first_base_class].unsqueeze(0))
    farthest_base_classes = np.argsort(-base_distances[:, 0])[:4]  # 距离最远的4个 base classes

    selected_base_classes = [first_base_class] + list(farthest_base_classes)

    # 基于欧氏距离，选择与第一个 base class 距离最近的5个 novel classes
    novel_distances = calculate_euclidean_distances(prototypes[novel_classes], prototypes[first_base_class].unsqueeze(0))
    closest_novel_classes = np.argsort(novel_distances[:, 0])[:5]  # 距离最近的5个 novel classes

    selected_novel_classes = list(closest_novel_classes + args.base_class)  # novel_classes 索引转换

    # 选定的 classes
    selected_classes = selected_base_classes + selected_novel_classes

    # 特征提取和标签
    features = []
    labels = []

    with torch.no_grad():
        for i, batch in enumerate(testloader, 1):
            data, test_label = [_.cuda() for _ in batch]
            features_batch = model(data)  # 提取特征
            features.append(features_batch.cpu())
            labels.append(test_label.cpu())

    features = torch.cat(features, dim=0)
    labels = torch.cat(labels, dim=0)

    # 仅保留选定类别的样本
    selected_indices = [i for i, label in enumerate(labels) if label in selected_classes]
    selected_features = features[selected_indices]
    selected_labels = labels[selected_indices]

    final_features = []
    final_labels = []

    # 为每个类选择与 prototype 欧氏距离最近的 num_samples_per_class 个样本
    for cls in selected_classes:
        cls_indices = np.where(selected_labels == cls)[0]
        cls_features = selected_features[cls_indices]

        # 计算每个样本到原型的欧氏距离
        prototype = prototypes[cls].unsqueeze(0)  # 该类别的原型
        distances = calculate_euclidean_distances(cls_features, prototype)  # 计算欧氏距离

        # 按距离排序并选择距离最近的 num_samples_per_class 个样本
        sorted_indices = np.argsort(distances[:, 0])[:num_samples_per_class]
        final_features.append(cls_features[sorted_indices])
        final_labels.append(selected_labels[cls_indices][sorted_indices])

    final_features = torch.cat(final_features, dim=0)
    final_labels = torch.cat(final_labels, dim=0)

    # 使用 T-SNE 对样本进行降维
    perplexity_value = min(30, max(5, len(final_features) // 3))  # 根据样本数量调整 perplexity
    reducer = TSNE(n_components=2, perplexity=20, n_iter=1000, random_state=seed)
    tsne_features = reducer.fit_transform(final_features)

    # 生成颜色
    colors = plt.cm.get_cmap('tab10', len(selected_classes))

    # 绘图
    plt.figure(figsize=(8, 8))

    # 保存每个类的颜色
    class_colors = {}

    # 画出 base test samples (圆点)
    for i, base_class in enumerate(selected_base_classes):
        base_indices = final_labels == base_class
        scatter = plt.scatter(tsne_features[base_indices, 0], tsne_features[base_indices, 1],
                              s=30, color=colors(i))
        class_colors[base_class] = scatter.get_facecolor()[0]  # 保存颜色

    # 画出 novel test samples (叉号)
    for i, novel_class in enumerate(selected_novel_classes):
        novel_indices = final_labels == novel_class
        scatter = plt.scatter(tsne_features[novel_indices, 0], tsne_features[novel_indices, 1],
                              s=30, marker='x', color=colors(len(selected_base_classes) + i))
        class_colors[novel_class] = scatter.get_facecolor()[0]  # 保存颜色

    # 将每个类的 prototype 放置在距离降维后样本最近的位置
    for i, cls in enumerate(selected_classes):
        prototype = prototypes[cls].unsqueeze(0)  # 获取原型
        distances_to_prototype = calculate_euclidean_distances(final_features, prototype)  # 计算降维后样本与原型的欧氏距离
        closest_sample_idx = np.argmin(distances_to_prototype)  # 找到距离原型最近的样本索引
        prototype_position = tsne_features[closest_sample_idx]  # 使用最近样本的降维位置作为原型的位置

        # 标记原型，用五角星表示，颜色和所属类别相同，带黑色边框
        plt.scatter(prototype_position[0], prototype_position[1], marker='*', s=200,
                    color=class_colors[cls], edgecolor='black', linewidths=1.5)

    # 自定义图例，只显示 Base, Novel 和 Prototypes
    legend_elements = [
        plt.Line2D([0], [0], marker='o', color='w', label='Base test sample', markerfacecolor='blue', markersize=10, linestyle='None'),
        plt.Line2D([0], [0], marker='x', color='red', label='Novel test sample', markersize=10, linestyle='None'),
        plt.Line2D([0], [0], marker='*', color='black', label='Weights/Prototypes', markersize=15, markerfacecolor='black', linestyle='None')
    ]

    plt.legend(handles=legend_elements, loc='best')

    # 去掉坐标轴刻度和小竖线
    plt.gca().set_xticks([])
    plt.gca().set_yticks([])

    # 保存图像为PDF文件
    save_path_pdf = os.path.join(os.getcwd(), f'metaadapter_tsne_session_{session}_seed_{seed}.pdf')
    plt.savefig(save_path_pdf, format='pdf', bbox_inches='tight')

    print(f"T-SNE 图已保存至 {save_path_pdf}")

import torch
import torch.nn.functional as F
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
import os

def test_and_plot_confusion_matrix(model, testloader, args, session):
    test_class = args.base_class + session * args.way
    model = model.eval()
    lgt = torch.tensor([])  # Tensor for predicted logits
    lbs = torch.tensor([])  # Tensor for true labels

    # 固定随机种子，确保结果一致
    torch.manual_seed(42)
    np.random.seed(42)

    with torch.no_grad():
        for i, batch in enumerate(testloader, 1):
            data, test_label = [_.cuda() for _ in batch]
            logits = model(data)  # Get model predictions (logits)
            logits = logits[:, :test_class]  # Select only relevant classes

            # Concatenate logits and labels for confusion matrix calculation
            lgt = torch.cat([lgt, logits.cpu()])
            lbs = torch.cat([lbs, test_label.cpu()])

    # Now that we have accumulated logits and labels, plot the confusion matrix
    plot_confusion_matrix_from_test(lgt, lbs, session, test_class)


def plot_confusion_matrix_from_test(logits, labels, session, class_count):
    """
    Generate and save confusion matrix plot based on logits and labels from the test function.
    Args:
    logits: Tensor containing the model's predicted logits.
    labels: Tensor containing the true labels.
    session: int representing the current session number.
    class_count: int representing the total number of classes (both base and novel).
    Returns:
    None. Saves the confusion matrix plot as a PDF file.
    """
    # Convert logits to predictions
    preds = torch.argmax(logits, dim=1)

    # Compute confusion matrix
    cm = confusion_matrix(labels, preds)

    # Normalize the confusion matrix to get percentages
    cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] * 100

    # Create a square plot with consistent DPI and figsize
    plt.figure(figsize=(10, 8), dpi=300)  # Set equal width and height for a square figure

    # Plot the confusion matrix using seaborn
    ax = sns.heatmap(cm_normalized, annot=False, fmt='.2f', cmap='magma', cbar=True, square=True, vmax=100)

    # Add title and labels
    plt.xlabel('Predicted Classes', fontsize=18)
    plt.ylabel('True Classes', fontsize=18)

    # Set specific ticks for both axes (0, 20, 40, 60, 80)
    tick_labels = [0, 20, 40, 60, 80]
    plt.xticks(tick_labels, tick_labels, fontsize=16)
    plt.yticks(tick_labels, tick_labels, fontsize=16)

    # Add cyan lines to differentiate base and novel classes (60 base, 40 novel)
    plt.axhline(y=60, color='cyan', linewidth=1)  # Horizontal line after the 60th class
    plt.axvline(x=60, color='cyan', linewidth=1)  # Vertical line after the 60th class

    # Ensure the plot has a square aspect ratio
    plt.gca().set_aspect('equal', adjustable='box')

    # Adjust the color bar font size
    cbar = ax.collections[0].colorbar
    cbar.ax.tick_params(labelsize=16)  # Set colorbar tick label size

    # Save the plot as a PDF file with consistent bbox_inches setting
    current_dir = os.getcwd()  # Get the current working directory
    save_filename = os.path.join(current_dir, f'metaadapter_confusion_matrix_session_{session}.pdf')
    plt.savefig(save_filename, bbox_inches='tight', format='pdf', dpi=300)

    # Show the plot
    plt.show()

    print(f"Confusion matrix saved as {save_filename}")


import torch
import matplotlib.pyplot as plt
import numpy as np
import random


def visualize_classes(testloader, args, session):
    test_class = args.base_class + session * args.way  # 总的类数（base + novel）
    base_class = args.base_class  # base classes 类数

    base_images = []
    novel_images = []

    # 遍历数据集
    with torch.no_grad():
        class_count = [0] * test_class  # 每个类的计数器
        for i, batch in enumerate(testloader):
            data, test_label = batch
            for img, label in zip(data, test_label):
                label = label.item()
                if label < base_class and class_count[label] == 0:
                    base_images.append(img)
                    class_count[label] += 1
                elif label >= base_class and class_count[label] == 0:
                    novel_images.append(img)
                    class_count[label] += 1

                # 当所有类都至少有一张图片时，停止选择
                if len(base_images) == base_class and len(novel_images) == session * args.way:
                    break

            # 如果已经选满所有类的图片，跳出循环
            if len(base_images) == base_class and len(novel_images) == session * args.way:
                break

    # 拼接 base class 和 novel class 的图片
    base_image_grid = make_image_grid(base_images, grid_size=(6, 10))  # 假设 8x8 网格
    novel_image_grid = make_image_grid(novel_images, grid_size=(4, 10))  # 假设 5x8 网格

    # 显示图像
    plt.figure(figsize=(10, 5))
    plt.subplot(1, 2, 1)
    plt.imshow(base_image_grid)
    plt.title('Base Classes')

    plt.subplot(1, 2, 2)
    plt.imshow(novel_image_grid)
    plt.title('Novel Classes')

    plt.show()


def make_image_grid(images, grid_size=(8, 8), img_size=(84, 84)):
    """将图像列表拼接为一个网格图像"""
    grid_h, grid_w = grid_size
    img_h, img_w = img_size
    grid_image = np.zeros((grid_h * img_h, grid_w * img_w, 3), dtype=np.uint8)

    for i, img in enumerate(images):
        img = img.permute(1, 2, 0).numpy()  # 将图像从 Tensor 转换为 numpy 并调整通道顺序
        img = (img * 255).astype(np.uint8)  # 将图像转换为 uint8 类型
        row = i // grid_w
        col = i % grid_w
        grid_image[row * img_h:(row + 1) * img_h, col * img_w:(col + 1) * img_w, :] = img

    return grid_image

# 示例调用（假设已定义 testloader、args 和 session）
# visualize_classes(testloader, args, session)
