### This is a pytorch implementation of AGEM based on https://github.com/facebookresearch/agem. 

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable

import ipdb
import model.meta.learner as Learner
import model.meta.modelfactory as mf
import numpy as np
import random

# Auxiliary functions useful for AGEM's inner optimization.

def compute_offsets(task, nc_per_task, is_cifar):
    """
        Compute offsets for cifar to determine which
        outputs to select for a given task.
    """
    if is_cifar:
        offset1 = task * nc_per_task
        offset2 = (task + 1) * nc_per_task
    else:
        offset1 = 0
        offset2 = nc_per_task
    return offset1, offset2


def store_grad(pp, grads, grad_dims, tid):
    """
        This stores parameter gradients of past tasks.
        pp: parameters
        grads: gradients
        grad_dims: list with number of parameters per layers
        tid: task id
    """
    # store the gradients
    grads[:, tid].fill_(0.0)
    cnt = 0
    for param in pp():
        if param.grad is not None:
            beg = 0 if cnt == 0 else sum(grad_dims[:cnt])
            en = sum(grad_dims[:cnt + 1])
            grads[beg: en, tid].copy_(param.grad.data.view(-1))
        cnt += 1


def overwrite_grad(pp, newgrad, grad_dims):
    """
        This is used to overwrite the gradients with a new gradient
        vector, whenever violations occur.
        pp: parameters
        newgrad: corrected gradient
        grad_dims: list storing number of parameters at each layer
    """
    cnt = 0
    for param in pp():
        if param.grad is not None:
            beg = 0 if cnt == 0 else sum(grad_dims[:cnt])
            en = sum(grad_dims[:cnt + 1])
            this_grad = newgrad[beg: en].contiguous().view(
                param.grad.data.size())
            param.grad.data.copy_(this_grad)
        cnt += 1



def projectgrad(gradient, memories, margin=0.5, eps = 1e-3, oiter = 0):
    """
        Solves the GEM dual QP described in the paper given a proposed
        gradient "gradient", and a memory of task gradients "memories".
        Overwrites "gradient" with the final projected update.
        input:  gradient, p-vector
        input:  memories, (t * p)-vector
        output: x, p-vector
    """

    similarity = torch.nn.functional.cosine_similarity(gradient.t(), memories.t().mean(dim=0).unsqueeze(0))

    memories_np = memories.cpu().t().double().numpy()
    gradient_np = gradient.cpu().contiguous().view(-1).double().numpy()

    # merge memories
    t = memories_np.shape[0]

    memories_np2 = memories_np.mean(axis=0).reshape(1, memories_np.shape[1])

    ref_mag = np.dot(memories_np2, memories_np2.transpose())
    dotp = np.dot(gradient_np.reshape(1, -1), memories_np2.transpose())

    if(oiter%100==0):
        print('similarity : ', similarity.item())
        print('dotp:', dotp)

    if(dotp[0,0]<0):
        proj = gradient_np.reshape(1, -1) - ((dotp/ ref_mag) * memories_np2)
        gradient.copy_(torch.Tensor(proj).view(-1, 1))


