from models.utils.continual_model import ContinualModel
import torch
# from models.optimizers import get_optimizer
# from models.optimizers.lr_scheduler import LR_Scheduler
import torch.distributed as dist

try:
    # noinspection PyUnresolvedReferences
    from apex import amp
except ImportError:
    amp = None
# import torch.cuda.amp as amp

class UnsupNaive(ContinualModel):
    NAME = 'unsupnaive'
    #COMPATIBILITY = ['class-il', 'domain-il', 'task-il']
    def __init__(self, backbone, loss, args, transform, logger):
        super(UnsupNaive, self).__init__(backbone, loss, args, transform, logger)

    def supervised_observe(self, inputs, labels, task_id, batch_idx=None, epoch_idx=None, num_steps=None):
        self.opt.zero_grad()
        inputs = inputs.cuda(non_blocking=True)
        labels = labels.cuda(non_blocking=True)

        # select samples (random)
        if self.args.train.select_ratio < 1:
            size = min(int(self.args.train.select_ratio*self.args.batch_size),len(inputs))
            pick = torch.randperm(len(inputs))[:size]
            inputs = inputs[pick]
            labels = labels[pick]
        loss = self.loss(self.net(inputs), labels)
        data_dict = {'loss': loss.item()}
        data_dict['penalty'] = 0.0

        self.opt.zero_grad()
        if self.args.amp_opt_level != "O0":
            with amp.scale_loss(loss, self.opt) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()
        self.opt.step()
        self.lr_scheduler.step_update(index=epoch_idx*num_steps+batch_idx)

        torch.cuda.synchronize()
        data_dict.update({'lr': self.lr_scheduler.get_lr()})
        return data_dict

    def siamese_observe(self, inputs_a, inputs_b, notaug_inputs, task_id, batch_idx=None, epoch_idx=None, num_steps=None):
        self.opt.zero_grad()
        inputs_a = inputs_a.cuda(non_blocking=True)
        inputs_b = inputs_b.cuda(non_blocking=True)

        # select samples (random)
        if self.args.train.select_ratio < 1:
            size = min(int(self.args.train.select_ratio*self.args.batch_size),len(inputs_a))
            pick = torch.randperm(len(inputs_a))[:size]
            inputs_a = inputs_a[pick]
            inputs_b = inputs_b[pick]

        loss = self.net(inputs_a, inputs_b)
        data_dict = {'loss': loss.item()}
        data_dict['penalty'] = 0.0

        self.opt.zero_grad()
        if self.args.amp_opt_level != "O0":
            with amp.scale_loss(loss, self.opt) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()
        self.opt.step()
        self.lr_scheduler.step_update(index=epoch_idx*num_steps+batch_idx)

        torch.cuda.synchronize()
        data_dict.update({'lr': self.lr_scheduler.get_lr()})
        return data_dict

    def masked_observe(self, inputs, mask, notaug_inputs, task_id, batch_idx=None, epoch_idx=None, num_steps=None):
        self.opt.zero_grad()
        inputs = inputs.cuda(non_blocking=True)
        mask = mask.cuda(non_blocking=True)

        # select samples (random)
        if self.args.train.select_ratio < 1:
            size = min(int(self.args.train.select_ratio*self.args.batch_size), len(inputs))
            pick = torch.randperm(len(inputs))[:size]
            inputs = inputs[pick]
            mask = mask[pick]

        loss = self.net(inputs, mask)
        data_dict = {'loss': loss.item()}
        data_dict['penalty'] = 0.0

        self.opt.zero_grad()
        if self.args.amp_opt_level != "O0":
            with amp.scale_loss(loss, self.opt) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()
        self.opt.step()
        self.lr_scheduler.step_update(index=epoch_idx*num_steps+batch_idx)

        torch.cuda.synchronize()
        data_dict.update({'lr': self.lr_scheduler.get_lr()})
        return data_dict
    
    
    def mae_observe(self, inputs, notaug_inputs, task_id, batch_idx=None, epoch_idx=None, num_steps=None):
        self.opt.zero_grad()
        inputs = inputs.cuda(non_blocking=True)
        # mask = mask.cuda(non_blocking=True)

        # select samples (random)
        if self.args.train.select_ratio < 1:
            size = min(int(self.args.train.select_ratio*self.args.batch_size), len(inputs))
            pick = torch.randperm(len(inputs))[:size]
            inputs = inputs[pick]
            # mask = mask[pick]

        loss = self.net(inputs)
        data_dict = {'loss': loss.item()}
        data_dict['penalty'] = 0.0

        self.opt.zero_grad()
        if self.args.amp_opt_level != "O0":
            with amp.scale_loss(loss, self.opt) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()
        self.opt.step()
        self.lr_scheduler.step_update(index=epoch_idx*num_steps+batch_idx)

        torch.cuda.synchronize()
        data_dict.update({'lr': self.lr_scheduler.get_lr()})
        return data_dict

    def gkeep_observe(self, inputs, mask):
        inputs = inputs.cuda(non_blocking=True)
        mask = mask.cuda(non_blocking=True)
        loss = self.net(inputs, mask)
    
        self.opt.zero_grad()
        if self.args.amp_opt_level != "O0":
            with amp.scale_loss(loss, self.opt) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()
        self.opt.step()
        if dist.get_rank() == 0:
            import pdb; pdb.set_trace()
        dist.barrier()