import numpy as np
import torch, math
import torch.nn as nn
import torch.nn.functional as F
from lib.core.evaluate import accuracy, accuracy_shot, accuracy_shot_perclass
from torch.nn import functional as Func
import pdb
from lib.loss import CrossEntropy
from lib.loss.loss import loss_fn_kd, loss_fn_kd_binary, loss_fn_kd_KL
from lib.utils import to_one_hot

"""Combiner is only used for training model not for eval"""


class MCFM_Combiner:
    def __init__(self, cfg, device):
        self.cfg = cfg
        self.device = device

    def compute_distill_loss(self, output_for_distill, previous_task_model_output):
        # distill_previous_task_active_classes_num: dpt_active_classes_num
        if self.cfg.DISTILL.softmax_sigmoid == 0:
            distill_loss = loss_fn_kd(output_for_distill, previous_task_model_output, self.cfg.DISTILL.TEMP)
        elif self.cfg.DISTILL.softmax_sigmoid == 1:
            distill_loss = loss_fn_kd_binary(output_for_distill, previous_task_model_output, self.cfg.DISTILL.TEMP)
        else:
            loss_fn_kd_KL_forward = loss_fn_kd_KL()
            distill_loss = loss_fn_kd_KL_forward(output_for_distill, previous_task_model_output,
                                                 T=self.cfg.DISTILL.TEMP)
        return distill_loss

    def forward(self, model, criterion, current_image, current_label, active_classes_num, classes_per_task,
                previous_task_model, pre_tasks_imgs, pre_tasks_labels, balance_multiple, task, label_weight=None,
                reach_IB=False, criterion_ib=None):

        if task == 1 or (self.cfg.use_base_half and task == int(self.cfg.DATASET.all_tasks / 2)):
            if self.cfg.first_task_mix:
                mixup_current_images, mixup_current_labels_a, mixup_current_labels_b, mixup_current_lams = \
                    self.mixup_data(current_image, current_label, alpha_1=self.cfg.Mixup.mixup_alpha1,
                                    alpha_2=self.cfg.Mixup.mixup_alpha2)
                mixup_current_task_output = model(mixup_current_images, classifier_flag=True)
                mixup_current_task_output = mixup_current_task_output[:, 0:active_classes_num]
                mixup_current_task_cls_loss = self.mixup_criterion(criterion, mixup_current_task_output,
                                                                   mixup_current_labels_a, mixup_current_labels_b,
                                                                   mixup_current_lams)
                loss = [mixup_current_task_cls_loss * self.cfg.CLASSIFIER.LOSS_FACTOR]

                now_acc = [0]
                now_cnt = [0]

                if self.cfg.CLASSIFIER.NECK.distance_loss:
                    # todo inter_class loss and intra_class loss
                    current_features = model(current_image, feature_flag=True)
                    distance_loss = self.compute_distance_loss_first_task(all_features=current_features,
                                                                          all_labels=current_label)
                    loss.append(distance_loss * self.cfg.CLASSIFIER.NECK.LOSS_FACTOR)

                return loss, now_acc, now_cnt
            else:
                image, label = current_image.to(self.device), current_label.to(self.device)
                output = model(image, classifier_flag=True)
                output = output[:, 0:active_classes_num]
                # return criterion(output, label)
                _, now_result = torch.max(output, 1)
                # now_result = torch.max(y_hat, 1)
                now_acc, now_cnt = accuracy(now_result.cpu().numpy(), label.cpu().numpy())
                cls_loss = criterion(output, label)
                loss = [cls_loss * self.cfg.CLASSIFIER.LOSS_FACTOR]

                now_acc = [now_acc]
                now_cnt = [now_cnt]

                if self.cfg.CLASSIFIER.NECK.distance_loss:
                    # todo inter_class loss and intra_class loss
                    current_features = model(current_image, feature_flag=True)
                    distance_loss = self.compute_distance_loss_first_task(all_features=current_features,
                                                                          all_labels=current_label)
                    loss.append(distance_loss * self.cfg.CLASSIFIER.NECK.LOSS_FACTOR)

                return loss, now_acc, now_cnt

        current_image, current_label = current_image.to(self.device), current_label.to(self.device)
        previous_tasks_imgs, previous_tasks_labels = pre_tasks_imgs.to(self.device), pre_tasks_labels.to(self.device)
        # mixup_imgs, mixup_labels_a, mixup_labels_b, all_lams = \
        #     self.mixup_data(previous_tasks_imgs, previous_tasks_labels, alpha_1=-1.0, alpha_2=-1.0)
        mixup_imgs, mixup_labels_a, mixup_labels_b, all_lams = \
            self.mixup_data(previous_tasks_imgs, previous_tasks_labels, alpha_1=self.cfg.Mixup.mixup_alpha1,
                            alpha_2=self.cfg.Mixup.mixup_alpha2)
        balance_multiple -= 1
        if task > 2 or self.cfg.exemplar_manager.fixed_exemplar_num > 0 or self.cfg.Mixup.mix_balance:
            for i in range(0, balance_multiple, 2):
                mixup_imgs_temp, mixup_labels_a_temp, mixup_labels_b_temp, all_lams_temp = \
                    self.all_tasks_mixup_data(previous_tasks_imgs, previous_tasks_labels, current_image, current_label,
                                              alpha_1=self.cfg.Mixup.mixup_alpha1,
                                              alpha_2=self.cfg.Mixup.mixup_alpha2)
                mixup_imgs = torch.cat([mixup_imgs, mixup_imgs_temp], dim=0)
                mixup_labels_a = torch.cat([mixup_labels_a, mixup_labels_a_temp], dim=0)
                mixup_labels_b = torch.cat([mixup_labels_b, mixup_labels_b_temp], dim=0)
                all_lams = torch.cat([all_lams, all_lams_temp], dim=0)

                previous_tasks_self_mixup_imgs, previous_tasks_self_mixup_labels_a, previous_tasks_self_mixup_labels_b, \
                all_lams_temp = self.mixup_data(previous_tasks_imgs, previous_tasks_labels,
                                                alpha_1=self.cfg.Mixup.mixup_alpha1,
                                                alpha_2=self.cfg.Mixup.mixup_alpha2)
                mixup_imgs = torch.cat([mixup_imgs, previous_tasks_self_mixup_imgs], dim=0)
                mixup_labels_a = torch.cat([mixup_labels_a, previous_tasks_self_mixup_labels_a], dim=0)
                mixup_labels_b = torch.cat([mixup_labels_b, previous_tasks_self_mixup_labels_b], dim=0)
                all_lams = torch.cat([all_lams, all_lams_temp], dim=0)
        else:
            for i in range(0, balance_multiple):
                # if i > int(balance_multiple * (task - 1) / task):
                if i <= int(balance_multiple * (task - 1) / task):
                    # if i > int(balance_multiple * 2 / 3):
                    previous_tasks_self_mixup_imgs, previous_tasks_self_mixup_labels_a, previous_tasks_self_mixup_labels_b, \
                    all_lams_temp = self.mixup_data(previous_tasks_imgs, previous_tasks_labels,
                                                    alpha_1=self.cfg.Mixup.mixup_alpha1,
                                                    alpha_2=self.cfg.Mixup.mixup_alpha2)
                    mixup_imgs = torch.cat([mixup_imgs, previous_tasks_self_mixup_imgs], dim=0)
                    mixup_labels_a = torch.cat([mixup_labels_a, previous_tasks_self_mixup_labels_a], dim=0)
                    mixup_labels_b = torch.cat([mixup_labels_b, previous_tasks_self_mixup_labels_b], dim=0)
                    all_lams = torch.cat([all_lams, all_lams_temp], dim=0)
                else:
                    mixup_imgs_temp, mixup_labels_a_temp, mixup_labels_b_temp, all_lams_temp = \
                        self.all_tasks_mixup_data(previous_tasks_imgs, previous_tasks_labels, current_image,
                                                  current_label,
                                                  alpha_1=self.cfg.Mixup.mixup_alpha1,
                                                  alpha_2=self.cfg.Mixup.mixup_alpha2)
                    mixup_imgs = torch.cat([mixup_imgs, mixup_imgs_temp], dim=0)
                    mixup_labels_a = torch.cat([mixup_labels_a, mixup_labels_a_temp], dim=0)
                    mixup_labels_b = torch.cat([mixup_labels_b, mixup_labels_b_temp], dim=0)
                    all_lams = torch.cat([all_lams, all_lams_temp], dim=0)

        dpt_active_classes_num = active_classes_num - classes_per_task
        if self.cfg.Mixup.all:
            mixup_current_images, mixup_current_labels_a, mixup_current_labels_b, mixup_current_lams = \
                self.mixup_data(current_image, current_label, alpha_1=self.cfg.Mixup.mixup_alpha1,
                                alpha_2=self.cfg.Mixup.mixup_alpha2)
            mixup_current_task_output, mixup_current_features = model(mixup_current_images, feature_flag=True,
                                                                      classifier_flag=True)
            mixup_current_task_output = mixup_current_task_output[:, 0:active_classes_num]
            mixup_data_output, mixup_data_features = model(mixup_imgs, feature_flag=True, classifier_flag=True)
            mixup_data_output = mixup_data_output[:, 0:active_classes_num]
            previous_task_model_mixup_data_output = previous_task_model(mixup_imgs, is_nograd=True,
                                                                        classifier_flag=True)  # 获取classifier_output
            mixup_data_output_for_distill = mixup_data_output[:, 0:dpt_active_classes_num]
            previous_task_model_mixup_data_output = previous_task_model_mixup_data_output[:, 0:dpt_active_classes_num]
            distill_loss = self.compute_distill_loss(mixup_data_output_for_distill,
                                                     previous_task_model_mixup_data_output)
            loss = []
            mix_rate = (mixup_data_output.size(0) / (mixup_current_task_output.size(0) + mixup_data_output.size(0)))
            current_rate = 1 - mix_rate

            mixup_current_task_distill_loss = 0
            if task > 2 or self.cfg.use_current_task_for_distill:
                pre_task_model_current_data_output = previous_task_model(mixup_current_images, is_nograd=True,
                                                                         classifier_flag=True)  # 获取classifier_output
                pre_task_model_current_data_output = pre_task_model_current_data_output[:, 0:dpt_active_classes_num]
                mixup_current_task_output_for_distill = mixup_current_task_output[:, 0:dpt_active_classes_num]
                mixup_current_task_distill_loss = self.compute_distill_loss(mixup_current_task_output_for_distill,
                                                                            pre_task_model_current_data_output)

            if self.cfg.use_IB and reach_IB:
                #assert pre_task_model_current_data_output is not None
                if self.cfg.use_current_task_for_distill:
                    mixup_data_output = torch.cat([mixup_data_output, mixup_current_task_output], dim=0)
                    previous_task_model_mixup_data_output = torch.cat([previous_task_model_mixup_data_output,
                                                                       pre_task_model_current_data_output], dim=0)
                    mixup_labels_a = torch.cat([mixup_labels_a, mixup_current_labels_a], dim=0)
                    mixup_labels_b = torch.cat([mixup_labels_b, mixup_current_labels_b], dim=0)
                    mixup_data_features = torch.cat([mixup_data_features, mixup_current_features], dim=0)
                    all_lams = torch.cat([all_lams, mixup_current_lams], dim=0)

                cls_loss = self.mixup_criterion_trade_off_ib(criterion_ib, mixup_data_output,
                                                             previous_task_model_mixup_data_output,
                                                             mixup_labels_a, mixup_labels_b, all_lams,
                                                             mixup_data_features)
                cls_loss = cls_loss / task
                distill_loss += mixup_current_task_distill_loss
                distill_loss *= (task - 1) / task
                loss = [cls_loss, distill_loss]
            else:
                if "binary" in self.cfg.LOSS.LOSS_TYPE:
                    mixup_current_task_cls_loss = self.mixup_criterion_iCaRL(mixup_current_task_output,
                                                                             mixup_current_labels_a,
                                                                             mixup_current_labels_b,
                                                                             mixup_current_lams,
                                                                             classes_per_task)
                    mixup_data_cls_loss = self.mixup_criterion_iCaRL(mixup_data_output,
                                                                     mixup_labels_a, mixup_labels_b, all_lams,
                                                                     classes_per_task)

                else:
                    mixup_current_task_cls_loss = self.mixup_criterion(criterion, mixup_current_task_output,
                                                                       mixup_current_labels_a, mixup_current_labels_b,
                                                                       mixup_current_lams)
                    mixup_data_cls_loss = self.mixup_criterion(criterion, mixup_data_output,
                                                               mixup_labels_a, mixup_labels_b, all_lams)

                if self.cfg.pre_current_loss_balance:
                    distill_loss *= mix_rate
                    distill_loss += mixup_current_task_distill_loss * current_rate
                    mixup_current_task_cls_loss *= current_rate
                    mixup_data_cls_loss *= mix_rate
                else:
                    distill_loss += mixup_current_task_distill_loss

                # now_result = torch.argmax(current_task_output, 1)
                # now_acc, now_cnt = accuracy(now_result.cpu().numpy(), current_label.cpu().numpy())
                if "binary" in self.cfg.LOSS.LOSS_TYPE:
                    cls_loss = (mixup_current_task_cls_loss + mixup_data_cls_loss) * self.cfg.CLASSIFIER.LOSS_FACTOR
                    loss.append(cls_loss)
                    loss.append(distill_loss * self.cfg.DISTILL.LOSS_FACTOR)
                else:
                    cls_loss = (
                                       mixup_current_task_cls_loss + mixup_data_cls_loss) * self.cfg.CLASSIFIER.LOSS_FACTOR / task
                    loss.append(cls_loss)
                    loss.append(distill_loss * self.cfg.DISTILL.LOSS_FACTOR * (task - 1) / task)
            now_acc, now_cnt = [0], [0]

        else:
            current_task_output, current_task_features = model(current_image, feature_flag=True, classifier_flag=True)
            current_task_output = current_task_output[:, 0:active_classes_num]
            mixup_data_output = model(mixup_imgs, classifier_flag=True)
            mixup_data_output = mixup_data_output[:, 0:active_classes_num]
            previous_task_model_mixup_data_output = previous_task_model(mixup_imgs, is_nograd=True,
                                                                        classifier_flag=True)  # 获取classifier_output
            mixup_data_output_for_distill = mixup_data_output[:, 0:dpt_active_classes_num]
            previous_task_model_mixup_data_output = previous_task_model_mixup_data_output[:, 0:dpt_active_classes_num]
            distill_loss = self.compute_distill_loss(mixup_data_output_for_distill,
                                                     previous_task_model_mixup_data_output)

            mix_rate = (mixup_data_output.size(0) / (current_task_output.size(0) + mixup_data_output.size(0)))
            current_rate = 1 - mix_rate
            current_task_distill_loss = 0
            if task > 2 or self.cfg.use_current_task_for_distill or self.cfg.DATASET.all_tasks < 3:
                pre_task_model_current_data_output = previous_task_model(current_image, is_nograd=True,
                                                                         classifier_flag=True)  # 获取classifier_output
                pre_task_model_current_data_output = pre_task_model_current_data_output[:, 0:dpt_active_classes_num]
                current_task_output_for_distill = current_task_output[:, 0:dpt_active_classes_num]
                current_task_distill_loss = self.compute_distill_loss(current_task_output_for_distill,
                                                                      pre_task_model_current_data_output)
            if self.cfg.use_IB and reach_IB:
                # previous_tasks_imgs, previous_tasks_labels
                pre_task_data_output, pre_task_data_features = model(previous_tasks_imgs, feature_flag=True,
                                                                     classifier_flag=True)
                pre_task_data_output = pre_task_data_output[:, 0:active_classes_num]
                all_data_output = torch.cat([pre_task_data_output, current_task_output], dim=0)
                all_data_features = torch.cat([pre_task_data_features, current_task_features], dim=0)
                all_data_labels = torch.cat([previous_tasks_labels, current_label], dim=0)
                cls_loss = criterion_ib(all_data_output, all_data_labels, all_data_features)
                '''cls_loss = criterion_ib(loss_vec=F.cross_entropy(all_data_output, all_data_labels, reduction='none', 
                                                                 weight=criterion_ib.weight), input=all_data_output, 
                                        previous_model_output=pre_task_data_output, target=all_data_labels, 
                                        features=all_data_features, reduction="mean")'''
                mixup_current_task_cls_loss = 0
                mixup_data_cls_loss = 0
                if self.cfg.plus_mix_cls:
                    mixup_data_cls_loss = self.mixup_criterion(criterion, mixup_data_output,
                                                               mixup_labels_a, mixup_labels_b, all_lams)
                    mixup_current_images, mixup_current_labels_a, mixup_current_labels_b, mixup_current_lams = \
                        self.mixup_data(current_image, current_label, alpha_1=self.cfg.Mixup.mixup_alpha1,
                                        alpha_2=self.cfg.Mixup.mixup_alpha2)
                    mixup_current_task_output, mixup_current_features = model(mixup_current_images, feature_flag=True,
                                                                              classifier_flag=True)
                    mixup_current_task_output = mixup_current_task_output[:, 0:active_classes_num]
                    mixup_current_task_cls_loss = self.mixup_criterion(criterion, mixup_current_task_output,
                                                                       mixup_current_labels_a, mixup_current_labels_b,
                                                                       mixup_current_lams)
                if self.cfg.pre_current_loss_balance:
                    distill_loss *= mix_rate
                    distill_loss += current_task_distill_loss * current_rate
                    mixup_current_task_cls_loss *= current_rate
                    mixup_data_cls_loss *= mix_rate
                else:
                    distill_loss += current_task_distill_loss
                cls_loss += (mixup_data_cls_loss + mixup_current_task_cls_loss) * self.cfg.mix_cls_alpha
                cls_loss = cls_loss / task
                distill_loss *= (task - 1) / task
                loss = [cls_loss, distill_loss]
            else:
                cls_loss = 0
                if self.cfg.use_mix_cls:
                    mixup_data_cls_loss = self.mixup_criterion(criterion, mixup_data_output,
                                                               mixup_labels_a, mixup_labels_b, all_lams)
                    mixup_current_images, mixup_current_labels_a, mixup_current_labels_b, mixup_current_lams = \
                        self.mixup_data(current_image, current_label, alpha_1=self.cfg.Mixup.mixup_alpha1,
                                        alpha_2=self.cfg.Mixup.mixup_alpha2)
                    mixup_current_task_output, mixup_current_features = model(mixup_current_images, feature_flag=True,
                                                                              classifier_flag=True)
                    mixup_current_task_output = mixup_current_task_output[:, 0:active_classes_num]
                    mixup_current_task_cls_loss = self.mixup_criterion(criterion, mixup_current_task_output,
                                                                       mixup_current_labels_a, mixup_current_labels_b,
                                                                       mixup_current_lams)
                    if self.cfg.pre_current_loss_balance:
                        distill_loss *= mix_rate
                        distill_loss += current_task_distill_loss * current_rate
                        mixup_current_task_cls_loss *= current_rate
                        mixup_data_cls_loss *= mix_rate
                    else:
                        distill_loss += current_task_distill_loss
                    cls_loss = (mixup_data_cls_loss + mixup_current_task_cls_loss)
                    cls_loss = cls_loss * self.cfg.CLASSIFIER.LOSS_FACTOR / task
                else:
                    current_task_cls_loss = criterion(current_task_output, current_label)

                    mixup_data_cls_loss = self.mixup_criterion(criterion, mixup_data_output,
                                                               mixup_labels_a, mixup_labels_b, all_lams)

                    if self.cfg.pre_current_loss_balance:
                        distill_loss *= mix_rate
                        distill_loss += current_task_distill_loss * current_rate
                        current_task_cls_loss *= current_rate
                        mixup_data_cls_loss *= mix_rate
                    else:
                        distill_loss += current_task_distill_loss
                    cls_loss = (current_task_cls_loss + mixup_data_cls_loss) * self.cfg.CLASSIFIER.LOSS_FACTOR * \
                               (active_classes_num - dpt_active_classes_num) / active_classes_num
                distill_loss *= self.cfg.DISTILL.LOSS_FACTOR * dpt_active_classes_num / active_classes_num
                loss = [cls_loss, distill_loss]

            now_result = torch.argmax(current_task_output, 1)
            now_acc, now_cnt = accuracy(now_result.cpu().numpy(), current_label.cpu().numpy())
            now_acc, now_cnt = [now_acc], [now_cnt]

        if self.cfg.CLASSIFIER.NECK.distance_loss:
            # todo inter_class loss and intra_class loss
            current_features = model(current_image, feature_flag=True)
            previous_tasks_features = model(previous_tasks_imgs, feature_flag=True)
            label_weight = label_weight.to(current_image.device)
            # previous_tasks_imgs, previous_tasks_labels

            distance_loss = self.compute_distance_loss_LT(all_features=current_features, all_labels=current_label,
                                                          previous_tasks_features=previous_tasks_features,
                                                          previous_tasks_labels=previous_tasks_labels,
                                                          label_weight=label_weight)
            loss.append(distance_loss * self.cfg.CLASSIFIER.NECK.LOSS_FACTOR)

        return loss, now_acc, now_cnt

    def compute_distance_loss_first_task(self, all_features, all_labels):
        class_mean_features = []
        class_feature_num = []
        class_intra_distance = []
        for l in all_labels.unique():
            label_index = all_labels == l
            center_feature = all_features[label_index].mean(0)
            class_mean_features.append(center_feature)
            if self.cfg.CLASSIFIER.NECK.INTRA_DISTANCE:
                class_intra_distance.append((1 - torch.nn.functional.cosine_similarity(center_feature[None, :],
                                                                                       all_features[
                                                                                           label_index])).mean())
            class_feature_num.append(label_index.sum())

        class_mean_features = torch.stack(class_mean_features)
        class_feature_num = torch.stack(class_feature_num)
        c_batch = class_mean_features.shape[0]
        if self.cfg.CLASSIFIER.NECK.INTRA_DISTANCE:
            class_intra_distance = torch.stack(class_intra_distance)
        ignore_mask = 1 - torch.eye(c_batch).to(class_mean_features.device)  # 将对角线置0,非对角线置1
        distance_loss = 0  # distance_loss = 0

        if self.cfg.CLASSIFIER.NECK.INTER_DISTANCE:
            class_mean_features_norm = class_mean_features / class_mean_features.norm(dim=1)[:, None]
            distance_matrix = (1 - torch.mm(class_mean_features_norm, class_mean_features_norm.t()))
            class_inter_loss = torch.nn.functional.hinge_embedding_loss(distance_matrix,
                                                                        -1 * torch.ones_like(
                                                                            distance_matrix).long(),
                                                                        margin=self.cfg.CLASSIFIER.NECK.MARGIN,
                                                                        reduction='none')

            reduction_factor = all_labels.unique().shape[0] ** 2 - all_labels.unique().shape[0]

            class_inter_loss = (class_inter_loss * ignore_mask).sum() / reduction_factor  # 计算平均class_inter_loss
            distance_loss += class_inter_loss

        if self.cfg.CLASSIFIER.NECK.INTRA_DISTANCE:
            if (class_feature_num > 1).any():
                class_intra_loss = class_intra_distance[class_feature_num > 1].mean()
            else:
                class_intra_loss = 0
            distance_loss += class_intra_loss
        return distance_loss

    def compute_distance_loss_LT(self, all_features, all_labels, previous_tasks_features, previous_tasks_labels,
                                 label_weight):
        all_features = torch.cat([all_features, previous_tasks_features])
        all_labels = torch.cat([all_labels, previous_tasks_labels])
        all_weights = label_weight

        class_mean_features = []
        class_feature_num = []
        class_feature_weight = []
        class_intra_distance = []
        for l in all_labels.unique():
            label_index = all_labels == l
            center_feature = all_features[label_index].mean(0)
            class_mean_features.append(center_feature)  # 在一个batch中每个label对应的center_feature
            if self.cfg.CLASSIFIER.NECK.INTRA_DISTANCE:
                class_intra_distance.append((1 - torch.nn.functional.cosine_similarity(center_feature[None, :],
                                                                                       all_features[
                                                                                           label_index])).mean())
            class_feature_num.append(label_index.sum())  # 在一个batch中每个label对应的imgs或者features数目
            class_feature_weight.append(all_weights[l])  # 在一个batch中每个imgs对应的weight

        class_mean_features = torch.stack(class_mean_features)
        class_feature_num = torch.stack(class_feature_num)
        class_feature_weight = torch.stack(class_feature_weight)
        c_batch = class_mean_features.shape[0]
        if self.cfg.CLASSIFIER.NECK.INTRA_DISTANCE:
            class_intra_distance = torch.stack(class_intra_distance)
        ignore_mask = 1 - torch.eye(c_batch).to(class_mean_features.device)  # 将对角线置0,非对角线置1
        distance_loss = 0  # distance_loss = 0

        if self.cfg.CLASSIFIER.NECK.INTER_DISTANCE:
            class_mean_features_norm = class_mean_features / class_mean_features.norm(dim=1)[:, None]
            distance_matrix = (1 - torch.mm(class_mean_features_norm, class_mean_features_norm.t()))
            class_inter_loss = torch.nn.functional.hinge_embedding_loss(distance_matrix,
                                                                        -1 * torch.ones_like(
                                                                            distance_matrix).long(),
                                                                        margin=self.cfg.CLASSIFIER.NECK.MARGIN,
                                                                        reduction='none')

            if self.cfg.CLASSIFIER.NECK.WEIGHT_INTER_LOSS:
                weight_norm_matrix = class_feature_weight[:, None].repeat(1, c_batch) + class_feature_weight[None,
                                                                                        :].repeat(c_batch, 1)
                weight_norm_matrix /= (weight_norm_matrix * ignore_mask).sum()
                class_inter_loss = class_inter_loss * weight_norm_matrix

            if self.cfg.CLASSIFIER.NECK.WEIGHT_INTER_LOSS:
                reduction_factor = 1
            else:
                reduction_factor = all_labels.unique().shape[0] ** 2 - all_labels.unique().shape[0]

            class_inter_loss = (class_inter_loss * ignore_mask).sum() / reduction_factor  # 计算平均class_inter_loss
            distance_loss += class_inter_loss

        if self.cfg.CLASSIFIER.NECK.INTRA_DISTANCE:
            if (class_feature_num > 1).any():
                class_intra_loss = class_intra_distance[class_feature_num > 1].mean()
            else:
                class_intra_loss = 0
            distance_loss += class_intra_loss
        return distance_loss

    def mixup_data(self, x, y, alpha_1=1.0, alpha_2=1.0):
        '''Returns mixed inputs, pairs of targets, and lambda'''
        if alpha_1 > 0:
            lam = np.random.beta(alpha_1, alpha_2)
            # lam = np.random.uniform(0, 1)
        else:
            lam = 1

        batch_size = x.size()[0]
        index = torch.randperm(batch_size).to(self.device)

        mixed_x = lam * x + (1 - lam) * x[index, :]
        y_a, y_b = y, y[index]
        all_lams = torch.ones_like(y) * lam
        return mixed_x, y_a, y_b, all_lams

    def all_tasks_mixup_data(self, pre_tasks_imgs, pre_tasks_labels, current_image, current_label, alpha_1=1,
                             alpha_2=1):
        '''Returns mixed inputs, pairs of targets, and lambda'''
        if alpha_1 > 0:
            lam = np.random.beta(alpha_1, alpha_2)
            # lam = np.random.uniform(0, 1)
        else:
            lam = 1

        batch_size = pre_tasks_imgs.size()[0]
        index = torch.randperm(batch_size).to(self.device)

        mixed_x = lam * pre_tasks_imgs + (1 - lam) * current_image[index, :]
        y_a, y_b = pre_tasks_labels, current_label[index]
        all_lams = torch.ones_like(pre_tasks_labels) * lam
        return mixed_x, y_a, y_b, all_lams

    def mixup_criterion_iCaRL(self, pred, y_a, y_b, lam, classes_per_task):

        return (lam * self.compute_cls_binary_loss(y_a, pred, classes_per_task) +
                (1 - lam) * self.compute_cls_binary_loss(y_b, pred, classes_per_task)).mean()


    def mixup_criterion_trade_off_ib(self, criterion_trade_off_ib, mixup_data_output, previous_model_output,
                                     mixup_labels_a, mixup_labels_b, all_lams, mixup_data_features):
        if self.cfg.use_weight:
            loss_vec = all_lams * F.cross_entropy(mixup_data_output, mixup_labels_a, reduction='none',
                                                  weight=criterion_trade_off_ib.weight) + \
                       (1 - all_lams) * F.cross_entropy(mixup_data_output, mixup_labels_b, reduction='none',
                                                        weight=criterion_trade_off_ib.weight)
        else:
            loss_vec = all_lams * F.cross_entropy(mixup_data_output, mixup_labels_a, reduction='none') + \
                       (1 - all_lams) * F.cross_entropy(mixup_data_output, mixup_labels_b, reduction='none')
        return criterion_trade_off_ib(loss_vec, mixup_data_output, previous_model_output,
                                      mixup_labels_a, mixup_labels_b, all_lams, mixup_data_features, reduction='none').mean()
        pass

    @staticmethod
    def mixup_criterion(criterion, output, y_a, y_b, lam):
        return (lam * criterion(output, y_a, reduction='none') +
                (1 - lam) * criterion(output, y_b, reduction='none')).mean()

    @staticmethod
    def compute_cls_binary_loss(labels, output, classes_per_task):

        binary_targets = to_one_hot(labels.cpu(), output.size(1)).to(labels.device)
        binary_targets = binary_targets[:, -classes_per_task:]
        output_for_newclass_cls = output[:, -classes_per_task:]
        predL = Func.binary_cross_entropy_with_logits(
            input=output_for_newclass_cls, target=binary_targets, reduction='none'
        ).sum(dim=1)  # --> sum over classes, then average over batch
        return predL