class Net(nn.Module):
    def __init__(self,
                 n_inputs,
                 n_outputs,
                 n_tasks,
                 args):
        super(Net, self).__init__()
        self.args = args

        nl, nh = args.n_layers, args.n_hiddens
        self.margin = args.memory_strength
        self.is_cifar = ((args.dataset == 'cifar100') or (args.dataset == 'tinyimagenet'))

        nl, nh = args.n_layers, args.n_hiddens
        config = mf.ModelFactory.get_model(model_type = args.arch, sizes = [n_inputs] + [nh] * nl + [n_outputs],
                                                dataset = args.dataset, args=args)
        self.net = Learner.Learner(config, args)

        self.ce = nn.CrossEntropyLoss()
        self.bce = torch.nn.CrossEntropyLoss()
        self.n_outputs = n_outputs
        self.glances = args.glances

        self.opt = optim.SGD(self.parameters(), args.lr)

        self.n_memories = args.n_memories
        self.gpu = args.cuda

        self.age = 0
        self.M = []
        self.memories = args.memories
        self.grad_align = []
        self.grad_task_align = {}
        self.current_task = 0

        # allocate episodic memory
        self.memory_data = torch.FloatTensor(
            n_tasks, self.n_memories, n_inputs)
        self.memory_labs = torch.LongTensor(n_tasks, self.n_memories)
        if args.cuda:
            self.memory_data = self.memory_data.cuda()
            self.memory_labs = self.memory_labs.cuda()

        # allocate temporary synaptic memory
        self.grad_dims = []
        for param in self.parameters():
            self.grad_dims.append(param.data.numel())
        self.grads = torch.Tensor(sum(self.grad_dims), n_tasks)
        if args.cuda:
            self.grads = self.grads.cuda()

        # allocate counters
        self.observed_tasks = []
        self.mem_cnt = 0
        if self.is_cifar:
            self.nc_per_task = int(n_outputs / n_tasks)
        else:
            self.nc_per_task = n_outputs
        
        if args.cuda:
            self.cuda()

        self.iter = 0

    def forward(self, x, t):
        if self.args.dataset == 'tinyimagenet':
            x = x.view(-1, 3, 64, 64)
        elif self.args.dataset == 'cifar100':
            x = x.view(-1, 3, 32, 32)
        output = self.net.forward(x)
        if self.is_cifar:
            # make sure we predict classes within the current task
            offset1 = int(t * self.nc_per_task)
            offset2 = int((t + 1) * self.nc_per_task)
            if offset1 > 0:
                output[:, :offset1].data.fill_(-10e10)
            if offset2 < self.n_outputs:
                output[:, offset2:self.n_outputs].data.fill_(-10e10)
        return output

    def observe(self, x, y, t):

        self.iter +=1
        x = x.view(x.size(0), -1)
        # update memory
        if t != self.current_task:
            self.observed_tasks.append(t)
            self.current_task = t
            self.grad_align.append([])

        for pass_itr in range(self.glances):

            if(pass_itr==0):
                # Update ring buffer storing examples from current task
                bsz = y.data.size(0)
                endcnt = min(self.mem_cnt + bsz, self.n_memories)
                effbsz = endcnt - self.mem_cnt
                self.memory_data[t, self.mem_cnt: endcnt].copy_(
                    x.data[: effbsz])
                if bsz == 1:
                    self.memory_labs[t, self.mem_cnt] = y.data[0]
                else:
                    self.memory_labs[t, self.mem_cnt: endcnt].copy_(
                        y.data[: effbsz])
                self.mem_cnt += effbsz
                if self.mem_cnt == self.n_memories:
                    self.mem_cnt = 0

            # compute gradient on previous tasks
            if len(self.observed_tasks) > 1:
                for tt in range(len(self.observed_tasks) - 1):
                    self.zero_grad()
                    # fwd/bwd on the examples in the memory
                    past_task = self.observed_tasks[tt]

                    offset1, offset2 = compute_offsets(past_task, self.nc_per_task,
                                                       self.is_cifar)
                    ptloss = self.ce(
                        self.forward(
                            Variable(self.memory_data[past_task]),
                            past_task)[:, offset1: offset2],
                        Variable(self.memory_labs[past_task] - offset1))
                    ptloss.backward()
                    torch.nn.utils.clip_grad_norm_(self.net.parameters(), self.args.grad_clip_norm)

                    store_grad(self.parameters, self.grads, self.grad_dims,
                               past_task)

            # now compute the grad on the current minibatch
            self.zero_grad()

            offset1, offset2 = compute_offsets(t, self.nc_per_task, self.is_cifar)
            loss = self.ce(self.forward(x, t)[:, offset1: offset2], y - offset1)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.net.parameters(), self.args.grad_clip_norm)

            # check if gradient violates constraints                                                           
            if len(self.observed_tasks) > 1:
                # copy gradient
                store_grad(self.parameters, self.grads, self.grad_dims, t)
                indx = torch.cuda.LongTensor(self.observed_tasks[:-1]) if self.gpu \
                    else torch.LongTensor(self.observed_tasks[:-1])

                projectgrad(self.grads[:, t].unsqueeze(1),                                           
                              self.grads.index_select(1, indx), self.margin, oiter = self.iter)
                # copy gradients back
                overwrite_grad(self.parameters, self.grads[:, t],
                               self.grad_dims)

            self.opt.step()
        
        xi = x.data.cpu().numpy()
        yi = y.data.cpu().numpy()
        for i in range(0,x.size()[0]):
            self.age += 1
            # Reservoir sampling memory update:
            if len(self.M) < self.memories:
                self.M.append([xi[i],yi[i],t])

            else:
                p = random.randint(0,self.age)
                if p < self.memories:
                    self.M[p] = [xi[i],yi[i],t]

        return loss.item()
