# coding=utf-8
import select
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import autograd
from torch.autograd import Variable
import copy
import numpy as np
import math
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset

from Alg.modelopera import get_fea
from Alg.opt import *
from network import common_network
from Alg.base import Algorithm
import datautil.imgdata.util as imgutil

from Alg.DGAlg import PCL
from Alg.PCDGAlg import SupPCL
from Alg import SupCon
from DataAug.ForAug.augnetV1 import AugNetV1
from Alg import LDAuCID

class PCDG(Algorithm):
    """
    Our own Pratical Continual Domain Generation Algorithm Model.
    It has two training stage: 1. train on source labeled domain; 2. adapt to unlabeled continual target domain.
    It contains our own algorithm and some baseline.
    """

    def __init__(self, args):
        super(PCDG, self).__init__(args)
        self.args = args
        self.task_id = 0
        self.naug = 0    # augmentation data amount
        self.fea_rep = None
        self.proj_rep = None
        self.replay_proxy = None
        self.replay_center = None
        self.featurizer = get_fea(args)

        # training algorithm model
        if args.sourceAlg == 'ERM':
            if args.targetAlg == 'LDAuCID':
                self.classifier = common_network.gmm_classifier_two(args.num_classes, self.featurizer.in_features)
            elif args.targetAlg == 'ERMPCL':  # source ERM, target PCL2
                self.classifier = nn.Linear(self.featurizer.in_features, args.num_classes, bias=False)
                self.proxycloss = PCL.ProxyPLoss2(num_classes=args.num_classes, scale=self.args.PCL_scale)
            else:
                self.classifier = common_network.feat_classifier(
                    args.num_classes, self.featurizer.in_features, args.classifier)
            self.network = nn.Sequential(
                self.featurizer, self.classifier)
            
        elif args.sourceAlg == 'ERM_bot':
            self.bottleneck = common_network.feat_bottleneck(self.featurizer.in_features, type='bn')
            self.classifier = common_network.feat_classifier(args.num_classes, 256, args.classifier)
            self.network = nn.Sequential(self.featurizer, self.bottleneck, self.classifier)
            
        elif args.sourceAlg in ['PCL', 'PCL2', 'SupPCL', 'SupPCL2', 'FP', 'MFP']:
            fea_dim = args.proj_dim[args.dataset]
            self.encoder = PCL.encoder(args, self.featurizer.in_features, fea_dim)
            self._initialize_weights(self.encoder)
            self.fea_proj, self.fc_proj = PCL.fea_proj(args, fea_dim)
            nn.init.kaiming_uniform_(self.fc_proj, mode='fan_out', a=math.sqrt(5))
            
            self.classifier = nn.Parameter(torch.FloatTensor(args.num_classes, fea_dim))
            nn.init.kaiming_uniform_(self.classifier, mode='fan_out', a=math.sqrt(5))
            if args.sourceAlg == 'PCL':
                self.proxycloss = PCL.ProxyPLoss(num_classes=args.num_classes, scale=self.args.PCL_scale)
            elif args.sourceAlg == 'PCL2':
                self.proxycloss = PCL.ProxyPLoss2(num_classes=args.num_classes, scale=self.args.PCL_scale)
            elif args.sourceAlg == 'SupPCL':
                self.suppcloss = SupPCL.SupPCLoss(args)
            elif args.sourceAlg == 'SupPCL2':
                self.proxycloss = PCL.ProxyOnlyLoss(num_classes=args.num_classes, scale=self.args.PCL_scale)
                self.supcon = SupCon.SupConLoss()
            elif args.sourceAlg == 'FP':
                self.proxycloss = PCL.FPLoss(num_classes=args.num_classes, scale=self.args.PCL_scale)
            elif args.sourceAlg == 'MFP':
                self.proxycloss = PCL.MFPLoss(num_classes=args.num_classes, scale=self.args.PCL_scale)
        
        elif args.sourceAlg == 'supcon':
            self.supcon = SupCon.SupConLoss()
            self.fea_proj = SupCon.fea_proj(args, self.featurizer.in_features)
            self.classifier = common_network.feat_classifier(
                args.num_classes, self.featurizer.in_features, args.classifier)
            self.network = nn.Sequential(
                self.featurizer, self.classifier)

        # Data augment algorithm
        if self.args.forAug == 'v1':
            self.aug_fore = AugNetV1(1).cuda()
            # self.aug_fore_opt = torch.optim.SGD(self.aug_fore.parameters(), lr=self.args.lr_sc)
            if args.dataset == 'dg5':
                self.aug_tran = transforms.Normalize([0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
            else:
                self.aug_tran = transforms.Normalize([0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        
    def forward(self, x):
        if self.args.sourceAlg in ['ERM', 'ERM_bot']:
            pred = self.network(x)
        elif self.args.sourceAlg in ['PCL', 'PCL2', 'SupPCL', 'SupPCL2', 'FP', 'MFP']:
            x = self.featurizer(x)
            x = self.encoder(x)
            self.fea_rep = x
            self.proj_rep = self.fea_proj(x)
            pred = F.linear(x, self.classifier)
        elif self.args.sourceAlg == 'supcon':
            pred = self.network(x)
        return pred

    def get_optimizer(self, lr_decay=1.0):
        # get optimizer only based on source algorithm
        if self.args.sourceAlg == 'ERM':
            if self.args.targetAlg == 'LDAuCID':
                self.optimizer = torch.optim.Adam([
                {'params': self.featurizer.parameters(), 'lr': lr_decay * self.args.lr},
                {'params': self.classifier.parameters()}
                ], lr=self.args.lr)
            else:
                self.optimizer = torch.optim.SGD([
                {'params': self.featurizer.parameters(), 'lr': lr_decay * self.args.lr},
                {'params': self.classifier.parameters()}
                ], lr=self.args.lr, momentum=self.args.momentum, weight_decay=self.args.weight_decay, nesterov=True)
        
        elif self.args.sourceAlg == 'ERM_bot':
            self.optimizer = torch.optim.SGD([
            {'params': self.featurizer.parameters(), 'lr': lr_decay * self.args.lr},
            {'params': self.bottleneck.parameters()},
            {'params': self.classifier.parameters()}
            ], lr=self.args.lr, momentum=self.args.momentum, weight_decay=self.args.weight_decay, nesterov=True)

        elif self.args.sourceAlg in ['PCL', 'PCL2', 'SupPCL', 'SupPCL2', 'FP', 'MFP']:
            self.optimizer = torch.optim.SGD([
            {'params': self.featurizer.parameters(), 'lr': lr_decay * self.args.lr},
            {'params': self.encoder.parameters()},
        	{'params': self.fea_proj.parameters()},
        	{'params': self.fc_proj},
        	{'params': self.classifier},
            ], lr=self.args.lr, weight_decay=self.args.weight_decay)
        
        elif self.args.sourceAlg == 'supcon':
            self.optimizer = torch.optim.SGD([
            {'params': self.featurizer.parameters(), 'lr': lr_decay * self.args.lr},
            {'params': self.classifier.parameters()},
            {'params': self.fea_proj.parameters()}
            ], lr=self.args.lr, momentum=self.args.momentum, weight_decay=self.args.weight_decay, nesterov=True)

################################################## train source and adapt ######################################################################

    def train_source(self, minibatches, task_id, epoch):
        self.task_id = task_id
        if (self.args.sourceAlg == 'supcon' or self.args.targetAlg == 'supcon') and self.args.forAug is None:
            all_x = torch.cat([minibatches[0][0], minibatches[0][1]]).cuda().float() 
            all_y = torch.cat([minibatches[1], minibatches[1]]).cuda().long()
        else:
            all_x = minibatches[0].cuda().float()   # torch tensor [batch, C, H, W]
            all_y = minibatches[1].cuda().long()    # torch tensor [batch] -  class labels

        # forward Augmentation
        if self.args.forAug == 'v1':
            ratio = epoch / self.args.max_epoch
            data_fore = self.aug_tran(torch.sigmoid(self.aug_fore(all_x, ratio=ratio)))
            all_x = torch.cat([all_x, data_fore])    # [original, aug]
            all_y = torch.cat([all_y, all_y])
        
        # if a batch only have 1 sample, it can't pass batch norm layer in resnet
        if all_x.size(0) == 1:
            all_x = torch.cat([all_x, all_x])
            all_y = torch.cat([all_y, all_y])

        loss_dict = None
        if self.args.sourceAlg in ['ERM', 'ERM_bot']:
            loss = self.ERMupdate(all_x, all_y)
        elif self.args.sourceAlg in ['PCL', 'PCL2', 'FP', 'MFP']:
            loss, loss_dict = self.PCLupdate(all_x, all_y)
        elif self.args.sourceAlg == 'supcon':
            loss = self.SupConUpdate(all_x, all_y)
        elif self.args.sourceAlg == 'SupPCL':
            loss = self.SupPCLUpdate(all_x, all_y)
        elif self.args.sourceAlg == 'SupPCL2':
            loss = self.SupPCL2Update(all_x, all_y)

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        # self.scheduler.step()
        if loss_dict is None:
            return {'loss': loss.item()}
        else: 
            loss_dict['loss'] = loss.item()
            return loss_dict
    
    def adapt(self, minibatches, task_id, epoch, replay_dataloader=None, old_model=None):
        # self.update(self, minibatches, opt, sch, task_id, old_model=old_model, replay_minibatches=replay_minibatches, dataAug=dataAug, **kwargs)
        self.task_id = task_id
        if (self.args.sourceAlg == 'supcon' or self.args.targetAlg == 'supcon') and self.args.forAug is None:
            all_x = torch.cat([minibatches[0][0], minibatches[0][1]]).cuda().float() 
            all_y = torch.cat([minibatches[1], minibatches[1]]).cuda().long()
        else:
            all_x = minibatches[0].cuda().float()   # torch tensor [batch, C, H, W]
            all_y = minibatches[1].cuda().long()    # torch tensor [batch] -  class labels

        # forward Augmentation
        if self.args.forAug == 'v1':
            all_x, all_y = self.select_aug(all_x, all_y, epoch)

        # if a batch only have 1 sample, it can't pass batch norm layer in resnet
        if all_x.size(0) == 1:
            all_x = torch.cat([all_x, all_x])
            all_y = torch.cat([all_y, all_y])

        loss_dict = None
        if self.args.targetAlg in ['ERM', 'ERM_bot']:
            loss = self.ERMupdate(all_x, all_y, old_model)
        elif self.args.targetAlg in ['PCL', 'PCL2', 'FP', 'MFP']:
            loss, loss_dict = self.PCLupdate(all_x, all_y, old_model)
        elif self.args.targetAlg == 'ERMPCL':
            loss = self.ERMPCLupdate(all_x, all_y, old_model)
        elif self.args.targetAlg == 'supcon':
            loss = self.SupConUpdate(all_x, all_y, old_model)
        elif self.args.targetAlg == 'SupPCL':
            loss = self.SupPCLUpdate(all_x, all_y, old_model)
        elif self.args.targetAlg == 'SupPCL2':
            loss = self.SupPCL2Update(all_x, all_y, old_model)
        elif self.args.targetAlg == 'LDAuCID':
            loss = self.LDAuCIDUpdate(all_x, replay_dataloader, epoch)

        if self.args.sourceAlg == 'ERM_bot' and self.args.pLabelAlg == 'SHOT':
            loss = loss * 0.3

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        # self.scheduler.step()
        if loss_dict is None:
            return {'loss': loss.item()}
        else: 
            loss_dict['loss'] = loss.item()
            return loss_dict

################################################################ Algorithms ####################################################
    def ERMupdate(self, all_x, all_y, old_model=None):
        pred = self(all_x)
        loss = F.cross_entropy(pred, all_y)
        if self.args.SHOT_IM:
            loss = self.SHOT_IMloss(loss, pred)
        if self.args.distill and old_model is not None:
            loss += self.args.distill_alpha * self.distill_loss(pred, all_x, old_model)
        return loss

    def PCLupdate(self, all_x, all_y, old_model=None):
        '''
        orignal Proxy contrastive loss
        '''
        pred = self(all_x)
        loss_cls = F.nll_loss(F.log_softmax(pred, dim=1), all_y)
        
        if self.args.classifier_proxy:  # directly use classifier as proxy
            proxy = self.classifier
            features = self.fea_rep
            if self.args.MPCL == 5 and self.task_id > 0:
                old_proxy = old_model.classifier
                loss_pcl = PCL.MPCLoss(self.args.num_classes, self.args.PCL_scale)(features, all_y, proxy, old_proxy, mweight=self.args.MPCL_alpha, v=1)
            else:
                loss_pcl = self.proxycloss(features, all_y, proxy)
        else:
            proxy = F.linear(self.classifier, self.fc_proj)
        
            # different proxy algorithms in target domains
            if self.args.mix_proxy and old_model is not None:
                old_model.eval()
                with torch.no_grad():
                    # old_proxy = F.linear(old_model.classifier, old_model.fc_proj)
                    proj_old_center = self.fea_proj(self.replay_center)
                proxy = self.args.classifier_mix_tau * proxy + (1-self.args.classifier_mix_tau) * proj_old_center
            
            if self.args.weight_pcl and self.task_id > 0:
                loss_pcl = PCL.WeightPCL(self.args.num_classes, self.args.PCL_scale)(self.proj_rep, all_y, proxy, pred)
            elif self.args.MPCL == 1 and self.task_id > 0:
                proj_old_center = self.fea_proj(self.replay_center)
                loss_pcl = self.args.MPCL_alpha*self.proxycloss(self.proj_rep, all_y, proxy) + (1-self.args.MPCL_alpha)*self.proxycloss(self.proj_rep, all_y, proj_old_center)
            elif self.args.MPCL == 2 and self.task_id > 0:
                with torch.no_grad():
                    proj_old_center = self.fea_proj(self.replay_center)
                loss_pcl = self.args.MPCL_alpha*self.proxycloss(self.proj_rep, all_y, proxy) + (1-self.args.MPCL_alpha)*self.proxycloss(self.proj_rep, all_y, proj_old_center)
            elif self.args.MPCL == 3 and self.task_id > 0:
                with torch.no_grad():
                    old_proxy = F.linear(old_model.classifier, old_model.fc_proj)
                loss_pcl = self.proxycloss(self.proj_rep, all_y, proxy) + self.args.MPCL_alpha * self.proxycloss(self.proj_rep, all_y, old_proxy)
            elif self.args.MPCL == 4 and self.task_id > 0:
                with torch.no_grad():
                    old_proxy = F.linear(old_model.classifier, old_model.fc_proj)
                loss_pcl = PCL.MPCLoss(self.args.num_classes, self.args.PCL_scale)(self.proj_rep, all_y, proxy, old_proxy, mweight=self.args.MPCL_alpha)
            elif self.args.MPCL == 5 and self.task_id > 0:
                with torch.no_grad():
                    old_proxy = F.linear(old_model.classifier, old_model.fc_proj)
                loss_pcl = PCL.MPCLoss(self.args.num_classes, self.args.PCL_scale)(self.proj_rep, all_y, proxy, old_proxy, mweight=self.args.MPCL_alpha, v=1)
            elif self.args.MPCL == 6 and self.task_id > 0:
                with torch.no_grad():
                    old_proxy = 0.5 * F.linear(old_model.classifier, old_model.fc_proj) + 0.5 * proxy
                loss_pcl = PCL.MPCLoss(self.args.num_classes, self.args.PCL_scale)(self.proj_rep, all_y, proxy, old_proxy, mweight=self.args.MPCL_alpha, v=1)
            elif self.args.MPCL == 7 and self.task_id > 0:
                with torch.no_grad():
                    old_proxy = F.linear(old_model.classifier, old_model.fc_proj)
                loss_pcl = PCL.MPCLoss(self.args.num_classes, self.args.PCL_scale)(self.proj_rep, all_y, proxy, old_proxy, mweight=self.args.MPCL_alpha, v=2)
            elif self.args.MPCL == 8 and self.task_id > 0:
                with torch.no_grad():
                    old_proxy = F.linear(old_model.classifier, old_model.fc_proj)
                loss_pcl = PCL.MPCLoss(self.args.num_classes, self.args.PCL_scale)(self.proj_rep, all_y, proxy, old_proxy, mweight=self.args.MPCL_alpha, v=3)
            
            elif self.args.targetAlg == 'FP':
                loss_pcl = self.proxycloss(self.fea_rep, all_y)
            elif self.args.targetAlg == 'MFP':
                loss_pcl = self.proxycloss(self.fea_rep, all_y, self.replay_proxy, mweight=self.args.MPCL_alpha)
            else:
                loss_pcl = self.proxycloss(self.proj_rep, all_y, proxy)
        
        loss_dict = {'ce': loss_cls.item(), 'pcl': (self.args.loss_alpha1 * loss_pcl).item()}
        loss = loss_cls + self.args.loss_alpha1 * loss_pcl
        
        if self.args.SHOT_IM:
            loss = self.SHOT_IMloss(loss, pred)
        if self.args.distill and old_model is not None:
            distill_loss = self.args.distill_alpha * self.distill_loss(pred, all_x, old_model)
            loss += distill_loss
            loss_dict['distill'] = distill_loss.item()
        if self.args.distillProxy and old_model is not None:
            distill_loss = self.args.distill_alpha * self.distill_proxy_loss(proxy, self.proj_rep, all_x, old_model)
            loss += distill_loss
            loss_dict['distill'] = distill_loss.item()
        return loss, loss_dict

    def ERMPCLupdate(self,all_x, all_y, old_model=None):
        '''
        source ERM, target PCL2, no encoder and projector head
        '''
        feas = self.featurizer(all_x)
        logits = self.classifier(feas)
        loss_cls = F.nll_loss(F.log_softmax(logits, dim=1), all_y)
        loss_pcl = self.proxycloss(feas, all_y, self.classifier.weight)
        loss = loss_cls + self.args.loss_alpha1 * loss_pcl
        if self.args.SHOT_IM:
            loss = self.SHOT_IMloss(loss, logits)
        if self.args.distill and old_model is not None:
            loss += self.args.distill_alpha * self.distill_loss(logits, all_x, old_model)
        return loss
    
    def SupConUpdate(self, all_x, all_y, old_model=None):
        '''
        orignal supervised contrastive loss
        all_x: [batch_size*2, C, H, W] 
        all_y: [batch_size*2]

        If args.forAug is None: 
        batch_size*2 is cancatenate of two imgutil.image_train transform of the same original image. So it will lead to twice trainig process: each original images is trained twice.
        If args.forAug is not None: 
        batch_size*2 is concatenate of imgutil.image_train transform and forAug images.     
        
        '''
        features = self.featurizer(all_x)
        pred = self.classifier(features)
        loss_ce = F.cross_entropy(pred, all_y)

        proj_features = self.fea_proj(features)
        f1, f2 = torch.split(proj_features, [int(proj_features.size(0)/2), int(proj_features.size(0)/2)], dim=0)
        con_y, _ = torch.split(all_y, [int(all_y.size(0)/2), int(all_y.size(0)/2)], dim=0)
        supconLoss = self.supcon(torch.cat([F.normalize(f1).unsqueeze(1), F.normalize(f2).unsqueeze(1)], dim=1), con_y)

        loss = loss_ce + self.args.loss_alpha1 * supconLoss
        if self.args.SHOT_IM:
            loss = self.SHOT_IMloss(loss, pred)
        if self.args.distill and old_model is not None:
            loss += self.args.distill_alpha * self.distill_loss(pred, all_x, old_model)
        return loss

    def SupPCLUpdate(self, all_x, all_y, old_model=None):
        '''
        Combination of orignal supervised contrastive loss and ProxyOnlyLoss. Use Alg.PCDGAlg.SupPCL.SupPCLoss
        if forAug:
            all_x: [N*2, C, H, W]   all_y: [N*2] 
        if not forAug:
            all_x: [N, C, H, W]   all_y: [N]     
        '''
        features = self.encoder(self.featurizer(all_x))
        pred = F.linear(features, self.classifier)
        loss_ce = F.cross_entropy(pred, all_y)

        proj_features = self.fea_proj(features)
        proj_proxy = F.linear(self.classifier, self.fc_proj)

        suppcloss = self.suppcloss(proj_features, proj_proxy, all_y)

        loss = loss_ce + self.args.loss_alpha1*suppcloss
        if self.args.SHOT_IM:
            loss = self.SHOT_IMloss(loss, pred)
        if self.args.distill and old_model is not None:
            loss += self.args.distill_alpha * self.distill_loss(pred, all_x, old_model)
        return loss
        
    
    def SupPCL2Update(self, all_x, all_y, old_model=None):
        '''
        combination of orignal supervised contrastive loss and ProxyOnlyLoss. forAug should not be None.
        all_x: [batch_size*2, C, H, W] 
        all_y: [batch_size*2]   
        '''
        features = self.encoder(self.featurizer(all_x))
        pred = F.linear(features, self.classifier)
        loss_ce = F.cross_entropy(pred, all_y)

        proj_features = self.fea_proj(features)

        f1, f2 = torch.split(proj_features, [int(proj_features.size(0)/2), int(proj_features.size(0)/2)], dim=0)
        con_y, _ = torch.split(all_y, [int(all_y.size(0)/2), int(all_y.size(0)/2)], dim=0)
        supconLoss = self.supcon(torch.cat([F.normalize(f1).unsqueeze(1), F.normalize(f2).unsqueeze(1)], dim=1), con_y)

        fc_proj = F.linear(self.classifier, self.fc_proj)
        assert fc_proj.requires_grad == True
        loss_pcl = self.proxycloss(proj_features, all_y, fc_proj)

        loss = loss_ce + self.args.loss_alpha1 * supconLoss + 1 * loss_pcl
        if self.args.SHOT_IM:
            loss = self.SHOT_IMloss(loss, pred)
        if self.args.distill and old_model is not None:
            loss += self.args.distill_alpha * self.distill_loss(pred, all_x, old_model)
        return loss

    def LDAuCIDUpdate(self, all_x, replay_dataloader, epoch):
        # assign pseudo label here. Use softmax to assign pseudo label for replay data.

        lamda2=1e-2
        gmmX,gmmY  = self.gmm_model.sample(n_samples=20*self.args.batch_size)
        probs = self.gmm_model.predict_proba(gmmX)
        probs = np.max(probs,axis=1)
        
        gmmX = torch.from_numpy(gmmX[probs > self.args.gmm_tau,:]).to(dtype=torch.float).cuda() 
        gmmY = torch.from_numpy(gmmY[probs > self.args.gmm_tau]).to(dtype=torch.long).cuda()
        # gmmY = keras.utils.to_categorical(Yembedlabel1) 

        replay_minibatches = next(iter(replay_dataloader))
        replay_x = replay_minibatches[0][:len(all_x)].cuda().float()
        replay_y = replay_minibatches[1][:len(all_x)].cuda().long()

        # forward Augmentation
        if self.args.forAug == 'v1':
            ratio = epoch / self.args.max_epoch
            data_fore = self.aug_tran(torch.sigmoid(self.aug_fore(replay_x, ratio=ratio)))
            replay_x = torch.cat([replay_x, data_fore])
            replay_y = torch.cat([replay_y, replay_y])
                       
        theta_ = LDAuCID.generateTheta(100,self.args.num_classes).cuda()

        # print(torch.cuda.get_device_properties(self.args.device).total_memory, torch.cuda.memory_reserved(self.args.device), torch.cuda.memory_allocated(self.args.device))
        curr_features = nn.ReLU()(self.network[1].fc0(self.network[0](all_x)))
        replay_features = nn.ReLU()(self.network[1].fc0(self.network[0](replay_x)))
        # print(curr_features.size(), replay_features.size())

        # network[0]=network.featurizer, model[1]=network.classifier
        discriminationLoss = torch.mean(F.cross_entropy(self.network[1].fc1(gmmX), gmmY))   \
                            + torch.mean(F.cross_entropy(self.classifier.fc1(replay_features), replay_y)) 

        matchingLoss = LDAuCID.sWasserstein(curr_features, gmmX, theta_, nclass=self.args.num_classes, Cp=None,Cq=None,)+\
                        LDAuCID.sWasserstein(curr_features, replay_features, theta_, nclass=self.args.num_classes,Cp=None,Cq=None,)+\
                        LDAuCID.sWasserstein(replay_features, gmmX, theta_, nclass=self.args.num_classes, Cp=None,Cq=None,)
        myLoss=  lamda2*matchingLoss  + discriminationLoss  
        return myLoss

    def SHOT_IMloss(self, loss, logist):
        '''
        Source Hypothesis Transfer with Information Maximization (SHOT-IM)
        '''
        gent = True
        epsilon = 1e-5
        ent_par = 1.0
        if self.task_id>0:
            softmax_out = nn.Softmax(dim=1)(logist)
            entropy_loss = torch.mean(Entropy_(softmax_out))
            if gent:
                msoftmax = softmax_out.mean(dim=0)
                gentropy_loss = torch.sum(-msoftmax * torch.log(msoftmax + epsilon))
                entropy_loss -= gentropy_loss
            im_loss = entropy_loss * ent_par
            loss += im_loss
        return loss
    
    def distill_loss(self, pred, all_x, old_model):
        '''
        distill loss 
        '''
        # all_x = all_x.cpu()
        old_model.cuda().eval()
        with torch.no_grad():
            old_logist = nn.Softmax(dim=1)(old_model(all_x))
        loss = F.cross_entropy(pred, old_logist)
        return loss

    def distill_proxy_loss(self, proxy, proj_fea, all_x, old_model):
        '''
        distill loss using proxy and projected features to predict
        '''
        proxy_pred = torch.matmul(proj_fea, proxy.T) * self.args.PCL_scale   # (N, C)
        
        old_model.cuda().eval()
        with torch.no_grad():
            old_pred = old_model(all_x)
            old_proxy = F.linear(old_model.classifier, old_model.fc_proj)
            old_proxy_pred = torch.matmul(old_model.proj_rep, old_proxy.T) * self.args.PCL_scale
            old_logist = nn.Softmax(dim=1)(old_proxy_pred)
            
        loss = F.cross_entropy(proxy_pred, old_logist)
        return loss
            
        
################################################################ Utils ####################################################
    def _initialize_weights(self, modules):
        '''
        initialize weights for PCL 
        '''
        for m in modules:
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                n = m.weight.size(1)
                m.weight.data.normal_(0, 0.01)
                m.bias.data.zero_()

    
    def after_train(self, data_loader, task_id):
        if self.args.targetAlg == 'LDAuCID':
            self.gmm_model = LDAuCID.myGMMfit(self.args, self.network, data_loader)
        if self.args.targetAlg == 'MFP':
            pass
            # self.eval()
            # img, clabel,  dlabel = data_loader.dataset.get_raw_data()
            # util_dataset = utilDataset(img, clabel, dlabel, data_loader.dataset.loader, transforms=imgutil.image_test(self.args))
            # dataloader = DataLoader(dataset=util_dataset, batch_size=self.args.batch_size*2,num_workers=self.args.N_WORKERS)
            # with torch.no_grad():
            #     for data in DataLoader:
            #         x = data[0].cuda()
            #         label = data[1].cuda()
                    
                    
        
    def balance_finetune(self, minibatches, opt, sch, task_id):
        pass

    def select_aug(self, all_x, all_y, epoch):
        ratio = epoch / self.args.max_epoch
        if self.args.aug_tau > 0:
            self.eval()
            with torch.no_grad():
                pred = nn.Softmax(dim=1)(self(all_x))
                ov, idx = torch.max(pred, 1)
                bool_index = ov > self.args.aug_tau
                data_fore = all_x[bool_index]
                y_fore = all_y[bool_index]
                data_fore = self.aug_tran(torch.sigmoid(self.aug_fore(data_fore, ratio=ratio)))
            self.train()
        else:
            data_fore = self.aug_tran(torch.sigmoid(self.aug_fore(all_x, ratio=ratio)))
            y_fore = all_y
        all_x = torch.cat([all_x, data_fore])    # [original, aug]
        all_y = torch.cat([all_y, y_fore])
        self.naug += len(y_fore)
        return all_x, all_y

    # def get_scheduler(self):
    #     if self.args.sourceAlg == 'ERM' or self.args.targetAlg == 'ERM':
    #         lr_gamma = 0.0003
    #         lr_decay = 0.75
    #         self.scheduler = torch.optim.lr_scheduler.LambdaLR(
    #             self.optimizer, lambda x:  self.args.lr * (1. + lr_gamma * float(x)) ** (-lr_decay))
    #     else:
    #         pass

def Entropy_(input_):
    bs = input_.size(0)
    epsilon = 1e-5
    entropy = -input_ * torch.log(input_ + epsilon)
    entropy = torch.sum(entropy, dim=1)
    return entropy 


class utilDataset(Dataset):
    '''
    construct pseudo dataset
    input: images_dict.
    '''
    def __init__(self, images_dict, class_labels, domain_labels, loader, transform=None, target_transform=None):
        self.x = images_dict                 # list of [PIL image]
        self.labels = class_labels           # numpy array
        self.dlabels = domain_labels         # numpy array
        self.loader = loader
        self.transform = transform
    
    def __len__(self):
        return len(self.labels)

    def __getitem__(self, index):
        imgs = self.transform(self.loader(self.x[index])) if self.transform is not None else self.loader(self.x[index])
        return imgs, self.labels[index], self.dlabels[index] 

    def get_raw_data(self):
        return self.x, self.labels, self.dlabels