import numpy as np
import torch, math
import torch.nn as nn
import torch.nn.functional as F
from lib.core.evaluate import accuracy, accuracy_shot, accuracy_shot_perclass
from torch.nn import functional as Func
import pdb
from lib.loss import CrossEntropy
from lib.loss.loss import loss_fn_kd, loss_fn_kd_binary, loss_fn_kd_KL
from lib.utils import to_one_hot

"""Combiner is only used for training model not for eval"""


class CIL_Combiner:
    def __init__(self, cfg, device):
        self.cfg = cfg
        self.device = device

    def compute_distill_loss(self, output_for_distill, previous_task_model_output):
        # distill_previous_task_active_classes_num: dpt_active_classes_num
        if self.cfg.DISTILL.softmax_sigmoid == 0:
            distill_loss = loss_fn_kd(output_for_distill, previous_task_model_output, self.cfg.DISTILL.TEMP)
        elif self.cfg.DISTILL.softmax_sigmoid == 1:
            distill_loss = loss_fn_kd_binary(output_for_distill, previous_task_model_output, self.cfg.DISTILL.TEMP)
        else:
            loss_fn_kd_KL_forward = loss_fn_kd_KL()
            distill_loss = loss_fn_kd_KL_forward(output_for_distill, previous_task_model_output,
                                                 T=self.cfg.DISTILL.TEMP)
        return distill_loss

    '''self.model, criterion, current_image, current_label, active_classes_num,
    self.dataset_handler.classes_per_task, self.pre_tasks_model,
    task, label_weight, criterion_ib = criterion_ib, reach_IB = reach_IB'''

    def forward(self, model, criterion, current_image, current_label, active_classes_num, classes_per_task,
                previous_task_model, task, reach_IB=False, criterion_ib=None):

        if task == 1 or (self.cfg.use_base_half and task == int(self.cfg.DATASET.all_tasks / 2)):
            if self.cfg.first_task_mix:
                mixup_current_images, mixup_current_labels_a, mixup_current_labels_b, mixup_current_lams = \
                    self.mixup_data(current_image, current_label, alpha_1=self.cfg.Mixup.mixup_alpha1,
                                    alpha_2=self.cfg.Mixup.mixup_alpha2)
                mixup_current_task_output = model(mixup_current_images, classifier_flag=True)
                mixup_current_task_output = mixup_current_task_output[:, 0:active_classes_num]
                mixup_current_task_cls_loss = self.mixup_criterion(criterion, mixup_current_task_output,
                                                                   mixup_current_labels_a, mixup_current_labels_b,
                                                                   mixup_current_lams)
                loss = [mixup_current_task_cls_loss * self.cfg.CLASSIFIER.LOSS_FACTOR]

                now_acc = [0]
                now_cnt = [0]

                if self.cfg.CLASSIFIER.NECK.distance_loss:
                    # todo inter_class loss and intra_class loss
                    current_features = model(current_image, feature_flag=True)
                    distance_loss = self.compute_distance_loss_first_task(all_features=current_features,
                                                                          all_labels=current_label)
                    loss.append(distance_loss * self.cfg.CLASSIFIER.NECK.LOSS_FACTOR)

                return loss, now_acc, now_cnt
            else:
                image, label = current_image.to(self.device), current_label.to(self.device)
                output = model(image, classifier_flag=True)
                output = output[:, 0:active_classes_num]
                # return criterion(output, label)
                _, now_result = torch.max(output, 1)
                # now_result = torch.max(y_hat, 1)
                now_acc, now_cnt = accuracy(now_result.cpu().numpy(), label.cpu().numpy())
                cls_loss = criterion(output, label)
                loss = [cls_loss * self.cfg.CLASSIFIER.LOSS_FACTOR]

                now_acc = [now_acc]
                now_cnt = [now_cnt]

                if self.cfg.CLASSIFIER.NECK.distance_loss:
                    # todo inter_class loss and intra_class loss
                    current_features = model(current_image, feature_flag=True)
                    distance_loss = self.compute_distance_loss_first_task(all_features=current_features,
                                                                          all_labels=current_label)
                    loss.append(distance_loss * self.cfg.CLASSIFIER.NECK.LOSS_FACTOR)

                return loss, now_acc, now_cnt

        current_image, current_label = current_image.to(self.device), current_label.to(self.device)
        dpt_active_classes_num = active_classes_num - classes_per_task
        loss = []

        if self.cfg.use_IB and reach_IB:
            if self.cfg.re_mix and self.cfg.Mixup.all:
                mixup_current_images, mixup_current_labels_a, mixup_current_labels_b, mixup_current_lams = \
                    self.mixup_data(current_image, current_label, alpha_1=self.cfg.Mixup.mixup_alpha1,
                                    alpha_2=self.cfg.Mixup.mixup_alpha2)
                mixup_data_output, mixup_features = model(mixup_current_images, feature_flag=True,
                                                          classifier_flag=True)
                mixup_data_output = mixup_data_output[:, 0:active_classes_num]

                previous_task_model_mixup_data_output = previous_task_model(mixup_current_images, is_nograd=True,
                                                                            classifier_flag=True)  # 获取classifier_output
                mixup_data_output_for_distill = mixup_data_output[:, 0:dpt_active_classes_num]
                previous_task_model_mixup_data_output = previous_task_model_mixup_data_output[:,
                                                        0:dpt_active_classes_num]
                distill_loss = self.compute_distill_loss(mixup_data_output_for_distill,
                                                         previous_task_model_mixup_data_output)
                cls_loss = self.mixup_criterion_trade_off_ib(criterion_ib, mixup_data_output,
                                                             previous_task_model_mixup_data_output,
                                                             mixup_current_labels_a, mixup_current_labels_b,
                                                             mixup_current_lams,
                                                             mixup_features)
            else:
                all_data_output, all_data_features = model(current_image, feature_flag=True,
                                                           classifier_flag=True)
                all_data_output = all_data_output[:, 0:active_classes_num]
                output_for_distill = all_data_output[:, 0:dpt_active_classes_num]
                previous_task_model_mixup_data_output = previous_task_model(current_image, is_nograd=True,
                                                                            classifier_flag=True)  # 获取classifier_output
                pre_model_output_for_distill = previous_task_model_mixup_data_output[:, 0:dpt_active_classes_num]
                distill_loss = self.compute_distill_loss(output_for_distill, pre_model_output_for_distill)
                cls_loss = criterion_ib(all_data_output, current_label, all_data_features)
                '''cls_loss = criterion_ib(loss_vec=F.cross_entropy(all_data_output, all_data_labels, reduction='none', 
                                                                 weight=criterion_ib.weight), input=all_data_output, 
                                        previous_model_output=pre_task_data_output, target=all_data_labels, 
                                        features=all_data_features, reduction="mean")'''

            cls_loss = cls_loss / task
            distill_loss *= (task - 1) / task
            loss = [cls_loss, distill_loss]
        else:
            mixup_current_images, mixup_current_labels_a, mixup_current_labels_b, mixup_current_lams = \
                self.mixup_data(current_image, current_label, alpha_1=self.cfg.Mixup.mixup_alpha1,
                                alpha_2=self.cfg.Mixup.mixup_alpha2)
            mixup_data_output, mixup_features = model(mixup_current_images, feature_flag=True,
                                                      classifier_flag=True)
            mixup_data_output = mixup_data_output[:, 0:active_classes_num]

            previous_task_model_mixup_data_output = previous_task_model(mixup_current_images, is_nograd=True,
                                                                        classifier_flag=True)  # 获取classifier_output
            mixup_data_output_for_distill = mixup_data_output[:, 0:dpt_active_classes_num]
            previous_task_model_mixup_data_output = previous_task_model_mixup_data_output[:, 0:dpt_active_classes_num]
            distill_loss = self.compute_distill_loss(mixup_data_output_for_distill,
                                                     previous_task_model_mixup_data_output)
            if "binary" in self.cfg.LOSS.LOSS_TYPE:
                mixup_current_task_cls_loss = self.mixup_criterion_iCaRL(mixup_data_output,
                                                                         mixup_current_labels_a,
                                                                         mixup_current_labels_b,
                                                                         mixup_current_lams,
                                                                         classes_per_task)

            else:
                mixup_current_task_cls_loss = self.mixup_criterion(criterion, mixup_data_output,
                                                                   mixup_current_labels_a, mixup_current_labels_b,
                                                                   mixup_current_lams)

            # now_result = torch.argmax(current_task_output, 1)
            # now_acc, now_cnt = accuracy(now_result.cpu().numpy(), current_label.cpu().numpy()
            cls_loss = mixup_current_task_cls_loss * self.cfg.CLASSIFIER.LOSS_FACTOR / task
            loss.append(cls_loss)
            loss.append(distill_loss * self.cfg.DISTILL.LOSS_FACTOR * (task - 1) / task)
        now_acc, now_cnt = [0], [0]

        return loss, now_acc, now_cnt

    def compute_distance_loss_first_task(self, all_features, all_labels):
        class_mean_features = []
        class_feature_num = []
        class_intra_distance = []
        for l in all_labels.unique():
            label_index = all_labels == l
            center_feature = all_features[label_index].mean(0)
            class_mean_features.append(center_feature)
            if self.cfg.CLASSIFIER.NECK.INTRA_DISTANCE:
                class_intra_distance.append((1 - torch.nn.functional.cosine_similarity(center_feature[None, :],
                                                                                       all_features[
                                                                                           label_index])).mean())
            class_feature_num.append(label_index.sum())

        class_mean_features = torch.stack(class_mean_features)
        class_feature_num = torch.stack(class_feature_num)
        c_batch = class_mean_features.shape[0]
        if self.cfg.CLASSIFIER.NECK.INTRA_DISTANCE:
            class_intra_distance = torch.stack(class_intra_distance)
        ignore_mask = 1 - torch.eye(c_batch).to(class_mean_features.device)  # 将对角线置0,非对角线置1
        distance_loss = 0  # distance_loss = 0

        if self.cfg.CLASSIFIER.NECK.INTER_DISTANCE:
            class_mean_features_norm = class_mean_features / class_mean_features.norm(dim=1)[:, None]
            distance_matrix = (1 - torch.mm(class_mean_features_norm, class_mean_features_norm.t()))
            class_inter_loss = torch.nn.functional.hinge_embedding_loss(distance_matrix,
                                                                        -1 * torch.ones_like(
                                                                            distance_matrix).long(),
                                                                        margin=self.cfg.CLASSIFIER.NECK.MARGIN,
                                                                        reduction='none')

            reduction_factor = all_labels.unique().shape[0] ** 2 - all_labels.unique().shape[0]

            class_inter_loss = (class_inter_loss * ignore_mask).sum() / reduction_factor  # 计算平均class_inter_loss
            distance_loss += class_inter_loss

        if self.cfg.CLASSIFIER.NECK.INTRA_DISTANCE:
            if (class_feature_num > 1).any():
                class_intra_loss = class_intra_distance[class_feature_num > 1].mean()
            else:
                class_intra_loss = 0
            distance_loss += class_intra_loss
        return distance_loss

    def compute_distance_loss_LT(self, all_features, all_labels, previous_tasks_features, previous_tasks_labels,
                                 label_weight):
        all_features = torch.cat([all_features, previous_tasks_features])
        all_labels = torch.cat([all_labels, previous_tasks_labels])
        all_weights = label_weight

        class_mean_features = []
        class_feature_num = []
        class_feature_weight = []
        class_intra_distance = []
        for l in all_labels.unique():
            label_index = all_labels == l
            center_feature = all_features[label_index].mean(0)
            class_mean_features.append(center_feature)  # 在一个batch中每个label对应的center_feature
            if self.cfg.CLASSIFIER.NECK.INTRA_DISTANCE:
                class_intra_distance.append((1 - torch.nn.functional.cosine_similarity(center_feature[None, :],
                                                                                       all_features[
                                                                                           label_index])).mean())
            class_feature_num.append(label_index.sum())  # 在一个batch中每个label对应的imgs或者features数目
            class_feature_weight.append(all_weights[l])  # 在一个batch中每个imgs对应的weight

        class_mean_features = torch.stack(class_mean_features)
        class_feature_num = torch.stack(class_feature_num)
        class_feature_weight = torch.stack(class_feature_weight)
        c_batch = class_mean_features.shape[0]
        if self.cfg.CLASSIFIER.NECK.INTRA_DISTANCE:
            class_intra_distance = torch.stack(class_intra_distance)
        ignore_mask = 1 - torch.eye(c_batch).to(class_mean_features.device)  # 将对角线置0,非对角线置1
        distance_loss = 0  # distance_loss = 0

        if self.cfg.CLASSIFIER.NECK.INTER_DISTANCE:
            class_mean_features_norm = class_mean_features / class_mean_features.norm(dim=1)[:, None]
            distance_matrix = (1 - torch.mm(class_mean_features_norm, class_mean_features_norm.t()))
            class_inter_loss = torch.nn.functional.hinge_embedding_loss(distance_matrix,
                                                                        -1 * torch.ones_like(
                                                                            distance_matrix).long(),
                                                                        margin=self.cfg.CLASSIFIER.NECK.MARGIN,
                                                                        reduction='none')

            if self.cfg.CLASSIFIER.NECK.WEIGHT_INTER_LOSS:
                weight_norm_matrix = class_feature_weight[:, None].repeat(1, c_batch) + class_feature_weight[None,
                                                                                        :].repeat(c_batch, 1)
                weight_norm_matrix /= (weight_norm_matrix * ignore_mask).sum()
                class_inter_loss = class_inter_loss * weight_norm_matrix

            if self.cfg.CLASSIFIER.NECK.WEIGHT_INTER_LOSS:
                reduction_factor = 1
            else:
                reduction_factor = all_labels.unique().shape[0] ** 2 - all_labels.unique().shape[0]

            class_inter_loss = (class_inter_loss * ignore_mask).sum() / reduction_factor  # 计算平均class_inter_loss
            distance_loss += class_inter_loss

        if self.cfg.CLASSIFIER.NECK.INTRA_DISTANCE:
            if (class_feature_num > 1).any():
                class_intra_loss = class_intra_distance[class_feature_num > 1].mean()
            else:
                class_intra_loss = 0
            distance_loss += class_intra_loss
        return distance_loss

    def mixup_data(self, x, y, alpha_1=1.0, alpha_2=1.0):
        '''Returns mixed inputs, pairs of targets, and lambda'''
        if alpha_1 > 0:
            lam = np.random.beta(alpha_1, alpha_2)
            # lam = np.random.uniform(0, 1)
        else:
            lam = 1

        batch_size = x.size()[0]
        index = torch.randperm(batch_size).to(self.device)

        mixed_x = lam * x + (1 - lam) * x[index, :]
        y_a, y_b = y, y[index]
        all_lams = torch.ones_like(y) * lam
        return mixed_x, y_a, y_b, all_lams

    def all_tasks_mixup_data(self, pre_tasks_imgs, pre_tasks_labels, current_image, current_label, alpha_1=1,
                             alpha_2=1):
        '''Returns mixed inputs, pairs of targets, and lambda'''
        if alpha_1 > 0:
            lam = np.random.beta(alpha_1, alpha_2)
            # lam = np.random.uniform(0, 1)
        else:
            lam = 1

        batch_size = pre_tasks_imgs.size()[0]
        index = torch.randperm(batch_size).to(self.device)

        mixed_x = lam * pre_tasks_imgs + (1 - lam) * current_image[index, :]
        y_a, y_b = pre_tasks_labels, current_label[index]
        all_lams = torch.ones_like(pre_tasks_labels) * lam
        return mixed_x, y_a, y_b, all_lams

    def mixup_criterion_iCaRL(self, pred, y_a, y_b, lam, classes_per_task):

        return (lam * self.compute_cls_binary_loss(y_a, pred, classes_per_task) +
                (1 - lam) * self.compute_cls_binary_loss(y_b, pred, classes_per_task)).mean()


    def mixup_criterion_trade_off_ib(self, criterion_trade_off_ib, mixup_data_output, previous_model_output,
                                     mixup_labels_a, mixup_labels_b, all_lams, mixup_data_features):
        if self.cfg.use_weight:
            loss_vec = all_lams * F.cross_entropy(mixup_data_output, mixup_labels_a, reduction='none',
                                                  weight=criterion_trade_off_ib.weight) + \
                       (1 - all_lams) * F.cross_entropy(mixup_data_output, mixup_labels_b, reduction='none',
                                                        weight=criterion_trade_off_ib.weight)
        else:
            loss_vec = all_lams * F.cross_entropy(mixup_data_output, mixup_labels_a, reduction='none') + \
                       (1 - all_lams) * F.cross_entropy(mixup_data_output, mixup_labels_b, reduction='none')
        return criterion_trade_off_ib(loss_vec, mixup_data_output, previous_model_output,
                                      mixup_labels_a, mixup_labels_b, all_lams, mixup_data_features,
                                      reduction='none').mean()
        pass

    @staticmethod
    def mixup_criterion(criterion, output, y_a, y_b, lam):
        return (lam * criterion(output, y_a, reduction='none') +
                (1 - lam) * criterion(output, y_b, reduction='none')).mean()

    @staticmethod
    def compute_cls_binary_loss(labels, output, classes_per_task):

        binary_targets = to_one_hot(labels.cpu(), output.size(1)).to(labels.device)
        binary_targets = binary_targets[:, -classes_per_task:]
        output_for_newclass_cls = output[:, -classes_per_task:]
        predL = Func.binary_cross_entropy_with_logits(
            input=output_for_newclass_cls, target=binary_targets, reduction='none'
        ).sum(dim=1)  # --> sum over classes, then average over batch
        return predL
