import os
# import numpy as np
from models.utils.continual_model import ContinualModel
import torch
from utils.buffer import Buffer
import higher
import pdb
import torch.nn as nn
from torchvision.utils import save_image
from augmentations.imagenet_aug import MaskGenerator, NonBinaryMaskGenerator

try:
    # noinspection PyUnresolvedReferences
    from apex import amp
except ImportError:
    amp = None
    
class LAPS(ContinualModel):
    NAME = 'laps'
    def __init__(self, backbone, loss, args, len_train_loader, transform, logger):
        super(LAPS, self).__init__(backbone, loss, args, len_train_loader, transform, logger)
        self._laps = self.args.hyperparameters.LAPS
        self.buffer = Buffer(self.args.model.buffer_size, self.device, configs=self.args)
        self._ps = self.net.patch_size
        
        self.curr_epoch = -1
        if self.args.model.type == 'swin':
            model_patch_size=self.args.model.swin.patch_size
        elif self.args.model.type == 'vit':
            model_patch_size=self.args.model.vit.patch_size
        else:
            raise NotImplementedError
            
        self.mg = MaskGenerator(
                input_size=self.args.dataset.image_size,
                mask_patch_size=self.args.model.mask_patch_size,
                model_patch_size=model_patch_size,
                mask_ratio=self.args.model.mask_ratio,
            )

        self.nbmg = NonBinaryMaskGenerator(
                input_size=self.args.dataset.image_size,
                mask_patch_size=self.args.model.mask_patch_size,
                model_patch_size=model_patch_size,
                mask_ratio=self.args.model.mask_ratio,
                sample='uniform'
            )
        self.scale = self.args.model.mask_patch_size // model_patch_size # 8

    def masked_observe(self, inputs, mask, image_paths, task_id, batch_idx=None, epoch_idx=None, num_steps=None):
        inputs = inputs.cuda(non_blocking=True)
        mask = mask.cuda(non_blocking=True)
        data_dict = {'penalty': 0.0, 'loss': 0.0, 'rep_loss': 0.0}
        
        if not self.buffer.is_full(self.logger):
            loss = self.net(inputs, mask)            
        else:   
            bf_inputs, bf_masks = self.buffer.get_data(self.args.batch_size, transform=self.transform, get_mask=True)
            bf_inputs = bf_inputs.cuda(non_blocking=True)
            bf_masks = bf_masks.cuda(non_blocking=True)                
            mask_grad_accum = 0
                       
            with higher.innerloop_ctx(self.net.module, 
                                      self.opt, 
                                      copy_initial_weights=False, 
                                      track_higher_grads=True) as (fmodel, diffopt):                    
                
                inner_batch = self.args.batch_size // self._laps.inner_steps
                for _i in range(self._laps.inner_steps):                        
                    # a step forward using inputs
                    target_mask = torch.stack([torch.Tensor(self.mg()) 
                                        if self._laps.la_mask_type == 'binary' else 
                                        torch.Tensor(self.nbmg()) for _ in range(len(inputs))])
                    target_mask = nn.Parameter(target_mask.cuda(non_blocking=True))                                       
                    la_loss = fmodel(inputs, torch.sigmoid(target_mask))
                    diffopt.step(la_loss)
                    
                    # compute loss after updating current training instances
                    la_bf_loss = fmodel(bf_inputs[inner_batch*_i:inner_batch*(_i+1)], bf_masks[inner_batch*_i:inner_batch*(_i+1)])
                    mask_loss = torch.norm(torch.sigmoid(target_mask).sum([1,2]) / (mask.sum([1,2]) + 1e-8) - 1.0)
                    la_forget_loss = la_bf_loss + mask_loss
                    mask_grad_accum += torch.autograd.grad(la_forget_loss, target_mask)[0]
        
            adap_mask = mask_grad_accum
        
            adap_mask = nn.functional.avg_pool2d(adap_mask, 
                                        kernel_size=self.nbmg.scale, 
                                        stride=self.nbmg.scale) # (6, 6)            
            
            argsrt_mask = (torch.sort(adap_mask.view(self.args.batch_size, -1), descending=True)[1] < self.nbmg.mask_count).type_as(mask)
            argsrt_mask = argsrt_mask.reshape((self.args.batch_size, self.nbmg.rand_size, self.nbmg.rand_size)) # (6, 6)
            argsrt_mask = argsrt_mask.repeat_interleave(self.nbmg.scale, 1).repeat_interleave(self.nbmg.scale, 2).contiguous() # (192, 192)                
            rep_loss = self.net(torch.cat([inputs, bf_inputs]), torch.cat([argsrt_mask, bf_masks]))
            loss = rep_loss
                                                    
            if self.buffer.is_full(self.logger) and (self.curr_epoch == -1 or epoch_idx == 0 or (epoch_idx+1) % 100 == 0):
                full_mask = argsrt_mask.repeat_interleave(self._ps, 1).repeat_interleave(self._ps, 2).unsqueeze(1).contiguous()
                for _i in range(len(inputs)):
                    mkdir='../images/0506_LAPS/%s/t%s_ep%s'%(self.args.name, task_id, epoch_idx)
                    os.makedirs(mkdir, exist_ok=True)
                    save_image(inputs[_i]*full_mask[_i], os.path.join(mkdir,'img%s.png'%(_i)))
                self.curr_epoch += 1
                
        data_dict['loss'] = loss.item()       
        data_dict['rep_loss'] = loss.item()                                
        self.opt.zero_grad()
        if self.args.amp_opt_level != "O0":
            with amp.scale_loss(loss, self.opt) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()
        self.opt.step()
        
        self.lr_scheduler.step_update(epoch_idx * num_steps + batch_idx)
        # torch.cuda.synchronize()    
        # dist.barrier()
        data_dict.update({'lr': self.lr_scheduler.get_lr()})
        self.buffer.add_data(examples=inputs, paths=image_paths)
        return data_dict       