# from torch.nn import functional as F
from models.utils.continual_model import ContinualModel
from utils.buffer import Buffer
import torch
# from augmentations import get_aug

try:
    # noinspection PyUnresolvedReferences
    from apex import amp
except ImportError:
    amp = None
    
class Uniform(ContinualModel):
    # NAME = 'uniform'
    # COMPATIBILITY = ['class-il', 'domain-il', 'task-il', 'general-continual']
    def __init__(self, backbone, loss, args, transform, logger):
        super(Uniform, self).__init__(backbone, loss, args, transform, logger)
        self._uniform = self.args.hyperparameters.UNIFORM
        # logger.info(f'UNIFORM hyperparameter {self._uniform.ref_hyp}')
        logger.info(f'UNIFORM hyperparameter: 1.0 (default)')
        self.buffer = Buffer(self.args.model.buffer_size, self.device, configs=self.args)
        
    def supervised_observe(self, inputs, labels, notaug_inputs, image_paths, task_id, batch_idx=None, epoch_idx=None, num_steps=None):
        self.opt.zero_grad()
        inputs = inputs.cuda(non_blocking=True)
        labels = labels.cuda(non_blocking=True)

        data_dict = {}
        data_dict['penalty'] = 0.0
        if not self.buffer.is_full(self.logger):            
            loss = self.net(inputs)
            data_dict['loss'] = loss.item()        
            data_dict['rep_loss'] = loss.item()

            self.opt.zero_grad()
            if self.args.amp_opt_level != "O0":
                with amp.scale_loss(loss, self.opt) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()        
        else:
            # _ps, _inc = self.net.module.patch_size, self.net.module.in_chans            
            bf_inputs = self.buffer.get_data(self.args.batch_size, transform=self.transform, get_mask=False)
            bf_inputs = bf_inputs.cuda(non_blocking=True)
                        
            rep_loss = self.net(torch.cat([inputs, bf_inputs]))
            loss = rep_loss
            data_dict['loss'] = loss.item()            
            data_dict['rep_loss'] = rep_loss.item()
                               
            self.opt.zero_grad()
            if self.args.amp_opt_level != "O0":
                with amp.scale_loss(loss, self.opt) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()        
        self.opt.step()
        self.lr_scheduler.step_update(index=epoch_idx*num_steps+batch_idx)

        torch.cuda.synchronize()
        data_dict.update({'lr': self.lr_scheduler.get_lr()})
        self.buffer.add_data(examples=notaug_inputs, paths=image_paths)
        return data_dict

    def siamese_observe(self, inputs_a, inputs_b, image_paths, task_id, batch_idx=None, epoch_idx=None, num_steps=None):
        self.opt.zero_grad()
        inputs_a = inputs_a.cuda(non_blocking=True)
        inputs_b = inputs_b.cuda(non_blocking=True)

        data_dict = {}
        data_dict['penalty'] = 0.0
        if not self.buffer.is_full(self.logger):            
            loss = self.net(inputs_a, inputs_b)
            data_dict['loss'] = loss.item()        
            data_dict['rep_loss'] = loss.item()

            self.opt.zero_grad()
            if self.args.amp_opt_level != "O0":
                with amp.scale_loss(loss, self.opt) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()        
        else:
            # _ps, _inc = self.net.module.patch_size, self.net.module.in_chans            
            bf_inputs_a, bf_inputs_b = self.buffer.get_data(self.args.batch_size, transform=self.transform, get_mask=False)
            bf_inputs_a = bf_inputs_a.cuda(non_blocking=True)
            bf_inputs_a = bf_inputs_b.cuda(non_blocking=True)            
                        
            rep_loss = self.net(torch.cat([inputs_a, bf_inputs_a]), torch.cat([inputs_b, bf_inputs_b]))
            loss = rep_loss
            data_dict['loss'] = loss.item()            
            data_dict['rep_loss'] = rep_loss.item()
                               
            self.opt.zero_grad()
            if self.args.amp_opt_level != "O0":
                with amp.scale_loss(loss, self.opt) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()            
        self.opt.step()
        self.lr_scheduler.step_update(index=epoch_idx*num_steps+batch_idx)

        torch.cuda.synchronize()
        data_dict.update({'lr': self.lr_scheduler.get_lr()})
        self.buffer.add_data(examples=inputs_a, paths=image_paths)
        return data_dict
    
    def masked_observe(self, inputs, mask, image_paths, task_id, batch_idx=None, epoch_idx=None, num_steps=None):
        self.opt.zero_grad()
        inputs = inputs.cuda(non_blocking=True)
        mask = mask.cuda(non_blocking=True)
        # size = min(int(self.args.train.select_ratio*self.args.batch_size),len(inputs))
        # pick = torch.randperm(len(inputs))[:size]
        # inputs = inputs[pick]
        # mask = mask[pick]                
        data_dict = {}
        data_dict['penalty'] = 0.0
        if not self.buffer.is_full(self.logger):
            loss = self.net(inputs, mask)
            data_dict['loss'] = loss.item()        
            data_dict['rep_loss'] = loss.item()
            
            self.opt.zero_grad()
            if self.args.amp_opt_level != "O0":
                with amp.scale_loss(loss, self.opt) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()        
        else:   
            # _ps, _inc = self.net.module.patch_size, self.net.module.in_chans            
            bf_inputs, bf_masks = self.buffer.get_data(self.args.batch_size, transform=self.transform, get_mask=True)
            bf_inputs = bf_inputs.cuda(non_blocking=True)
            bf_masks = bf_masks.cuda(non_blocking=True)                
                        
            rep_loss = self.net(torch.cat([inputs, bf_inputs]), torch.cat([mask, bf_masks]))
            loss = rep_loss
            data_dict['loss'] = loss.item()            
            data_dict['rep_loss'] = rep_loss.item()
                               
            self.opt.zero_grad()
            if self.args.amp_opt_level != "O0":
                with amp.scale_loss(loss, self.opt) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()            
        self.opt.step()
        self.lr_scheduler.step_update(epoch_idx * num_steps + batch_idx)
        torch.cuda.synchronize()            
        data_dict.update({'lr': self.lr_scheduler.get_lr()})
        self.buffer.add_data(examples=inputs, paths=image_paths)
        return data_dict