# An implementation of Experience Replay (ER) with reservoir sampling and without using tasks from Algorithm 4 of https://openreview.net/pdf?id=B1gTShAct7
#
# Copyright 2019-present, IBM Research
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import torch
import torch.nn as nn
from torch.nn.modules.loss import CrossEntropyLoss
from torch.nn import functional as F

import torch.optim as optim
from torch.autograd import Variable

import numpy as np

import random
from random import shuffle

import sys
import warnings
import math

import model.meta.modelfactory_extention as mf
import model.meta.learner_extention as Learner

warnings.filterwarnings("ignore")

class Net(nn.Module):
    def __init__(self,
                 n_inputs,
                 n_outputs,
                 n_tasks,
                 args):
        super(Net, self).__init__()

        self.args = args
        nl, nh = args.n_layers, args.n_hiddens

        config = mf.ModelFactory.get_model(model_type = args.arch, sizes = [n_inputs] + [nh] * nl + [n_outputs],
                                                dataset = args.dataset, args=args)
        self.net = Learner.Learner(config, args)

        self.opt_wt = optim.SGD(self.parameters(), lr=args.lr)

        if self.args.learn_lr:
            self.net.define_task_lr_params(alpha_init = args.alpha_init)
            self.opt_lr = torch.optim.SGD(list(self.net.alpha_lr.parameters()), lr=args.opt_lr)          

            self.net.define_task_lr_params_e(alpha_init = args.alpha_init)
            self.opt_lr_e = torch.optim.SGD(list(self.net.alpha_lr_e.parameters()), lr=args.opt_lr_e)   

            self.net.define_reg_param(gamma_init = 10.0)
            self.opt_gamma_r = torch.optim.SGD(list(self.net.gamma_r.parameters()), lr=1)  

        self.loss = CrossEntropyLoss()
        self.is_cifar = ((args.dataset == 'cifar100') or (args.dataset == 'tinyimagenet') or (args.dataset == 'cifar10') )
        self.glances = args.glances

        self.current_task = 0
        self.memories = args.memories
        self.batchSize = int(args.replay_batch_size)
       
        # allocate buffer
        self.M = []
        self.age = 0
        
        # allocate label based memory buffer
        self.samples_per_task = int(args.memories/n_tasks)
        
        single_task = []
        self.M_per_task = [[] for r in range(0, n_tasks)]
        self.n_all_tasks = n_tasks


        # handle gpus if specified
        self.cuda = args.cuda
        if self.cuda:
            self.net = self.net.cuda()

        self.n_outputs = n_outputs
        if self.is_cifar:
            self.nc_per_task = int(n_outputs / n_tasks)
        else:
            self.nc_per_task = n_outputs
        self.n_tasks = n_tasks

        self.pass_itr = 0
        self.iter_number_current = 0

        self.evaluate_on_tasks = 'no'

    def compute_offsets(self, task):
        #print('NC per task:')
        #print(self.nc_per_task)
        offset1 = task * self.nc_per_task
        offset2 = (task + 1) * self.nc_per_task
        return int(offset1), int(offset2)
            
    def take_multitask_loss(self, bt, logits, y):
        loss = 0.0
        for i, ti in enumerate(bt):
            offset1, offset2 = self.compute_offsets(ti)
            loss += self.loss(logits[i, offset1:offset2].unsqueeze(0), y[i].unsqueeze(0)-offset1)
        return loss/len(bt)

    def take_singletask_loss(self, bt, logits, y):
        loss = 0.0
        offset1, offset2 = self.compute_offsets(bt)
        for i in range(0, len(y)):
            loss += self.loss(logits[i, offset1:offset2].unsqueeze(0), y[i].unsqueeze(0)-offset1)
        return loss/len(y)

    def forward(self, x, t):
        output, e = self.net.forward(x)
        if self.is_cifar:
            # make sure we predict classes within the current task
            offset1, offset2 = self.compute_offsets(t)
            if offset1 > 0:
                output[:, :offset1].data.fill_(-10e10)
            if offset2 < self.n_outputs:
                output[:, offset2:self.n_outputs].data.fill_(-10e10)
        return output

    def rep_buff_mem_len(self, t):
        mem_len = 0
        for i in range(0, t + 1):
            mem_len+=len(self.M_per_task[i])
        return mem_len

    def getBatch(self, x, y, t):
        if(x is not None):
            mxi = np.array(x)
            myi = np.array(y)
            mti = np.ones(x.shape[0], dtype=int)*t            
        else:
            mxi = np.empty( shape=(0, 0) )
            myi = np.empty( shape=(0, 0) )
            mti = np.empty( shape=(0, 0) )

        bxs = []
        bys = []
        bts = []
        if len(self.M) > 0:
            order = [i for i in range(0,len(self.M))]
            osize = min(self.batchSize,len(self.M))
            for j in range(0,osize):
                shuffle(order)
                k = order[j]
                x,y,t = self.M[k]
                xi = np.array(x)
                yi = np.array(y)
                ti = np.array(t)
                
                bxs.append(xi)
                bys.append(yi)
                bts.append(ti)

        bxs_pc = bxs
        bys_pc = bys
        bts_pc = bts

        for i in range(len(myi)):
            bxs.append(mxi[i])
            bys.append(myi[i])
            bts.append(mti[i])

        bxs = Variable(torch.from_numpy(np.array(bxs))).float()
        bys = Variable(torch.from_numpy(np.array(bys))).long().view(-1)
        bts = Variable(torch.from_numpy(np.array(bts))).long().view(-1)       
                
        # handle gpus if specified
        if self.cuda:
            bxs = bxs.cuda()
            bys = bys.cuda()
            bts = bts.cuda()

        return bxs, bys, bts, bxs_pc, bys_pc, bts_pc


    def observe(self, x, y, t):

        ### step through elements of x
        reser_mem_task_aware = 0

        xi = x.data.cpu().numpy()
        yi = y.data.cpu().numpy()

        if t != self.current_task:
           self.current_task = t

        
        loss_o, loss_e = self.gnc(x, y, t)
        

        if reser_mem_task_aware == 1:
            self.push_to_task_aware_mem(xi, yi, t)
        else:
            if self.args.selective_mem_push == 'per_first_epoch':
                if(self.real_epoch > 1 or self.pass_itr > 1):
                    do_noting = 1
                else:
                    for i in range(0, x.size()[0]):
                        self.age += 1
                        # Reservoir sampling memory update:
                        if len(self.M) < self.memories:
                            self.M.append([xi[i], yi[i], t])
                        else:
                            p = random.randint(0,self.age)
                            if p < self.memories:
                                self.M[p] = [xi[i], yi[i], t]

            elif self.args.selective_mem_push == 'per_all_epochs_sparse':
                #print(self.iter_number_current)
                if np.remainder(self.iter_number_current, self.args.sparse_push_val) != 0:
                    do_noting = 1
                else:
                    for i in range(0, x.size()[0]):
                        self.age += 1
                        # Reservoir sampling memory update:
                        if len(self.M) < self.memories:
                            self.M.append([xi[i], yi[i], t])
                        else:
                            p = random.randint(0,self.age)
                            if p < self.memories:
                                self.M[p] = [xi[i], yi[i], t]

            elif self.args.selective_mem_push == 'none':
                for i in range(0, x.size()[0]):
                    self.age += 1
                    # Reservoir sampling memory update:
                    if len(self.M) < self.memories:
                        self.M.append([xi[i], yi[i], t])
                    else:
                        p = random.randint(0,self.age)
                        if p < self.memories:
                            self.M[p] = [xi[i], yi[i], t]

        return loss_o.item(), loss_e.item()

    def push_to_task_aware_mem(self, xi, yi, t):

        for i in range(0, xi.shape[0]):
            self.age += 1
            # tak aware reservoir sampling memory update:
            if len(self.M_per_task[t]) < self.samples_per_task:
                self.M_per_task[t].append([xi[i], yi[i], t])
            else:
                p = random.randint(0, int(self.age/self.n_all_tasks) ) 
                if p < self.samples_per_task:
                    self.M_per_task[t][p] = [xi[i], yi[i], t]
        return 1

    def batch_data_batch_mem(self, bxs, bys, bts):
        #self.batchSize
        data_plus_mem_sz = bxs.shape[0]
        batch_size = self.args.batch_size
        
        bxs_r = bxs[0:(data_plus_mem_sz-batch_size)] 
        bys_r = bys[0:(data_plus_mem_sz-batch_size)]
        bts_r = bts[0:(data_plus_mem_sz-batch_size)] 
 
        bxs_d = bxs[(data_plus_mem_sz - batch_size):data_plus_mem_sz]
        bys_d = bys[(data_plus_mem_sz - batch_size):data_plus_mem_sz]
        bts_d = bts[(data_plus_mem_sz - batch_size):data_plus_mem_sz]

        return bxs_d, bys_d, bts_d, bxs_r, bys_r, bts_r


    def gnc(self, x, y, t):
        self.args.evaluate_on_tasks = 'no'
        """
        gnc main learning loop
        """
        mx = np.array(x.cpu().numpy())
        my = np.array(y.cpu().numpy())
        mt = np.ones(x.shape[0], dtype=int)*t     

        for pass_itr in range(self.glances):

            self.pass_itr = pass_itr
            perm = torch.randperm(x.size(0))
            x = x[perm]
            y = y[perm]

            bx, by, bt, _, _, _ = self.getBatch(x.cpu().numpy(), y.cpu().numpy(), t)
            bx = bx.squeeze()
            n_batches = self.args.cifar_batches

            _, _, _,   bx_r, _, _  = self.batch_data_batch_mem(bx, by, bt)
            bx_r = bx_r.squeeze()
            bx_total = []
            by_total = []
            bt_total = []

            if self.args.sampling_type == 'peace_by_peace_batch':
                for i in range(n_batches):
                    bx, by, bt, bx_pc, by_pc, bt_pc = self.getBatch(x.cpu().numpy(), y.cpu().numpy(), t)
                    bx = bx.squeeze()
                    _, _, _,   bx_r, _, _  = self.batch_data_batch_mem(bx, by, bt)
                    bx_r = bx_r.squeeze()
                    if bx_r.shape[0] > 0:
                        for ind_inner in range(0, bx_r.shape[0], self.args.jump_lag_mem_buff):
                            bx_total.append(bx_pc[ind_inner])
                            by_total.append(by_pc[ind_inner])
                            bt_total.append(bt_pc[ind_inner])

                for ind_inner in range(0, x.shape[0]):
                    bx_total.append(mx[ind_inner])
                    by_total.append(my[ind_inner])
                    bt_total.append(mt[ind_inner])

                bx_total = Variable(torch.from_numpy(np.array(bx_total))).float()
                by_total = Variable(torch.from_numpy(np.array(by_total))).long().view(-1)
                bt_total = Variable(torch.from_numpy(np.array(bt_total))).long().view(-1)       
            
                # handle gpus if specified
                if self.cuda:
                    bx_total = bx_total.cuda()
                    by_total = by_total.cuda()
                    bt_total = bt_total.cuda()

            elif self.args.sampling_type == 'hole_batch':
                bx_total, by_total, bt_total, _, _, _ = self.getBatch(x.cpu().numpy(), y.cpu().numpy(), t)
                

            if self.args.is_ER_active == 'no':
                bx_total, by_total, bt_total,   _, _, _  = self.batch_data_batch_mem(bx_total, by_total, bt_total)

            prediction, _ = self.net.forward(bx_total, self.net.parameters(), bn_training = True)
            loss_o = self.take_multitask_loss(bt_total, prediction, by_total)
            

            if self.args.update_type == 'with-primary-loss': 
                self.net.zero_grad()
                loss_o.backward()
                if self.args.clip_grad_norm == 'clip_grad_norm[yes]':
                    torch.nn.utils.clip_grad_norm_(self.net.parameters(), self.args.grad_clip_norm)
                
                for i,p in enumerate(self.net.parameters()):
                    p.data = p.data -  self.args.gnc_lr * p.grad
                loss_e = loss_o

            else:
                print('Unknow update type!')
                print(self.args.update_type) 
                return 0 

            self.net.zero_grad()

        return   loss_o, loss_e
