from utils.buffer import Buffer
from torch.nn import functional as F
from models.utils.continual_model import ContinualModel
from augmentations import get_aug
import numpy as np
import torch
try:
    # noinspection PyUnresolvedReferences
    from apex import amp
except ImportError:
    amp = None

class Mixup(ContinualModel):
    NAME = 'mixup'
    def __init__(self, backbone, loss, args, transform, logger):
        super(Mixup, self).__init__(backbone, loss, args, transform, logger)
        self.buffer = Buffer(self.args.model.buffer_size, self.device)
        self._lump = self.args.hyperparameters.MIXUP

    def observe(self, inputs, labels, notaug_inputs):
        self.opt.zero_grad()
        data_dict = {}
        data_dict['penalty'] = 0.0
        if self.buffer.is_empty():
            labels = labels.to(self.device)
            outputs = self.net.module.backbone(inputs.to(self.device))
            loss = self.loss(outputs, labels)
            data_dict = {'loss': loss.item()}
            data_dict['penalty'] = 0
        else:
            buf_inputs, buf_labels = self.buffer.get_data(
                self.args.train.batch_size, transform=self.transform)
            buf_labels = buf_labels.to(self.device).long()
            labels = labels.to(self.device).long()
            lam = np.random.beta(self.c, self.c)
            mixed_x = lam * inputs.to(self.device) + (1 - lam) * buf_inputs[:inputs.shape[0]].to(self.device)
            net_output = self.net.module.backbone(mixed_x.to(self.device, non_blocking=True))
            buf_labels = buf_labels[:inputs.shape[0]].to(self.device)
            loss = self.loss(net_output, labels) + (1 - lam) * self.loss(net_output, buf_labels)
            data_dict = {'loss': loss.item()}
            data_dict['penalty'] = 0.0

        loss.backward()
        self.opt.step()
        data_dict.update({'lr': self.args.train.base_lr})

        self.buffer.add_data(examples=notaug_inputs, logits=labels)
        return data_dict

    def siamese_observe(self, inputs_a, inputs_b, image_paths, task_id, batch_idx=None, epoch_idx=None, num_steps=None):
        self.opt.zero_grad()
        inputs_a = inputs_a.cuda(non_blocking=True)
        inputs_b = inputs_b.cuda(non_blocking=True)
        
        data_dict = {}
        data_dict['penalty'] = 0.0
        # select samples (random)
        if self.args.train.select_ratio < 1:
            size = min(int(self.args.train.select_ratio*self.args.batch_size),len(inputs_a))
            pick = torch.randperm(len(inputs_a))[:size]
            inputs_a = inputs_a[pick]
            inputs_b = inputs_b[pick]
        
        if self.buffer.is_empty():                        
            loss = self.net(inputs_a, inputs_b)
            data_dict['loss'] = loss.item()
        else:
            buf_inputs, buf_inputs1 = self.buffer.get_data(
                    self.args.batch_size, transform=self.transform)
            lam = np.random.beta(self._lump.alpha, self._lump.alpha)            
            
            mixed_x = lam * inputs_a + (1 - lam) * buf_inputs[:inputs_a.shape[0]].cuda(non_blocking=True)
            mixed_x_aug = lam * inputs_b + (1 - lam) * buf_inputs1[:inputs_a.shape[0]].cuda(non_blocking=True)
            loss = self.net(mixed_x, mixed_x_aug)
            data_dict['loss'] = loss.item()

        self.opt.zero_grad()
        if self.args.amp_opt_level != "O0":
            with amp.scale_loss(loss, self.opt) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()
        self.opt.step()
        self.lr_scheduler.step_update(index=epoch_idx*num_steps+batch_idx)

        torch.cuda.synchronize()
        data_dict.update({'lr': self.lr_scheduler.get_lr()})
        self.buffer.add_data(examples=inputs_a.data, paths=image_paths, logits=inputs_b.data)        
        return data_dict


    def masked_observe(self, inputs, mask, image_paths, task_id, batch_idx=None, epoch_idx=None, num_steps=None):
        self.opt.zero_grad()
        inputs = inputs.cuda(non_blocking=True)
        mask = mask.cuda(non_blocking=True)
        data_dict = {}
        data_dict['penalty'] = 0.0
        
        if self.args.train.select_ratio < 1:
            size = min(int(self.args.train.select_ratio*self.args.batch_size), len(inputs))
            pick = torch.randperm(len(inputs))[:size]
            inputs = inputs[pick]
            mask = mask[pick]


        if self.buffer.is_empty():
            loss = self.net(inputs, mask)
            data_dict['loss'] = loss.item()

        else:
            buf_inputs, buf_masks = self.buffer.get_data(
                            self.args.batch_size, transform=self.transform)
            buf_inputs = buf_inputs.cuda(non_blocking=True)
            buf_masks = buf_masks.cuda(non_blocking=True)
            
            lam = np.random.beta(self._lump.alpha, self._lump.alpha)            
            mixed_x = lam * inputs + (1 - lam) * buf_inputs
            loss = self.net(mixed_x, buf_masks)
            data_dict['loss'] = loss.item()

        self.opt.zero_grad()
        if self.args.amp_opt_level != "O0":
            with amp.scale_loss(loss, self.opt) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()
        self.opt.step()
        self.lr_scheduler.step_update(index=epoch_idx*num_steps+batch_idx)

        torch.cuda.synchronize()
        data_dict.update({'lr': self.lr_scheduler.get_lr()})
        self.buffer.add_data(examples=inputs.data, paths=image_paths, masks=mask)        
        return data_dict
