# Copyright 2020-present, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Davide Abati, Simone Calderara.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import os
import numpy as np
from models.utils.continual_model import ContinualModel
from models.optimizers import get_optimizer, get_laps_inner_optimizer
from models.optimizers.lr_scheduler import LR_Scheduler
import torch
from utils.buffer import Buffer
import higher
import pdb
import copy
import torch.nn as nn
from torchvision.utils import save_image
from augmentations.imagenet_aug import MaskGenerator, NonBinaryMaskGenerator
import torch.distributed as dist


try:
    # noinspection PyUnresolvedReferences
    from apex import amp
except ImportError:
    amp = None
    
class LAPS(ContinualModel):
    NAME = 'laps'
    def __init__(self, backbone, loss, args, len_train_loader, transform, logger):
        super(LAPS, self).__init__(backbone, loss, args, len_train_loader, transform, logger)
        self._laps = self.args.hyperparameters.LAPS
        # self.buffer = Buffer(self._laps.buffer_size, self.device, n_tasks=self.args.dataset.n_task, mode='ring')
        self.buffer = Buffer(self.args.model.buffer_size, self.device, configs=self.args)


        self.curr_epoch = -1
        if self.args.model.type == 'swin':
            model_patch_size=self.args.model.swin.patch_size
        elif self.args.model.type == 'vit':
            model_patch_size=self.args.model.vit.patch_size
        else:
            raise NotImplementedError
            
        self.mask_generator = MaskGenerator(
                input_size=self.args.dataset.image_size,
                mask_patch_size=self.args.model.mask_patch_size,
                model_patch_size=model_patch_size,
                mask_ratio=self.args.model.mask_ratio,
            )

        self.non_binary_mask_generator = NonBinaryMaskGenerator(
                input_size=self.args.dataset.image_size,
                mask_patch_size=self.args.model.mask_patch_size,
                model_patch_size=model_patch_size,
                mask_ratio=self.args.model.mask_ratio,
                sample='uniform'
            )
        self.scale = self.args.model.mask_patch_size // model_patch_size # 8

    # def set_task(self, dummy, task_id):
    #     task_init_lr = self.args.init_lr * self.args.train.task_lr_decay**task_id                        
    #     final_lr = self.args.train.min_lr if hasattr(self.args.train, 'min_lr') else self.args.init_lr*self.args.final_lr_decay
        
    #     if dist.get_rank() == 0:
    #         print('warmup_lr: %.4f, base_lr: %.4f, final_lr: %.4f'%(self.args.train.warmup_lr*self.args.batch_size/256,
    #                                                                 task_init_lr*self.args.batch_size/256,
    #                                                                 final_lr*self.args.batch_size/256,))
    #     self.lr_scheduler = LR_Scheduler(
    #         optimizer=self.opt,
    #         warmup_epochs=self.args.train.warmup_epochs,
    #         warmup_lr=self.args.train.warmup_lr*self.args.batch_size/256,
    #         num_epochs=self.args.num_epochs,
    #         base_lr=task_init_lr*self.args.batch_size/256,
    #         final_lr=final_lr*self.args.batch_size/256,
    #         iter_per_epoch=self.len_train_lodaer,
    #         constant_predictor_lr=True # see the end of section 4.2 predictor
    #     )        
        
    #     if task_id > 0 and self.args.reinit_opt_per_task:            
    #         self.opt = get_optimizer(
    #             self.args.optimizer, self.net,
    #             lr=task_init_lr*self.args.batch_size/256,
    #             momentum=self.args.momentum,
    #             weight_decay=self.args.weight_decay
    #         )

    # def end_task(self, logger):
    #     self.lr_scheduler.reset()
        
    # CMAML     --sync_update --use_old_task_memory
    # Sync      --learn_lr --sync_update --use_old_task_memory
    # LaMAML    --learn_lr --use_old_task_memory
    # def observe(self, inputs, mask, notaug_inputs, task_id, batch_idx=None, epoch_idx=None, num_steps=None):
    #     self.opt.zero_grad()
    #     inputs = inputs.cuda(non_blocking=True)
    #     mask = mask.cuda(non_blocking=True)
    #     # select samples (random)
    #     # notaug_inputs = notaug_inputs.cuda(non_blocking=True)
    #     size = min(int(self.args.train.select_ratio*self.args.batch_size),len(inputs))
    #     pick = torch.randperm(len(inputs))[:size]
    #     inputs = inputs[pick]
    #     mask = mask[pick]                
    #     data_dict = {}
    #     if self.buffer.is_empty():
    #         #pdb.set_trace() # 1891
    #         loss = self.net(inputs, mask) # 13463, 8833
    #         data_dict['loss'] = loss.item()
    #         data_dict['penalty'] = 0.0
    #         data_dict['rep_loss'] = loss.item()
                        
    #         if self.args.train.accumulation_steps > 1:
    #             loss = loss / self.args.train.accumulation_step
    #             if self.args.amp_opt_level != "O0":
    #                 with amp.scale_loss(loss, self.opt) as scaled_loss:
    #                     scaled_loss.backward() # 19xxx->12xxx->4093
    #             else:
    #                 loss.backward()
    #             if (batch_idx + 1) % self.args.train.accumulation_steps == 0:
    #                 self.opt.step()
    #                 self.opt.zero_grad()
    #                 self.lr_scheduler.step_update(epoch_idx * num_steps + batch_idx)
    #         else:
    #             self.opt.zero_grad()
    #             if self.args.amp_opt_level != "O0":
    #                 with amp.scale_loss(loss, self.opt) as scaled_loss:
    #                     scaled_loss.backward()
    #             else:
    #                 loss.backward()
    #             self.opt.step()
    #             self.lr_scheduler.step_update(epoch_idx * num_steps + batch_idx)
    #     else:              
    #         _ps, _inc = self.net.module.patch_size, self.net.module.in_chans            
    #         full_mask = mask.repeat_interleave(_ps, 1).repeat_interleave(_ps, 2).unsqueeze(1).contiguous()
    #         norm_scale = (full_mask.sum() + 1e-5) * _inc
    #         bf_inputs, _ = self.buffer.get_data(self.args.batch_size, transform=self.transform)
    #         bf_inputs = bf_inputs.cuda(non_blocking=True)

    #         if self.args.train.accumulation_steps > 1:
    #             pass
    #         else:
    #             self.opt.zero_grad()
    #             with higher.innerloop_ctx(self.net.module, self.opt, 
    #                                 copy_initial_weights=False, track_higher_grads=False) as (fmodel, diffopt):                     
    #                 la_loss = fmodel(inputs, mask) / norm_scale
    #                 diffopt.step(la_loss)
    #                 la_bf_loss = fmodel(bf_inputs, mask, reduction='none').detach()

    #             task_px_loss = self.net.module(inputs, mask, reduction='none')
    #             task_loss = task_px_loss.sum() / norm_scale 

    #             if self.args.amp_opt_level != "O0":
    #                 with amp.scale_loss(task_loss, self.opt) as scaled_loss:
    #                     scaled_loss.backward(retain_graph=True)
    #             else:
    #                 task_loss.backward(retain_graph=True)

    #             bf_loss = self.net.module(bf_inputs, mask, reduction='none')
                
    #             repr_forget_loss = la_bf_loss - bf_loss#.detach()
    #             # rfl_mean = torch.mean(repr_forget_loss)
    #             # rfl_std = torch.std(repr_forget_loss)
    #             # # emphasize to learn pixels which are not harmful
    #             # att = - (repr_forget_loss - rfl_mean) / (rfl_std + 1e-9)
                                
    #             if self._laps.type == 'pixel':
    #                 # att /= torch.std(-repr_forget_loss)
    #                 att = self.minmax_scale(-repr_forget_loss) * 2
                    
    #             elif self._laps.type == 'patch':
    #                 # _mps = self.args.model.mask_patch_size                                                            
    #                 _mps = self._laps.mask_patch_size
    #                 # att = nn.functional.avg_pool2d(att, kernel_size=_mps, stride=_mps) # (6, 6)
    #                 att = nn.functional.avg_pool2d(-repr_forget_loss, kernel_size=_mps, stride=_mps) # (6, 6)
    #                 # att /= torch.std(att)
    #                 att = self.minmax_scale(att) * 2
    #                 # NOTE RGB channel wise?
    #                 att = att.repeat_interleave(_mps, 2).repeat_interleave(_mps, 3).contiguous() # (192, 192)

    #             elif self._laps.type == 'patch_selection':
    #                 raise NotImplementedError
    #             else:
    #                 raise NotImplementedError('wrong laps type: %s'%self._laps.type)
                
    #             rep_loss = (task_px_loss * self._laps.att_hyp * att.data).sum() / norm_scale                
    #             penalty = self._laps.buf_hyp * bf_loss.sum() / norm_scale                
    #             loss = rep_loss + penalty
                
    #             if self.args.amp_opt_level != "O0":
    #                 with amp.scale_loss(loss, self.opt) as scaled_loss:
    #                     scaled_loss.backward()
    #             else:
    #                 loss.backward()
    #             self.opt.step()
    #             self.lr_scheduler.step_update(epoch_idx * num_steps + batch_idx)
            
    #         data_dict['loss'] = loss.item()            
    #         data_dict['rep_loss'] = rep_loss.item()
    #         data_dict['penalty'] = penalty.item()
                    
    #     torch.cuda.synchronize()    
        
    #     data_dict.update({'lr': self.lr_scheduler.get_lr()})
    #     self.buffer.add_data(examples=notaug_inputs, task_labels=torch.ones(self.args.batch_size, dtype=torch.int)*task_id)        
    #     return data_dict       
    
    def masked_observe(self, inputs, mask, image_paths, task_id, batch_idx=None, epoch_idx=None, num_steps=None):
        self.opt.zero_grad()
        inputs = inputs.cuda(non_blocking=True)
        mask = mask.cuda(non_blocking=True)
        size = min(int(self.args.train.select_ratio*self.args.batch_size),len(inputs))
        pick = torch.randperm(len(inputs))[:size]
        inputs = inputs[pick]
        mask = mask[pick]                
        data_dict = {}
        data_dict['penalty'] = 0.0
        if not self.buffer.is_full(self.logger):
            loss = self.net(inputs, mask)
            data_dict['loss'] = loss.item()        
            data_dict['rep_loss'] = loss.item()                        
            
            self.opt.zero_grad()
            if self.args.amp_opt_level != "O0":
                with amp.scale_loss(loss, self.opt) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()
            self.opt.step()
            self.lr_scheduler.step_update(epoch_idx * num_steps + batch_idx)
        else:   
            _ps, _inc = self.net.module.patch_size, self.net.module.in_chans            
            full_mask = mask.repeat_interleave(_ps, 1).repeat_interleave(_ps, 2).unsqueeze(1).contiguous()
            norm_scale = (full_mask.sum() + 1e-5) * _inc
            self.opt.zero_grad()
                                        
            if self._laps.type == 'patch':
                with higher.innerloop_ctx(self.net.module, self.opt, 
                                    copy_initial_weights=False, track_higher_grads=False) as (fmodel, diffopt):                                     
                    for _ in range(self._laps.inner_steps):
                        bf_inputs = self.buffer.get_data(self.args.batch_size, transform=self.transform)
                        bf_inputs = bf_inputs[0].cuda(non_blocking=True)
                        new_mask = [torch.Tensor(self.mask_generator()) for _ in range(self.args.batch_size)]
                        new_mask = torch.stack(new_mask).cuda(non_blocking=True)
                        
                        bf_loss = fmodel(bf_inputs, new_mask) / norm_scale
                        diffopt.step(bf_loss)
                    la_task_loss = fmodel(inputs, mask, reduction='none').detach()

                task_loss = self.net.module(inputs, mask, reduction='none')
                repr_forget_loss = la_task_loss - task_loss
                
                _mps = self._laps.mask_patch_size
                att = nn.functional.avg_pool2d(-repr_forget_loss, kernel_size=_mps, stride=_mps) # (6, 6)
                att = self.minmax_scale(att) * 2
                # NOTE RGB channel wise?
                att = att.repeat_interleave(_mps, 2).repeat_interleave(_mps, 3).contiguous() # (192, 192)

                rep_loss = (task_loss * self._laps.att_hyp * att.data).sum() / norm_scale                
                loss = rep_loss #+ penalty      
                
            elif self._laps.type == 'patch_selection':
                la_task_loss = torch.zeros_like(inputs).cuda(non_blocking=True)
                        
                with higher.innerloop_ctx(self.net.module, self.opt, 
                                    copy_initial_weights=False, track_higher_grads=False) as (fmodel, diffopt):                                     
                    for _ in range(self._laps.inner_steps):
                        bf_inputs, bf_masks = self.buffer.get_data(self.args.batch_size, transform=self.transform, get_mask=True)
                        bf_inputs = bf_inputs.cuda(non_blocking=True)
                        bf_masks = bf_masks.cuda(non_blocking=True)
                        
                        # bf_loss = fmodel(bf_inputs, bf_masks) / norm_scale                        
                        bf_loss = fmodel(bf_inputs, bf_masks)
                        diffopt.step(bf_loss)
                        
                        tmp_mask = torch.stack([torch.Tensor(self.mask_generator()) for _ in range(self.args.batch_size)])
                        la_task_loss += fmodel(inputs, tmp_mask, reduction='none').detach()
                
                _mps = self._laps.mask_patch_size
                # selected_mask = nn.functional.avg_pool2d(la_task_loss, kernel_size=_mps, stride=_mps) # (6, 6) -> 108             
                selected_mask = nn.functional.avg_pool2d(la_task_loss.sum(1), kernel_size=_mps, stride=_mps) # (6, 6) -> 108                
                sms = selected_mask.shape
                selected_mask = selected_mask.view(self.args.batch_size, -1)
                margsort = torch.argsort(selected_mask, 1)
                selected_mask = torch.where(margsort > np.prod(sms[1:])*(1-self.args.model.mask_ratio), torch.ones_like(margsort), torch.zeros_like(margsort)).view(sms)
                selected_mask = selected_mask.repeat_interleave(self.scale, 1).repeat_interleave(self.scale, 2).contiguous() # (192, 192)                
                # selected_mask = selected_mask.repeat_interleave(self.scale, 2).repeat_interleave(self.scale, 3).contiguous() # (192, 192)                
                rep_loss = self.net(inputs, selected_mask)
                loss = rep_loss #+ penalty                
                                
                if self.buffer.is_full(self.logger) and (self.curr_epoch == -1 or epoch_idx == 0 or (epoch_idx+1) % 100 == 0):
                    full_mask = selected_mask.repeat_interleave(_ps, 1).repeat_interleave(_ps, 2).unsqueeze(1).contiguous()
                    for _i in range(len(inputs)):
                        os.makedirs('../images/0501_patch_select/%s/t%s_ep%s/'%(self.args.name, task_id, epoch_idx), exist_ok=True)
                        save_image(inputs[_i]*full_mask[_i], '../images/0501_patch_select/%s/t%s_ep%s/img%s.png'%(self.args.name, task_id, epoch_idx,_i))
                    self.curr_epoch += 1
                    
            elif self._laps.type == 'patch_selection_orignal':
                bf_inputs, bf_masks = self.buffer.get_data(self.args.batch_size, transform=self.transform, get_mask=True)
                bf_inputs = bf_inputs.cuda(non_blocking=True)
                bf_masks = bf_masks.cuda(non_blocking=True)                
                mask_grad_accum = 0
                with higher.innerloop_ctx(self.net.module, self.opt, 
                                    copy_initial_weights=False, track_higher_grads=True) as (fmodel, diffopt):                    
                    inner_batch = self.args.batch_size // self._laps.inner_steps
                    for _i in range(self._laps.inner_steps):                        
                        # a step forward using inputs
                        target_mask = torch.stack([torch.Tensor(self.buffer.mask_generator()) 
                                            if self._laps.la_mask_type == 'binary' else 
                                            torch.Tensor(self.buffer.non_binary_mask_generator()) for _ in range(len(inputs))])
                        target_mask = nn.Parameter(target_mask.cuda(non_blocking=True))                                       
                        la_loss = fmodel(inputs, torch.sigmoid(target_mask))
                        diffopt.step(la_loss)
                        
                        # compute loss after updating current training instances
                        la_bf_loss = fmodel(bf_inputs[inner_batch*_i:inner_batch*(_i+1)], bf_masks[inner_batch*_i:inner_batch*(_i+1)])
                        mask_loss = torch.norm(torch.sigmoid(target_mask).sum([1,2]) / (mask.sum([1,2]) + 1e-8) - 1.0)
                        la_forget_loss = la_bf_loss + mask_loss
                        mask_grad_accum += torch.autograd.grad(la_forget_loss, target_mask)[0]                        
                                        
                # mask <- mask - step_size * accum_grads                
                _mps = self._laps.mask_patch_size
                adap_mask = -mask_grad_accum
                if False: # 'across all instances in minibatch'
                    # boundary = min((adap_mask > 0).sum(), int(np.prod(adap_mask.shape) * (self.args.model.mask_ratio)))                
                    boundary = int(np.prod(adap_mask.shape) * (self.args.model.mask_ratio))
                    selected_mask = torch.where(adap_mask > torch.sort(adap_mask.view(-1))[0][boundary], torch.zeros_like(adap_mask), torch.ones_like(adap_mask))                
                else:
                    boundary = 1408
                    selected_mask = torch.where(
                                        torch.sort(adap_mask.view(self.args.batch_size, -1))[1].view_as(adap_mask) >= boundary, 
                                        torch.zeros_like(adap_mask), 
                                        torch.ones_like(adap_mask))                    
                rep_loss = self.net(torch.cat([inputs, bf_inputs]), torch.cat([selected_mask, bf_masks]))
                loss = rep_loss #+ penalty
                                
                if self.buffer.is_full(self.logger) and (self.curr_epoch == -1 or epoch_idx == 0 or (epoch_idx+1) % 100 == 0):
                    mask_ratios = selected_mask.sum([1,2])/np.prod(selected_mask.shape[1:])
                    full_mask = selected_mask.repeat_interleave(_ps, 1).repeat_interleave(_ps, 2).unsqueeze(1).contiguous()
                    for _i in range(len(inputs)):
                        os.makedirs('../images/0504_patch_select_org/%s/t%s_ep%s/'%(self.args.name, task_id, epoch_idx), exist_ok=True)
                        mr = ('%.2f'%mask_ratios[_i]).replace('.','p')
                        save_image(inputs[_i]*full_mask[_i], '../images/0504_patch_select_org/%s/t%s_ep%s/img%s_mr%s.png'%(self.args.name, task_id, epoch_idx,_i, mr))
                    self.curr_epoch += 1
            else:
                raise NotImplementedError('wrong laps type: %s'%self._laps.type)
            
            if self.args.amp_opt_level != "O0":
                with amp.scale_loss(loss, self.opt) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()
            self.opt.step()
            self.lr_scheduler.step_update(epoch_idx * num_steps + batch_idx)
            # data_dict['penalty'] = penalty.item()
            data_dict['loss'] = loss.item()            
            data_dict['rep_loss'] = rep_loss.item()
                    
        torch.cuda.synchronize()    
        
        data_dict.update({'lr': self.lr_scheduler.get_lr()})
        self.buffer.add_data(examples=inputs, paths=image_paths)
        return data_dict       
    
    
    
    def observe_eval(self, inputs, mask, notaug_inputs, task_id, batch_idx=None, epoch_idx=None, num_steps=None):
        self.opt.zero_grad()
        self.net.zero_grad()
        inputs = inputs.cuda(non_blocking=True)
        mask = mask.cuda(non_blocking=True)
        size = min(int(self.args.train.select_ratio*self.args.batch_size),len(inputs))
        pick = torch.randperm(len(inputs))[:size]
        inputs = inputs[pick]
        mask = mask[pick]                
        data_dict = {}
        if self.buffer.is_empty():
            pass
        else:              
            _ps, _inc = self.net.module.patch_size, self.net.module.in_chans            
            full_mask = mask.repeat_interleave(_ps, 1).repeat_interleave(_ps, 2).unsqueeze(1).contiguous()
            norm_scale = (full_mask.sum() + 1e-5) * _inc
            bf_inputs, _ = self.buffer.get_data(self.args.batch_size, transform=self.transform)
            bf_inputs = bf_inputs.cuda(non_blocking=True)

            if self.args.train.accumulation_steps > 1:
                pass
            else:
                self.opt.zero_grad()
                with higher.innerloop_ctx(self.net.module, self.opt, 
                                    copy_initial_weights=False, track_higher_grads=False) as (fmodel, diffopt):                     
                    la_loss = fmodel(inputs, mask) / norm_scale
                    diffopt.step(la_loss)
                    la_bf_loss = fmodel(bf_inputs, mask, reduction='none').detach()

                task_px_loss = self.net.module(inputs, mask, reduction='none')
                task_loss = task_px_loss.sum() / norm_scale 

                if self.args.amp_opt_level != "O0":
                    with amp.scale_loss(task_loss, self.opt) as scaled_loss:
                        scaled_loss.backward(retain_graph=True)
                else:
                    task_loss.backward(retain_graph=True)

                bf_loss = self.net.module(bf_inputs, mask, reduction='none')                
                repr_forget_loss = la_bf_loss - bf_loss
                
                if self._laps.type == 'pixel':
                    att = self.minmax_scale(-repr_forget_loss) * 2
                    
                elif self._laps.type == 'patch':
                    _mps = self._laps.mask_patch_size
                    att = nn.functional.avg_pool2d(-repr_forget_loss, kernel_size=_mps, stride=_mps) # (6, 6)
                    att = self.minmax_scale(att) * 2
                    att = att.repeat_interleave(_mps, 2).repeat_interleave(_mps, 3).contiguous() # (192, 192)

                elif self._laps.type == 'patch_selection':
                    raise NotImplementedError
                else:
                    raise NotImplementedError('wrong laps type: %s'%self._laps.type)
                
                # rep_loss = (task_px_loss * self._laps.att_hyp * att.data).sum() / norm_scale
                rep_loss = (task_px_loss * self._laps.att_hyp * att).sum() / norm_scale
                penalty = self._laps.buf_hyp * bf_loss.sum() / norm_scale
                loss = rep_loss + penalty                    
                    
                # NOTE repr_forget_loss???
                    
                os.makedirs('../images/%s/color/'%self.args.tag, exist_ok=True)
                os.makedirs('../images/%s/red/'%self.args.tag, exist_ok=True)
                os.makedirs('../images/%s/green/'%self.args.tag, exist_ok=True)
                os.makedirs('../images/%s/blue/'%self.args.tag, exist_ok=True)
                for j in range(len(inputs)):
                    if j not in [0,2,6,19,29]:
                        continue
                    
                    norm_img = self.minmax_scale(inputs[j])
                    norm_att = self.minmax_scale(-att[j], rescale=True)
                    norm_att_1d = self.minmax_scale(torch.mean(att[j], dim=0, keepdim=True), rescale=True) 
                    
                    norm_img_r = self.minmax_scale(inputs[j][0])
                    norm_att_r = self.minmax_scale(att[j][0], rescale=True)
                    norm_img_g = self.minmax_scale(inputs[j][1])
                    norm_att_g = self.minmax_scale(att[j][1], rescale=True)
                    norm_img_b = self.minmax_scale(inputs[j][2])
                    norm_att_b = self.minmax_scale(att[j][2], rescale=True)
                    
                    save_image(norm_img, '../images/%s/color/img%s_img_t%s.png'%(self.args.tag, j, self.args.eval.target_task))
                    save_image(norm_att, '../images/%s/color/att%s_img_t%s.png'%(self.args.tag, j, self.args.eval.target_task))
                    save_image(norm_att*norm_img, '../images/%s/color/imgxatt%s_img_t%s.png'%(self.args.tag, j, self.args.eval.target_task))
                    save_image(norm_att_1d*norm_img, '../images/%s/color/img_x_att_1d_%s_img_t%s.png'%(self.args.tag, j, self.args.eval.target_task))
                    save_image(norm_att_1d*norm_img*full_mask, '../images/%s/color/masked_img_x_att_1d_%s_img_t%s.png'%(self.args.tag, j, self.args.eval.target_task))
                    
                    save_image(norm_img_r, '../images/%s/red/img_%s_img_t%s.png'%(self.args.tag, j, self.args.eval.target_task))
                    save_image(norm_att_r, '../images/%s/red/att_%s_img_t%s.png'%(self.args.tag, j, self.args.eval.target_task))
                    save_image(norm_att_r*norm_img_r, '../images/%s/red/img_x_att_%s_img_t%s.png'%(self.args.tag, j, self.args.eval.target_task))
                    
                    save_image(norm_img_g, '../images/%s/green/img_%s_img_t%s.png'%(self.args.tag, j, self.args.eval.target_task))
                    save_image(norm_att_g, '../images/%s/green/att_%s_img_t%s.png'%(self.args.tag, j, self.args.eval.target_task))
                    save_image(norm_att_g*norm_img_g, '../images/%s/green/img_x_att_%s_img_t%s.png'%(self.args.tag, j, self.args.eval.target_task))
                    
                    save_image(norm_img_b, '../images/%s/blue/img_%s_img_t%s.png'%(self.args.tag, j, self.args.eval.target_task))
                    save_image(norm_att_b, '../images/%s/blue/att_%s_img_t%s.png'%(self.args.tag, j, self.args.eval.target_task))
                    save_image(norm_att_b*norm_img_b, '../images/%s/blue/img_x_att_%s_img_t%s.png'%(self.args.tag, j, self.args.eval.target_task))
                    
                import sys
                sys.exit()
                print('anal here')
                
            data_dict['loss'] = loss.item()            
            data_dict['rep_loss'] = rep_loss.item()
            data_dict['penalty'] = penalty.item()
                    
        torch.cuda.synchronize()    
        
        data_dict.update({'lr': self.lr_scheduler.get_lr()})
        self.buffer.add_data(examples=notaug_inputs, task_labels=torch.ones(self.args.batch_size, dtype=torch.int)*task_id)        
        return data_dict       
    
    def minmax_scale(self, input, rescale=False):
        if rescale:
            max_scale = float(self.args.eval.minmax_max_scale)
            min_scale = float(self.args.eval.minmax_min_scale)
            return (max_scale-min_scale)*(input-torch.min(input))/(torch.max(input)-torch.min(input)+1e-9) + min_scale
        else:
            return (input-torch.min(input))/(torch.max(input)-torch.min(input)+1e-9)
        
        
        
        
        
        
        
        
        
    def current_task_adaptation_observe_eval(self, inputs, mask, notaug_inputs, task_id, att_type='la_sim', logger=None):
        """_summary_

        Args:
            inputs (_type_): minibatch images
            mask (_type_): corresponding masks from a mask generator
            notaug_inputs (_type_): Need to be updated (w/o crop)
            task_id (_type_): task index
            att_type (str, optional): _description_. Defaults to 'la_sim'.
        """
        
        self.opt.zero_grad()
        self.net.zero_grad()
        inputs = inputs.cuda(non_blocking=True)
        mask = mask.cuda(non_blocking=True)
        size = min(int(self.args.train.select_ratio*self.args.batch_size),len(inputs))
        pick = torch.randperm(len(inputs))[:size]
        inputs = inputs[pick]
        mask = mask[pick]                
        data_dict = {}
        if not self.buffer.is_full(logger):
            pass
        else:              
            _ps, _inc = self.net.module.patch_size, self.net.module.in_chans            
            full_mask = mask.repeat_interleave(_ps, 1).repeat_interleave(_ps, 2).unsqueeze(1).contiguous()
            norm_scale = (full_mask.sum() + 1e-5) * _inc
            if self.args.train.accumulation_steps > 1:
                pass
            else:
                self.opt.zero_grad()
                with higher.innerloop_ctx(self.net.module, self.opt, 
                                    copy_initial_weights=False, track_higher_grads=False) as (fmodel, diffopt):
                    for _ in range(self._laps.inner_steps):                                            
                        bf_inputs, _ = self.buffer.get_data(self.args.batch_size, transform=self.transform)
                        bf_inputs = bf_inputs.cuda(non_blocking=True)
                        # new_mask = self.transform.mask_generator()
                        new_mask = [torch.Tensor(self.mask_generator()) for _ in range(self.args.batch_size)]
                        new_mask = torch.stack(new_mask).cuda(non_blocking=True)
                        
                        bf_loss = fmodel(bf_inputs, new_mask) / norm_scale
                        diffopt.step(bf_loss)
                        
                    la_task_loss = fmodel(inputs, mask, reduction='none').detach()

                task_loss = self.net.module(inputs, mask, reduction='none')
                # task_loss_sum = task_loss.sum() / norm_scale 
                changed_task_loss = la_task_loss - task_loss
                if self._laps.type == 'pixel':
                    att = self.minmax_scale(-changed_task_loss) * 2
                    
                elif self._laps.type == 'patch':
                    _mps = self._laps.mask_patch_size
                    att = nn.functional.avg_pool2d(-changed_task_loss, kernel_size=_mps, stride=_mps) # (6, 6)
                    att = self.minmax_scale(att) * 2
                    att = att.repeat_interleave(_mps, 2).repeat_interleave(_mps, 3).contiguous() # (192, 192)

                elif self._laps.type == 'patch_selection':
                    raise NotImplementedError
                else:
                    raise NotImplementedError('wrong laps type: %s'%self._laps.type)

                    
                os.makedirs('../images/%s/color/'%self.args.tag, exist_ok=True)
                os.makedirs('../images/%s/red/'%self.args.tag, exist_ok=True)
                os.makedirs('../images/%s/green/'%self.args.tag, exist_ok=True)
                os.makedirs('../images/%s/blue/'%self.args.tag, exist_ok=True)
                for j in range(len(inputs)):                
                    # if j not in [0,2,6,19,29]:
                    #     continue
                    norm_img = self.minmax_scale(inputs[j])
                    norm_att = self.minmax_scale(-att[j], rescale=True)
                    norm_att_1d = self.minmax_scale(torch.mean(att[j], dim=0, keepdim=True), rescale=True) 
                    
                    norm_img_r = self.minmax_scale(inputs[j][0])
                    norm_att_r = self.minmax_scale(att[j][0], rescale=True)
                    norm_img_g = self.minmax_scale(inputs[j][1])
                    norm_att_g = self.minmax_scale(att[j][1], rescale=True)
                    norm_img_b = self.minmax_scale(inputs[j][2])
                    norm_att_b = self.minmax_scale(att[j][2], rescale=True)

                    save_image(norm_img, '../images/%s/color/img%s_img_t%s.png'%(self.args.tag, j, self.args.eval.target_task))
                    save_image(norm_att, '../images/%s/color/att%s_img_t%s.png'%(self.args.tag, j, self.args.eval.target_task))
                    save_image(norm_att*norm_img, '../images/%s/color/imgxatt%s_img_t%s.png'%(self.args.tag, j, self.args.eval.target_task))
                    save_image(norm_att_1d*norm_img, '../images/%s/color/img_x_att_1d_%s_img_t%s.png'%(self.args.tag, j, self.args.eval.target_task))
                    save_image(norm_att_1d*norm_img*full_mask[j], '../images/%s/color/masked_img_x_att_1d_%s_img_t%s.png'%(self.args.tag, j, self.args.eval.target_task))
                    save_image(norm_att*norm_img*full_mask[j], '../images/%s/color/masked_img_x_att_%s_img_t%s.png'%(self.args.tag, j, self.args.eval.target_task))
                    
                    save_image(norm_img_r, '../images/%s/red/img_%s_img_t%s.png'%(self.args.tag, j, self.args.eval.target_task))
                    save_image(norm_att_r, '../images/%s/red/att_%s_img_t%s.png'%(self.args.tag, j, self.args.eval.target_task))
                    save_image(norm_att_r*norm_img_r, '../images/%s/red/img_x_att_%s_img_t%s.png'%(self.args.tag, j, self.args.eval.target_task))
                    save_image(norm_att_1d*norm_img_r*full_mask[j], '../images/%s/red/masked_img_x_att_1d_%s_img_t%s.png'%(self.args.tag, j, self.args.eval.target_task))
                    
                    save_image(norm_img_g, '../images/%s/green/img_%s_img_t%s.png'%(self.args.tag, j, self.args.eval.target_task))
                    save_image(norm_att_g, '../images/%s/green/att_%s_img_t%s.png'%(self.args.tag, j, self.args.eval.target_task))
                    save_image(norm_att_g*norm_img_g, '../images/%s/green/img_x_att_%s_img_t%s.png'%(self.args.tag, j, self.args.eval.target_task))
                    save_image(norm_att_1d*norm_img_g*full_mask[j], '../images/%s/green/masked_img_x_att_1d_%s_img_t%s.png'%(self.args.tag, j, self.args.eval.target_task))
                    
                    save_image(norm_img_b, '../images/%s/blue/img_%s_img_t%s.png'%(self.args.tag, j, self.args.eval.target_task))
                    save_image(norm_att_b, '../images/%s/blue/att_%s_img_t%s.png'%(self.args.tag, j, self.args.eval.target_task))
                    save_image(norm_att_b*norm_img_b, '../images/%s/blue/img_x_att_%s_img_t%s.png'%(self.args.tag, j, self.args.eval.target_task))
                    save_image(norm_att_1d*norm_img_b*full_mask[j], '../images/%s/blue/masked_img_x_att_1d_%s_img_t%s.png'%(self.args.tag, j, self.args.eval.target_task))
                    
                import sys
                sys.exit()
        torch.cuda.synchronize()    
        
        data_dict.update({'lr': self.lr_scheduler.get_lr()})
        self.buffer.add_data(examples=notaug_inputs, task_labels=torch.ones(self.args.batch_size, dtype=torch.int)*task_id)        
        return data_dict       