import torch
from utils.buffer import Buffer
from torch.nn import functional as F
from models.utils.continual_model import ContinualModel
from models.optimizers import get_optimizer
from models.optimizers.lr_scheduler import LR_Scheduler
from augmentations import get_aug
import copy
import pdb

try:
    # noinspection PyUnresolvedReferences
    from apex import amp
except ImportError:
    amp = None

class Der(ContinualModel):
    NAME = 'der'
    def __init__(self, backbone, loss, args, transform, logger):
        super(Der, self).__init__(backbone, loss, args, transform, logger)            
        self._der = self.args.hyperparameters.DER
        self.buffer = Buffer(self.args.model.buffer_size, self.device)

    def supervised_observe(self, inputs, labels, image_paths, task_id, batch_idx=None, epoch_idx=None, num_steps=None):
        self.opt.zero_grad()
        inputs = inputs.cuda(non_blocking=True)
        labels = labels.cuda(non_blocking=True)
        pred, latent_emb = self.net(inputs, return_emb=True)
        loss = self.loss(pred, labels)
        
        data_dict = {}
        data_dict['penalty'] = 0.0        
        data_dict['loss'] = loss.item()        
        data_dict['rep_loss'] = loss.item()

        self.opt.zero_grad()
        if self.args.amp_opt_level != "O0":
            with amp.scale_loss(loss, self.opt) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()
        self.opt.step()
        
        if not self.buffer.is_empty():
            buf_inputs, buf_logits = self.buffer.get_data(
                self.args.batch_size, transform=self.transform)  
            
            buf_inputs = buf_inputs.cuda(non_blocking=True)
            buf_logits = buf_logits.cuda(non_blocking=True)
            
            _, buf_embed = self.net(buf_inputs, return_emb=True)
            penalty = self._der.ref_hyp * F.mse_loss(buf_embed, buf_logits)
            
            data_dict['penalty'] = penalty.item()
            data_dict['loss'] += data_dict['penalty']
            
            self.opt.zero_grad()
            if self.args.amp_opt_level != "O0":
                with amp.scale_loss(penalty, self.opt) as scaled_loss:
                    scaled_loss.backward()
            else:
                penalty.backward()
        self.opt.step()
        self.lr_scheduler.step_update(index=epoch_idx*num_steps+batch_idx)

        torch.cuda.synchronize()    
        data_dict.update({'lr': self.lr_scheduler.get_lr()})
        self.buffer.add_data(examples=inputs.data, paths=image_paths, logits=latent_emb.detach().data)
        return data_dict

    def siamese_observe(self, inputs_a, inputs_b, image_paths, task_id, batch_idx=None, epoch_idx=None, num_steps=None):
        self.opt.zero_grad()
        inputs_a = inputs_a.cuda(non_blocking=True)
        inputs_b = inputs_b.cuda(non_blocking=True)

        data_dict = {}
        data_dict['penalty'] = 0.0
        # if not self.buffer.is_full(self.logger):            
        # loss = self.net(inputs_a, inputs_b)
        loss, latent_emb = self.net(inputs_a, inputs_b, return_emb=True)
        data_dict['loss'] = loss.item()        
        data_dict['rep_loss'] = loss.item()

        self.opt.zero_grad()
        if self.args.amp_opt_level != "O0":
            with amp.scale_loss(loss, self.opt) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()     
        self.opt.step()
        
        if not self.buffer.is_empty():
            # _ps, _inc = self.net.module.patch_size, self.net.module.in_chans            
            bf_inputs_a, buf_logits = self.buffer.get_data(self.args.batch_size, transform=self.transform, get_mask=False)  
            
            # bf_inputs_a, bf_inputs_b = self.buffer.get_data(self.args.batch_size, transform=self.transform, get_mask=False)
            bf_inputs_a = bf_inputs_a.cuda(non_blocking=True)
            buf_logits = buf_logits.cuda(non_blocking=True)            
                        
            # rep_loss = self.net(torch.cat([inputs_a, bf_inputs_a]), torch.cat([inputs_b, bf_inputs_b]))
            # _, buf_embed = self.net(bf_inputs_a, bf_inputs_b, return_emb=True)
            
            buf_embed = self.net.module.backbone(bf_inputs_a)
            
            penalty = self._der.ref_hyp * F.mse_loss(buf_embed, buf_logits)
            
            data_dict['penalty'] = penalty.item()
            data_dict['loss'] += data_dict['penalty']
            
            self.opt.zero_grad()
            if self.args.amp_opt_level != "O0":
                with amp.scale_loss(penalty, self.opt) as scaled_loss:
                    scaled_loss.backward()
            else:
                penalty.backward()
        self.opt.step()
        self.lr_scheduler.step_update(index=epoch_idx*num_steps+batch_idx)

        torch.cuda.synchronize()
        data_dict.update({'lr': self.lr_scheduler.get_lr()})
        # self.buffer.add_data(examples=inputs_a, paths=image_paths)
        # self.buffer.add_data(examples=inputs_a, paths=image_paths, logits=latent_emb.detach().data)
        self.buffer.add_data(examples=inputs_a, paths=image_paths, logits=latent_emb.detach().data)
        return data_dict
    
    def masked_observe(self, inputs, mask, image_paths, task_id, batch_idx=None, epoch_idx=None, num_steps=None):
        self.opt.zero_grad()
        inputs = inputs.cuda(non_blocking=True)
        mask = mask.cuda(non_blocking=True)
        
        size = min(int(self.args.train.select_ratio*self.args.batch_size),len(inputs))
        pick = torch.randperm(len(inputs))[:size]
        inputs = inputs[pick]
        mask = mask[pick]                
        data_dict = {}
        
        loss, latent_emb = self.net(inputs, mask, return_emb=True)
        data_dict['loss'] = loss.item()
        data_dict['penalty'] = 0.0
        data_dict['rep_loss'] = loss.item()
        # [batch_size, 768, 7, 7]
        
        
        self.opt.zero_grad()
        if self.args.amp_opt_level != "O0":
            with amp.scale_loss(loss, self.opt) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()
        self.opt.step()
            # self.lr_scheduler.step_update(epoch_idx * num_steps + batch_idx)                
        if not self.buffer.is_empty():
            # bf_inputs, bf_masks = self.buffer.get_data(self.args.batch_size, transform=self.transform, get_mask=True)                
            buf_inputs, buf_logits, buf_masks = self.buffer.get_data(
                self.args.batch_size, transform=self.transform)  
            
            buf_inputs = buf_inputs.cuda(non_blocking=True)
            buf_logits = buf_logits.cuda(non_blocking=True)
            buf_masks = buf_masks.cuda(non_blocking=True)
            
            _, buf_embed = self.net(buf_inputs, buf_masks, return_emb=True)
            penalty = self._der.ref_hyp * F.mse_loss(buf_embed, buf_logits)
            # loss += penalty
            data_dict['penalty'] = penalty.item()
            data_dict['loss'] += data_dict['penalty']
            
            self.opt.zero_grad()
            if self.args.amp_opt_level != "O0":
                with amp.scale_loss(penalty, self.opt) as scaled_loss:
                    scaled_loss.backward()
            else:
                penalty.backward()
        self.opt.step()
        self.lr_scheduler.step_update(epoch_idx * num_steps + batch_idx)        

        torch.cuda.synchronize()    
        data_dict.update({'lr': self.lr_scheduler.get_lr()})
        self.buffer.add_data(examples=inputs.data, paths=image_paths, logits=latent_emb.detach().data, masks=mask)        
        # self.buffer.add_data(examples=notaug_inputs, logits=self.net(notaug_inputs.to(self.device)).detach().data)
        return data_dict
