# An implementation of Experience Replay (ER) with reservoir sampling and without using tasks from Algorithm 4 of https://openreview.net/pdf?id=B1gTShAct7
#
# Copyright 2019-present, IBM Research
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import torch
import torch.nn as nn
from torch.nn.modules.loss import CrossEntropyLoss
from torch.nn import functional as F

import torch.optim as optim
from torch.autograd import Variable

import numpy as np

import random
from random import shuffle

import sys
import warnings
import math

import model.meta.modelfactory_extention as mf
import model.meta.learner_extention as Learner

warnings.filterwarnings("ignore")

class Net(nn.Module):
    def __init__(self,
                 n_inputs,
                 n_outputs,
                 n_tasks,
                 args):
        super(Net, self).__init__()

        self.args = args
        nl, nh = args.n_layers, args.n_hiddens

        config = mf.ModelFactory.get_model(model_type = args.arch, sizes = [n_inputs] + [nh] * nl + [n_outputs],
                                                dataset = args.dataset, args=args)
        self.net = Learner.Learner(config, args)

        self.opt_wt = optim.SGD(self.parameters(), lr=args.lr)

        if self.args.learn_lr:
            self.net.define_task_lr_params(alpha_init = args.alpha_init)
            self.opt_lr = torch.optim.SGD(list(self.net.alpha_lr.parameters()), lr=args.opt_lr)          

            self.net.define_task_lr_params_e(alpha_init = args.alpha_init)
            self.opt_lr_e = torch.optim.SGD(list(self.net.alpha_lr_e.parameters()), lr=args.opt_lr_e)   

        self.loss = CrossEntropyLoss()
        self.is_cifar = ((args.dataset == 'cifar100') or (args.dataset == 'tinyimagenet'))
        self.glances = args.glances

        self.current_task = 0
        self.memories = args.memories
        self.batchSize = int(args.replay_batch_size)
       
        # allocate buffer
        self.M = []
        self.age = 0
        
        # handle gpus if specified
        self.cuda = args.cuda
        if self.cuda:
            self.net = self.net.cuda()

        self.n_outputs = n_outputs
        if self.is_cifar:
            self.nc_per_task = int(n_outputs / n_tasks)
        else:
            self.nc_per_task = n_outputs


    def compute_offsets(self, task):
        offset1 = task * self.nc_per_task
        offset2 = (task + 1) * self.nc_per_task
        return int(offset1), int(offset2)
            
    def take_multitask_loss(self, bt, logits, y):
        loss = 0.0
        for i, ti in enumerate(bt):
            offset1, offset2 = self.compute_offsets(ti)
            loss += self.loss(logits[i, offset1:offset2].unsqueeze(0), y[i].unsqueeze(0)-offset1)
        return loss/len(bt)

    def forward(self, x, t):
        output, e = self.net.forward(x)
        if self.is_cifar:
            # make sure we predict classes within the current task
            offset1, offset2 = self.compute_offsets(t)
            if offset1 > 0:
                output[:, :offset1].data.fill_(-10e10)
            if offset2 < self.n_outputs:
                output[:, offset2:self.n_outputs].data.fill_(-10e10)
        return output

    def getBatch(self, x, y, t):
        
        if(x is not None):
            mxi = np.array(x)
            myi = np.array(y)
            mti = np.ones(x.shape[0], dtype=int)*t            
        else:
            mxi = np.empty( shape=(0, 0) )
            myi = np.empty( shape=(0, 0) )
            mti = np.empty( shape=(0, 0) )

        bxs = []
        bys = []
        bts = []

        if len(self.M) > 0:
            order = [i for i in range(0,len(self.M))]
            osize = min(self.batchSize,len(self.M))
            for j in range(0,osize):
                shuffle(order)
                k = order[j]
                x,y,t = self.M[k]
                xi = np.array(x)
                yi = np.array(y)
                ti = np.array(t)
                
                bxs.append(xi)
                bys.append(yi)
                bts.append(ti)
        
        #print(bts)
        for i in range(len(myi)):
            bxs.append(mxi[i])
            bys.append(myi[i])
            bts.append(mti[i])


        bxs = Variable(torch.from_numpy(np.array(bxs))).float()
        bys = Variable(torch.from_numpy(np.array(bys))).long().view(-1)
        bts = Variable(torch.from_numpy(np.array(bts))).long().view(-1)
        
        	
        # handle gpus if specified
        if self.cuda:
            bxs = bxs.cuda()
            bys = bys.cuda()
            bts = bts.cuda()

        return bxs, bys, bts


    def observe(self, x, y, t):
        ### step through elements of x

        xi = x.data.cpu().numpy()
        yi = y.data.cpu().numpy()

        if t != self.current_task:
           self.current_task = t

        if self.args.learn_lr:
            loss = self.la_ER(x, y, t)
        else:
            loss = self.ER(xi, yi, t)

        for i in range(0, x.size()[0]):
            self.age += 1
            # Reservoir sampling memory update:
            if len(self.M) < self.memories:
                self.M.append([xi[i], yi[i], t])

            else:
                p = random.randint(0,self.age)
                if p < self.memories:
                    self.M[p] = [xi[i], yi[i], t]

        return loss.item()

    def batch_data_batch_mem(self, bxs, bys, bts):
        #self.batchSize
        data_plus_mem_sz = bxs.shape[0]
        batch_size = self.args.batch_size
        
        bxs_r = bxs[0:(data_plus_mem_sz-batch_size)] 
        bys_r = bys[0:(data_plus_mem_sz-batch_size)]
        bts_r = bts[0:(data_plus_mem_sz-batch_size)] 
        #print(bxs_d.size())        

        bxs_d = bxs[(data_plus_mem_sz - batch_size):data_plus_mem_sz]
        bys_d = bys[(data_plus_mem_sz - batch_size):data_plus_mem_sz]
        bts_d = bts[(data_plus_mem_sz - batch_size):data_plus_mem_sz]
        #print(bxs_r.size())

        return bxs_d, bys_d, bts_d, bxs_r, bys_r, bts_r

    def simdiff_inner_loss(self, e):
        num_of_blocks = 8

        e = F.normalize(e)
        e_r = torch.reshape(e, (e.size(0), 40, num_of_blocks) ) 

        e_r_d = e_r.permute(0, 2, 1)
        e_r_d = e_r_d.detach()
        e_all = torch.einsum('lij, ljk -> lki', e_r_d, e_r)
        e_all_d_minus = torch.triu(e_all, 1) + torch.tril(e_all, -1)        

        loss_ext = torch.einsum( 'ijk -> ', e_all_d_minus * e_all_d_minus)/(num_of_blocks)
        norm_val = ((num_of_blocks - 1) * num_of_blocks/2) 
        for ind_sum in range(0, num_of_blocks - 1):
            tmp = e_all_d_minus[:,:, ind_sum] * torch.sum(e_all_d_minus[:, :, (ind_sum + 1):num_of_blocks], dim=2)
            loss_ext = loss_ext - torch.einsum('ij -> ', tmp)/norm_val
        #print(loss_ext.item())

        return 100 * loss_ext/3

    def simdiff_between_loss(self, e, d):
        num_of_blocks = 8

        e = F.normalize(e)
        d = F.normalize(d)

        e_r = torch.reshape(e, (e.size(0), 40, num_of_blocks) ) 
        d_r = torch.reshape(d, (d.size(0), 40, num_of_blocks) ) 

        e_r_d = e_r.permute(0, 1, 2)
        e_r_d = e_r_d.detach()

        d_r_d = d_r.permute(0, 1, 2)
        d_r_d = d_r_d.detach()


        e_all = torch.einsum('ljk -> lk', d_r_d * e_r)
        e_all =  e_all * e_all 

        loss_ext_e = torch.einsum( 'ij -> ', e_all)/(num_of_blocks * e.size(0))
        #norm_val = ((num_of_blocks - 1) * num_of_blocks/2) * num_of_blocks * e.size(0)
        #for ind_sum in range(0, num_of_blocks - 1):
        #    tmp = e_all_d_minus[:,:, ind_sum] * torch.sum(e_all_d_minus[:, :, (ind_sum + 1):num_of_blocks], dim=2)
        #    loss_ext_e = loss_ext_e - torch.einsum('ij -> ', tmp)/norm_val


        d_all = torch.einsum('ljk -> lk', e_r_d * d_r)
        d_all = d_all * d_all

        loss_ext_d = torch.einsum( 'ij -> ', d_all)/(num_of_blocks * d.size(0))
        #norm_val = ((num_of_blocks - 1) * num_of_blocks/2) * num_of_blocks * d.size(0)
        #for ind_sum in range(0, num_of_blocks - 1):
        #    tmp = d_all_d_minus[:,:, ind_sum] * torch.sum(d_all_d_minus[:, :, (ind_sum + 1):num_of_blocks], dim=2)
        #    loss_ext_d = loss_ext_d - torch.einsum('ij -> ', tmp)/norm_val
        #print(loss_ext_e + loss_ext_d)
        
        return loss_ext_e + loss_ext_d

    def ER(self, x, y, t):
        #print("Experience replay ...")
        #print(len(x))
        for pass_itr in range(self.glances):

            self.net.zero_grad()
            
            # Draw batch from buffer:
            bx,by,bt = self.getBatch(x,y,t)
            bx_d, by_d, bt_d, bx_r, by_r, bt_r  = self.batch_data_batch_mem(bx,by,bt)

            bx_d = bx_d.squeeze()
            out_d, e_d = self.net.forward(bx_d)
            prediction_d = out_d
            loss = self.take_multitask_loss(bt_d, prediction_d, by_d) + 30 * self.simdiff_inner_loss(e_d)
            
            if bx_r.shape[0] == 10:
                #print("ok stage 1")
                bx_r = bx_r.squeeze()
                out_r, e_r = self.net.forward(bx_r)
                prediction_r = out_r
                loss = loss + self.take_multitask_loss(bt_r, prediction_r, by_r) + 30 * self.simdiff_inner_loss(e_r)


            if bx_r.shape[0] == 10:
                #print("ok stage 2")
                loss = loss + 60 * self.simdiff_between_loss(e_d, e_r)
 
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.net.parameters(), self.args.grad_clip_norm)

            self.opt_wt.step()
        
        return loss

    def composite_loss(self, x, y, t, fast_weights):
   
  
        out_d, e_d = self.net.forward(bx_d, fast_weights)
        prediction_d = out_d
        loss = self.loss(prediction_d, by_d) + 30 * self.simdiff_inner_loss(e_d)
            
        if bx_r.shape[0] == 10:
            #print("ok stage 1")
            bx_r = bx_r.squeeze()
            out_r, e_r = self.net.forward(bx_r)
            prediction_r = out_r
            loss = loss + self.take_multitask_loss(bt_r, prediction_r, by_r) + 30 * self.simdiff_inner_loss(e_r)


        if bx_r.shape[0] == 10:
            #print("ok stage 2")
            loss = loss + 60 * self.simdiff_between_loss(e_d, e_r)

        return loss

    def inner_update(self, x, fast_weights, y, t):
        """
        Update the fast weights using the current samples and return the updated fast
        """

        if self.is_cifar:
            offset1, offset2 = self.compute_offsets(t)            
            out, _ = self.net.forward(x, fast_weights)
            logits = out[:, :offset2]
            loss = self.loss(logits[:, offset1:offset2], y-offset1) #+ 30 * self.simdiff_inner_loss(e)
        else:
            out, _ = self.net.forward(x, fast_weights) 
            logits = out
            loss = self.loss(logits, y) #+  30 * self.simdiff_inner_loss(e)  

        if fast_weights is None:
            fast_weights = self.net.parameters()

        graph_required = self.args.second_order
        grads = list(torch.autograd.grad(loss, fast_weights, create_graph=graph_required, retain_graph=graph_required))

        for i in range(len(grads)):
            grads[i] = torch.clamp(grads[i], min = -self.args.grad_clip_norm, max = self.args.grad_clip_norm)

        fast_weights = list(
            map(lambda p: p[1][0] - p[0] * p[1][1], zip(grads, zip(fast_weights, self.net.alpha_lr))))
        return fast_weights, loss.item()

    def inner_update_e(self, x, fast_weights_o, fast_weiths_e, y, t):
        """
        Update the fast weights using the current samples and return the updated fast
        """

        if self.is_cifar:
            offset1, offset2 = self.compute_offsets(t)            
            #out, _ = self.net.forward(x, fast_weights_o)
            out, _ = self.net.forward(x, fast_weights_e)
            logits = out[:, :offset2]
            #loss_o = self.loss(logits[:, offset1:offset2], y-offset1)  
            loss_e = self.simdiff_inner_loss(e) + 0 * self.loss(logits[:, offset1:offset2], y-offset1)  

        else:
            #out, _ = self.net.forward(x, fast_weights_o)
            out, e = self.net.forward(x, fast_weights_e)  
            logits = out
            #loss_o = self.loss(logits, y)  
            loss_e = self.simdiff_inner_loss(e) + 0 * self.loss(logits, y)  

        #if fast_weights is None:
        #    fast_weights = self.net.parameters()

        #graph_required = self.args.second_order
        #print(graph_required)
        #grads_o = list(torch.autograd.grad(loss_o, fast_weights_o, create_graph=graph_required, retain_graph=graph_required, allow_unused=True))

        graph_required = self.args.second_order
        grads_e = list(torch.autograd.grad(loss_e, fast_weights_e, create_graph=graph_required, retain_graph=graph_required, allow_unused=True))

        for i in range(len(grads_e)):
            #grads_o[i] = torch.clamp(grads_o[i], min = -self.args.grad_clip_norm, max = self.args.grad_clip_norm)
            grads_e[i] = torch.clamp(grads_e[i], min = -self.args.grad_clip_norm, max = self.args.grad_clip_norm)

        #fast_weights_o = list(map(lambda p: p[1][0] - p[0] * p[1][1], zip(grads_o, zip(fast_weights_o, self.net.alpha_lr))))
        fast_weights_e = list(map(lambda p_e: p_e[1][0] - p_e[0] * p_e[1][1], zip(grads_e, zip(fast_weights_e, self.net.alpha_lr_e))))

        return fast_weights_o, fast_weights_e,  loss_e.item()


    def inner_update_sup(self, x, fast_weights, y, t):
        """
        Update the fast weights using the current samples and return the updated fast
        """
        if self.is_cifar:
            offset1, offset2 = self.compute_offsets(t)            
            out, _ = self.net.forward(x, fast_weights)
            logits = out[:, :offset2]
            loss_o = self.loss(logits[:, offset1:offset2], y-offset1)  

        else:
            out, _ = self.net.forward(x, fast_weights)  
            loss_o = self.loss(out, y)  

        
        graph_required = self.args.second_order
        grads = list(torch.autograd.grad(loss_o, fast_weights, create_graph=graph_required, retain_graph=graph_required, allow_unused=True))

        for i in range(len(grads)):
            grads[i] = torch.clamp(grads[i], min = -self.args.grad_clip_norm, max = self.args.grad_clip_norm)
        fast_weights_o = list(map(lambda p: p[1][0] - p[0] * p[1][1], zip(grads, zip(fast_weights, self.net.alpha_lr))))

        if self.is_cifar:
            #offset1, offset2 = self.compute_offsets(t)            
            out, e = self.net.forward(x, fast_weights)
            logits = out[:, :offset2]
            loss_e = self.simdiff_inner_loss(e) + 0 * self.loss(logits[:, offset1:offset2], y-offset1)  

        else:
            out, e = self.net.forward(x, fast_weights)  
            loss_e =self.simdiff_inner_loss(e) + 0 *  self.loss(out, y) 

        graph_required = self.args.second_order
        grads = list(torch.autograd.grad(loss_e, fast_weights, create_graph=graph_required, retain_graph=graph_required, allow_unused=True))

        for i in range(len(grads)):
            grads[i] = torch.clamp(grads[i], min = -self.args.grad_clip_norm, max = self.args.grad_clip_norm)
        fast_weights = list(map(lambda p_e: p_e[1][0] - p_e[0] * p_e[1][1], zip(grads, zip(fast_weights_o, self.net.alpha_lr_e))))

        return fast_weights,  loss_e.item() + loss_e.item()

    def la_ER(self, x, y, t):
        """
        this ablation tests whether it suffices to just do the learning rate modulation
        guided by gradient alignment + clipping (that La-MAML does implciitly through autodiff)
        and use it with ER (therefore no meta-learning for the weights)

        """
        for pass_itr in range(self.glances):
            
            perm = torch.randperm(x.size(0))
            x = x[perm]
            y = y[perm]

            batch_sz = x.shape[0]
            n_batches = self.args.cifar_batches
            rough_sz = math.ceil(batch_sz/n_batches)
            fast_weights = None
            meta_losses = [0 for _ in range(n_batches)] 

            bx, by, bt = self.getBatch(x.cpu().numpy(), y.cpu().numpy(), t)
            bx = bx.squeeze()
            fast_weights_inic = self.net.parameters()
            for ind_per_loss in range(0, 1):
                fast_weights_o = fast_weights_inic
                fast_weights_e = fast_weights_inic

                meta_losses_o = [0 for _ in range(n_batches)] 
                meta_losses_e = [0 for _ in range(n_batches)] 
                for i in range(n_batches):

                    batch_x = x[i*rough_sz : (i+1)*rough_sz]
                    batch_y = y[i*rough_sz : (i+1)*rough_sz]

                    # assuming labels for inner update are from the same 
                    #if ind_per_loss == 0:
                    #fast_weights_o, inner_loss = self.inner_update(batch_x, fast_weights_o, batch_y, t)
                    #else:
                    #    fast_weights, inner_loss = self.inner_update_e(batch_x, fast_weights, batch_y, t)
                    #_, fast_weiths_e, _ = self.inner_update_e(batch_x, fast_weights_o, fast_weights_e, batch_y, t)
                    #_, fast_weiths_o, _ = self.inner_update_e(batch_x, fast_weights_o, fast_weights_o, batch_y, t)

                    fast_weiths_o, _ = self.inner_update_sup(batch_x, fast_weights_o, batch_y, t)
                    out, e = self.net.forward(bx, fast_weights_o)
                    #_, e   = self.net.forward(bx, fast_weights_e)
                    prediction = out
                    #if ind_per_loss == 0:
                    #    meta_loss = self.take_multitask_loss(bt, prediction, by)
                    #else:
                    #    meta_loss = self.simdiff_inner_loss(e)
                    meta_loss_o = self.take_multitask_loss(bt, prediction, by)
                    meta_loss_e = self.simdiff_inner_loss(e)
                    meta_losses_o[i] += meta_loss_o + meta_loss_e
                    #meta_losses_e[i] += meta_loss_e

                # update alphas
                self.net.zero_grad()
                self.opt_lr.zero_grad()
                self.opt_lr_e.zero_grad()

                meta_loss_o = meta_losses_o[-1] #sum(meta_losses)/len(meta_losses)
                meta_loss_o.backward()

                #meta_loss_e = meta_losses_e[-1] #sum(meta_losses)/len(meta_losses)
                #meta_loss_e.backward()

                torch.nn.utils.clip_grad_norm_(self.net.parameters(), self.args.grad_clip_norm)
                torch.nn.utils.clip_grad_norm_(self.net.alpha_lr.parameters(), self.args.grad_clip_norm)
                torch.nn.utils.clip_grad_norm_(self.net.alpha_lr_e.parameters(), self.args.grad_clip_norm)

                # update the LRs (guided by meta-loss, but not the weights)
                self.opt_lr.step()
                self.opt_lr_e.step()



                # update weights
                self.net.zero_grad()

                # compute ER loss for network weights
                out, _ = self.net.forward(bx)
                prediction = out
                #if ind_per_loss == 0:
                loss_o = self.take_multitask_loss(bt, prediction, by)
                loss_o.backward()
                torch.nn.utils.clip_grad_norm_(self.net.parameters(), self.args.grad_clip_norm)


                # update weights with grad from simple ER loss 
                # and LRs obtained from meta-loss guided by old and new tasks
                #if ind_per_loss == 0:
                #g_on_prev_task = []
                for i,p in enumerate(self.net.parameters()):
                    #g_on_prev_task.append(p.grad * nn.functional.relu(self.net.alpha_lr[i]) )
                    p.data = p.data -  (p.grad * nn.functional.relu(self.net.alpha_lr[i])) #- 0.5 * g_on_prev_task[i]
                
                #self.net.zero_grad()

                #out, e = self.net.forward(bx)
                #prediction = out
                #loss_e = 0 * self.take_multitask_loss(bt, prediction, by) + self.simdiff_inner_loss(e)
                #loss_e.backward()
                #torch.nn.utils.clip_grad_norm_(self.net.parameters(), self.args.grad_clip_norm)

                #for i,p in enumerate(self.net.parameters()):                                 
                #    p.data = p.data -  (p.grad * nn.functional.relu(self.net.alpha_lr[i])) #- 0.5 * g_on_prev_task[i]

                self.net.zero_grad()
                self.net.alpha_lr.zero_grad()
                self.net.alpha_lr_e.zero_grad()
                #print("update done ...")
                #print(ind_per_loss)

        return .5 * ( loss_o )
