import torch.nn as nn
# from utils.args import *
from .utils.continual_model import ContinualModel
from models.utils.continual_model import ContinualModel
from models.optimizers import get_optimizer
from models.optimizers.lr_scheduler import LR_Scheduler
import torch


try:
    # noinspection PyUnresolvedReferences
    from apex import amp
except ImportError:
    amp = None
    
class SI(ContinualModel):
    NAME = 'si'
    COMPATIBILITY = ['class-il', 'domain-il', 'task-il']

    def __init__(self, backbone, loss, args, transform, logger):
        super(SI, self).__init__(backbone, loss, args, transform, logger)
        self.big_omega = None
        self.small_omega = 0
        self._si = self.args.hyperparameters.SI
        self.xi = 0.1
        
        self.target_weights = [param for name, param in self.net.named_parameters() 
                                    if 'decoder' not in name 
                                    and 'norm' not in name 
                                    and 'mask' not in name
                                    and 'bias' not in name
                                    and 'predictor' not in name # simsiam
                                    and 'projector' not in name # simsiam
                                    ]
        
        
        # self.target_names = [name for name, param in self.net.named_parameters() 
        #                             if 'decoder' not in name 
        #                             and 'norm' not in name 
        #                             and 'mask' not in name
        #                             and 'bias' not in name
        #                             and 'predictor' not in name # simsiam
        #                             and 'projector' not in name # simsiam
        #                             ]
        self.checkpoint = self.get_weights_params().data.clone().cuda()
        self.seen_tasks = 0
        
    def get_weights_params(self):
        params = []                
        for pp in self.target_weights:
          # if pp.grad is not None:
          params.append(pp.view(-1))
        return torch.cat(params)

    def get_weights_grads(self):
        grads = []
        for pp in self.target_weights:
            # if pp.grad is not None:
            grads.append(pp.grad.view(-1))
        return torch.cat(grads)

    def task_incremental_learning(self, output):
        output[:, :self.args.task_id*self.args.dataset.n_classes_per_task].data.fill_(-10e10)
        output[:, (self.args.task_id+1)*self.args.dataset.n_classes_per_task:].data.fill_(-10e10)
        return output


    def penalty(self):
        if self.big_omega is None:
            return torch.tensor(0.0).cuda()
        else:
            # penalty = (self.big_omega * ((self.get_params() - self.checkpoint) ** 2)).sum()
            penalty = self._si.reg_hyp * (self.big_omega * ((self.get_weights_params() - self.checkpoint) ** 2)).sum()
            if penalty.item() > 1e-8:
                return penalty
            else:
                return torch.tensor(0.0).cuda()

    def end_task(self, logger):
        # big omega calculation step
        # logger.info(f'start task {args.train.start_task}')
        self.seen_tasks += 1
        if self.big_omega is None:
            self.big_omega = torch.zeros_like(self.get_weights_params()).cuda()
        logger.info(f'seen tasks {self.seen_tasks}')
        logger.info('big_omega_norm: %s'%torch.norm(self.big_omega).detach().data.cpu())        
        # logger.info(f'small_omega_norm: %s'%torch.norm(self.small_omega).detach().data.cpu())        
        logger.info('checkpoint_norm: %s'%torch.norm(self.checkpoint).detach().data.cpu())
        self.big_omega += self.small_omega / ((self.get_weights_params().data - self.checkpoint) ** 2 + self.xi)
        self.checkpoint = self.get_weights_params().data.clone().cuda()
        self.small_omega = 0
        self.lr_scheduler.reset()


    def supervised_observe(self, inputs, labels, task_id, batch_idx=None, epoch_idx=None, num_steps=None):
        self.opt.zero_grad()
        inputs = inputs.cuda(non_blocking=True)
        labels = labels.cuda(non_blocking=True)

        # select samples (random)
        if self.args.train.select_ratio < 1:
            size = min(int(self.args.train.select_ratio*self.args.batch_size),len(inputs))
            pick = torch.randperm(len(inputs))[:size]
            inputs = inputs[pick]
            labels = labels[pick]
        rep_loss = self.loss(self.net(inputs), labels)
        penalty = self.penalty() 
        loss = rep_loss + penalty
        
        data_dict = {'loss': loss.item()}
        data_dict['rep_loss'] = rep_loss.item()
        data_dict['penalty'] = penalty.item()

        self.opt.zero_grad()
        if self.args.amp_opt_level != "O0":
            with amp.scale_loss(loss, self.opt) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()
        self.opt.step()
        self.lr_scheduler.step_update(index=epoch_idx*num_steps+batch_idx)

        torch.cuda.synchronize()
        data_dict.update({'lr': self.lr_scheduler.get_lr()})
        self.small_omega += self.lr_scheduler.get_lr() * self.get_weights_grads().data ** 2
        return data_dict

    def siamese_observe(self, inputs_a, inputs_b, notaug_inputs, task_id, batch_idx=None, epoch_idx=None, num_steps=None):
        self.opt.zero_grad()
        inputs_a = inputs_a.cuda(non_blocking=True)
        inputs_b = inputs_b.cuda(non_blocking=True)

        # select samples (random)
        if self.args.train.select_ratio < 1:
            size = min(int(self.args.train.select_ratio*self.args.batch_size),len(inputs_a))
            pick = torch.randperm(len(inputs_a))[:size]
            inputs_a = inputs_a[pick]
            inputs_b = inputs_b[pick]

        rep_loss = self.net(inputs_a, inputs_b)
        penalty = self.penalty()         
        loss = rep_loss + penalty
        
        data_dict = {'loss': loss.item()}
        data_dict['rep_loss'] = rep_loss.item()
        data_dict['penalty'] = penalty.item()

        self.opt.zero_grad()
        if self.args.amp_opt_level != "O0":
            with amp.scale_loss(loss, self.opt) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()
        self.opt.step()
        self.lr_scheduler.step_update(index=epoch_idx*num_steps+batch_idx)

        torch.cuda.synchronize()
        data_dict.update({'lr': self.lr_scheduler.get_lr()})
        self.small_omega += self.lr_scheduler.get_lr() * self.get_weights_grads().data ** 2
        return data_dict


    def masked_observe(self, inputs, mask, image_paths, task_id, batch_idx=None, epoch_idx=None, num_steps=None, logger=None):
        self.opt.zero_grad()
        inputs = inputs.cuda(non_blocking=True)
        mask = mask.cuda(non_blocking=True)

        penalty = self.penalty() 
        rep_loss = self.net(inputs, mask)
        loss = rep_loss + penalty
        
        data_dict = {'loss': loss.item()}
        data_dict['rep_loss'] = rep_loss.item()
        data_dict['penalty'] = penalty.item()
        
        self.opt.zero_grad()
        if self.args.amp_opt_level != "O0":
            with amp.scale_loss(loss, self.opt) as scaled_loss:
                scaled_loss.backward()
                nn.utils.clip_grad.clip_grad_value_(self.net.parameters(), 1)
        else:
            loss.backward()
        self.opt.step()
        self.lr_scheduler.step_update(index=epoch_idx * num_steps + batch_idx)            
        torch.cuda.synchronize()
        data_dict.update({'lr': self.lr_scheduler.get_lr()})
        self.small_omega += self.lr_scheduler.get_lr() * self.get_weights_grads().data ** 2
        return data_dict
        
