
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np
import scipy as sp
import csv
import copy
import six
import importlib
import os
import sys

import tensorflow as tf

from .cl_base_model import CL_NN
from utils.model_util import *
from utils.train_util import *
from utils.resnet_util import *
from utils.net_util import define_d_net
from functools import reduce
from scipy.special import softmax


class DRL(CL_NN):
    def __init__(self,net_shape,x_ph,y_ph,num_heads=1,batch_size=500,coreset_size=0,coreset_type='random',\
                    conv=False,dropout=None,initialization=None,ac_fn=tf.nn.relu,conv_net_shape=None,strides=None,\
                    pooling=False,B=3,discriminant=False,lambda_dis=.001,coreset_mode='online',batch_iter=1,task_type='split',\
                    net_type='dense',fixed_budget=True,ER='BER1',reg=None, lambda_reg=5e-4,alpha=2.,*args,**kargs):
                    
                    
        assert(num_heads==1)
        super(DRL,self).__init__(net_shape,x_ph,y_ph,num_heads,batch_size,coreset_size,coreset_type,\
                    conv,dropout,initialization,ac_fn,conv_net_shape,strides,pooling,coreset_mode=coreset_mode,\
                    B=B,task_type=task_type,*args,**kargs)

        self.B = B # training batch size
        self.discriminant =discriminant
        self.lambda_dis = lambda_dis
        self.ER = ER
        self.batch_iter = batch_iter
        self.net_type = net_type
        self.fixed_budget = fixed_budget # fixed memory budget or not
        self.x_core_sets,self.y_core_sets = None, None
        self.core_sets = {}
        self.ll, self.kl = 0., 0. 
        self.reg = reg
        self.lambda_reg = lambda_reg
        self.alpha = alpha
        print('DRS_CL: B {}, ER {}, dis {}, batch iter {}'.format(B,ER,discriminant,batch_iter))

        self.define_model(initialization=initialization,dropout=dropout,reg=reg)

        return


    def define_model(self,initialization=None,dropout=None,reg=None,*args,**kargs):

        if self.net_type == 'dense':

            net_shape = [self.conv_net_shape,self.net_shape] if self.conv else self.net_shape
                
            self.qW, self.qB, self.H = define_d_net(self.x_ph,net_shape=net_shape,reuse=False,conv=self.conv,ac_fn=self.ac_fn,\
                                    scope='task',pooling=self.pooling,strides=self.strides,initialization=initialization,reg=reg)
            self.vars = self.qW+self.qB


        elif self.net_type == 'resnet18':
            # Same resnet-18 as used in GEM paper
            self.training = tf.placeholder(tf.bool, name='train_phase')
            kernels = [3, 3, 3, 3, 3]
            filters = [20, 20, 40, 80, 160]
            strides = [1, 0, 2, 2, 2]
            if reg=='l2' :
                regularizer = tf.contrib.layers.l2_regularizer(scale=0.01) 
            elif reg=='l1':
                regularizer = tf.contrib.layers.l1_regularizer(scale=0.01) 
            else:
                regularizer = None
            self.H, self.vars = resnet18_conv_feedforward(self.x_ph,kernels=kernels,filters=filters,strides=strides,
                                                        out_dim=self.net_shape[-1],train_phase=self.training,regularizer=regularizer)
            self.qW, self.qB = [],[]
        if not self.conv:
            self.conv_W,self.conv_h = None,None
        else:
            raise NotImplementedError('Not support Conv NN yet.')


        loss,self.ll,self.kl,self.dis = self.config_loss(self.x_ph,self.y_ph,self.vars,self.H,discriminant=self.discriminant)
        self.grads = tf.gradients(loss,self.vars)

        
    
    def init_inference(self,learning_rate,decay=None,grad_type='adam',*args,**kargs):
        self.config_optimizer(starter_learning_rate=learning_rate,decay=decay,grad_type=grad_type)
        self.config_inference(*args,**kargs)

        return

    

    def config_inference(self,*args,**kargs):

        self.inference = MLE_Inference(var_list=self.vars,grads=self.grads,optimizer=self.task_optimizer,ll=self.ll,kl=self.kl)

    
    
    def config_loss(self,x,y,var_list,H,discriminant=True,likelihood=True,compact_center=False,*args,**kargs):
        loss,ll,reg, dis = 0.,0.,0.,0.
        
        if likelihood:
            if self.task_type == 'split':
                ll = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=H[-1],labels=y))
            else:
                ll = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=H[-1],labels=y))

            loss += ll

        if discriminant:
            yids = tf.matmul(y, tf.transpose(y))
            N = self.B
            mask = tf.eye(N) 

            for h in H:                
                if len(h.shape) > 2:
                    h = tf.reshape(h,[N,-1])
                sim = tf.matmul(h,tf.transpose(h))
                dis += tf.reduce_mean(sim*(1.-yids)+self.alpha*sim*(yids-mask))
                    
            loss += self.lambda_dis * dis 

        if self.reg:
            reg = tf.losses.get_regularization_loss()    
            loss += self.lambda_reg *reg

        return loss,ll,reg,dis
   

    def update_train_batch(self,t,s,sess,feed_dict,*args,**kargs):
        y_batch = feed_dict[self.y_ph]
        buffer_size = self.B 
        if t > 0:
    
            coreset_x, coreset_y = [], []
            clss_batch = np.sum(y_batch,axis=0) > 0
            clss_batch = np.argsort(clss_batch)[-np.sum(clss_batch):]

            if self.ER == 'ER':
                if self.task_type == 'split':
 
                    clss_mem = set(self.core_sets.keys()) - set(clss_batch)
                    cx = np.vstack([self.core_sets[c] for c in clss_batch])
                    mem_x = np.vstack([self.core_sets[c] for c in clss_mem])
                    mem_y,cy = [],[]
                    for c in clss_mem:
                        tmp = np.zeros([self.core_sets[c].shape[0],self.net_shape[-1]])
                        tmp[:,c] = 1
                        mem_y.append(tmp)
                    mem_y = np.vstack(mem_y)
                    for c in clss_batch:
                        tmp = np.zeros([self.core_sets[c].shape[0],self.net_shape[-1]])
                        tmp[:,c] = 1
                        cy.append(tmp)
                    cy = np.vstack(cy)

                else:
                    mem_x,mem_y,cx,cy = [],[],[],[]
                    for c in self.core_sets.keys():
                        if c < t:
                            mem_x.append(self.core_sets[c][0])
                            mem_y.append(self.core_sets[c][1])
                        else:
                            cx.append(self.core_sets[c][0])
                            cy.append(self.core_sets[c][1])
                    mem_x = np.vstack(mem_x)
                    mem_y = np.vstack(mem_y)
                    cx = np.vstack(cx)
                    cy = np.vstack(cy)

                m_N = int(buffer_size/2)
                c_N = buffer_size-m_N
                mids = np.random.choice(mem_x.shape[0],size=m_N)
                cids = np.random.choice(cx.shape[0],size=c_N)
                coreset_x = np.vstack([cx[cids],mem_x[mids]])
                coreset_y = np.vstack([cy[cids],mem_y[mids]])


            else:
                ###### BER #####
                num_cl = len(self.core_sets)
                per_cl_size = int(buffer_size/num_cl)  
                rd = buffer_size % num_cl   

                if self.task_type == 'split':
                    if self.ER == 'BER0' and per_cl_size == 0:
                        # minimum zero positive pair
                        rd = rd - len(clss_batch) 
                        crange = list(set(self.core_sets.keys()).difference(set(clss_batch)))
                        clss = np.random.choice(crange,size=rd,replace=False)
                        clss = np.concatenate([clss,clss_batch])
                        per_cl_size = 1
                        rd_clss = []

                    elif self.ER == 'BER1' and buffer_size <= num_cl:
                        # minimum one positive pair in total
                        if per_cl_size == 1 and rd==0:                                               
                            clss = set(self.core_sets.keys()).difference(clss_batch)
                            rd_clss = np.random.choice(list(clss),size=1,replace=False)
                            clss = clss.difference([rd_clss[0]])
                            clss = np.concatenate([list(clss),clss_batch])
                            rd_clss = np.random.choice(list(clss),size=1,replace=False)
                            rd = 1
                        elif per_cl_size==0:
                            rd = rd - len(clss_batch) 
                            crange = list(set(self.core_sets.keys()).difference(set(clss_batch)))
                            clss = np.random.choice(crange,size=rd-1,replace=False)
                            clss = np.concatenate([list(clss),clss_batch])
                            rd_clss = np.random.choice(list(clss),size=1,replace=False)
                            rd = 1
                            per_cl_size = 1

                    elif self.ER == 'BER2' and per_cl_size <= 1:
                        # minimum one positive pair of each selected class
                        per_cl_size = 2
                        crange = list(set(self.core_sets.keys()).difference(set(clss_batch)))
                        clss = np.random.choice(crange,size=np.int(buffer_size/2-len(clss_batch)),replace=False)
                        clss = np.concatenate([clss,clss_batch])
                        rd_clss = np.random.choice(clss,size=buffer_size-len(clss)*2,replace=False) 
                        rd = len(rd_clss)
                        
                    else:
                        clss = set(self.core_sets.keys())
                        rd_clss = np.random.choice(list(self.core_sets.keys()),size=rd,replace=False) if rd > 0 else [] 
                    for i, cx in self.core_sets.items(): 
                        if i in clss:
                            tsize = per_cl_size+1 if rd>0 and i in rd_clss else per_cl_size
                        else:
                            tsize = 0
                        if tsize>0:                  
                            ids = np.random.choice(len(cx),size=tsize)
                            tmp_y = np.zeros([tsize,self.net_shape[-1]])
                            tmp_y[:,i] = 1
                            tmp_x = cx[ids]
                            coreset_x.append(tmp_x)
                            coreset_y.append(tmp_y)
                else:
                    clss = np.random.choice(list(self.core_sets.keys()),size=rd,replace=False)
                    for i, cx in self.core_sets.items():
                        tsize = per_cl_size+1 if rd>0 and i in clss else per_cl_size                      
                        num_cl = len(self.core_sets[i][0])
                        ids = np.random.choice(num_cl,size=tsize)
                        tmp_x = self.core_sets[i][0][ids]
                        tmp_y = self.core_sets[i][1][ids]
                        coreset_x.append(tmp_x)
                        coreset_y.append(tmp_y)
                    
            if isinstance(coreset_x,list):
                coreset_x, coreset_y = np.vstack(coreset_x), np.vstack(coreset_y)
            feed_dict.update({self.x_ph:coreset_x,self.y_ph:coreset_y})               

        ### first task ###              
        else:
            if self.task_type == 'split':
                cx, cy = [], []
                for c in self.core_sets.keys():
                    cx.append(self.core_sets[c])
                    tmp_y = np.zeros([cx[-1].shape[0],self.net_shape[-1]])
                    tmp_y[:,c] = 1
                    cy.append(tmp_y)

                cx = np.vstack(cx)
                cy = np.vstack(cy)
                cx, cy = shuffle_data(cx,cy)
            else:
                cx, cy = self.core_sets[t] 

            bids = np.random.choice(len(cx),size=buffer_size) 
            feed_dict.update({self.x_ph:cx[bids],self.y_ph:cy[bids]})

        return feed_dict


    def train_update_step(self,t,s,sess,feed_dict,err=0.,x_train_task=None,y_train_task=None,local_iter=0,*args,**kargs):
        assert(self.coreset_size > 0)

        x_batch, y_batch = feed_dict[self.x_ph], feed_dict[self.y_ph]

        if local_iter == 0 and self.coreset_mode == 'ring_buffer':  
            self.update_ring_buffer(t,x_batch,y_batch,sess=sess)
            
        feed_dict = self.update_train_batch(t,s,sess,feed_dict)

        self.inference.update(sess=sess,feed_dict=feed_dict)

        return err


    def train_task(self,sess,t,x_train_task,y_train_task,epoch,print_iter=5,\
                    tfb_merged=None,tfb_writer=None,tfb_avg_losses=None,*args,**kargs):

        # training for current task
        num_iter = int(np.ceil(x_train_task.shape[0]/self.batch_size))
        
        for e in range(epoch):
            shuffle_inds = np.arange(x_train_task.shape[0])
            np.random.shuffle(shuffle_inds)
            x_train_task = x_train_task[shuffle_inds]
            y_train_task = y_train_task[shuffle_inds]
            err = 0.
            ii = 0
            for _ in range(num_iter):
                x_batch,y_batch,ii = get_next_batch(x_train_task,self.batch_size,ii,labels=y_train_task)

                for __ in range(self.batch_iter):
                    feed_dict = {self.x_ph:x_batch,self.y_ph:y_batch}
                    if self.net_type == 'resnet18':
                        feed_dict.update({self.training:True})

                    err = self.train_update_step(t,_,sess,feed_dict,err,x_train_task,y_train_task,local_iter=__,*args,**kargs)
            if (e+1)%print_iter==0:
                if self.discriminant:
                    ll,kl,dis = sess.run([self.ll,self.kl,self.dis],feed_dict=feed_dict)
                    print('epoch',e+1,'ll',ll,'kl',kl,'dis',dis)

        return


    def update_inference(self,sess,*args,**kargs):
        self.inference.reinitialization(sess)
        return


    def update_ring_buffer(self,t,x_batch,y_batch,sess=None):
           
        if self.task_type == 'split':
            y_mask = np.sum(y_batch,axis=0) > 0
            nc_batch = np.sum(y_mask)                
            cls_batch = np.argsort(y_mask)[-nc_batch:]

            for c in cls_batch:
                cx = self.core_sets.get(c,None)
                self.core_sets[c] = x_batch[y_batch[:,c]==1] if cx is None else np.vstack([cx,x_batch[y_batch[:,c]==1]])
            
        
        else:
            cxy = self.core_sets.get(t,None)
            cx = x_batch if cxy is None else np.vstack([cxy[0],x_batch])
            cy = y_batch if cxy is None else np.vstack([cxy[1],y_batch])
            self.core_sets[t] = (cx,cy)
            
        self.online_update_coresets(self.coreset_size,self.fixed_budget,t,sess=sess)

        return
    

    def online_update_coresets(self,coreset_size,fixed_budget,t,sess=None):
    
        if self.coreset_mode == 'ring_buffer':
            if fixed_budget:
                clen = [(c,len(cx)) for c,cx in self.core_sets.items()] if self.task_type=='split' else [(c,len(cx[0])) for c,cx in self.core_sets.items()]
                lens = [it[1] for it in clen]
                R = np.sum(lens) - coreset_size
                while R > 0:
                    c = clen[np.argmax(lens)][0]
                    if self.task_type == 'split':
                        self.core_sets[c] = self.core_sets[c][1:]
                    else:
                        self.core_sets[c] = (self.core_sets[c][0][1:],self.core_sets[c][1][1:])
                    R -= 1
                    clen = [(c,len(cx)) for c,cx in self.core_sets.items()] if self.task_type=='split' else [(c,len(cx[0])) for c,cx in self.core_sets.items()]
                    lens = [it[1] for it in clen]
                
            else:
                if self.task_type == 'split':
                    for i in self.core_sets.keys():
                        cx = self.core_sets[i]  
                        if coreset_size < len(cx):                                                                     
                            cx = cx[-coreset_size:]
                            self.core_sets[i] = cx
                            
                else:
                    ## permuted task ##
                    cx = self.core_sets[t][0]
                    cy = self.core_sets[t][1]
                    num_per_cls = int(coreset_size/cy.shape[1])
                    num_cls = np.sum(cy,axis=0).astype(int)
                    
                    clss = num_cls > num_per_cls
                    tot = clss.sum()
                    if tot > 0:
                        clss = np.argsort(clss)[-tot:]
                        for c in clss:
                            cids = cy[:,c]==1                            
                            rids = np.argsort(cids)[-num_cls[c]:-num_per_cls]
                            cids = np.ones(len(cx))
                            cids[rids] = 0
                            cx = cx[cids.astype(bool)]
                            cy = cy[cids.astype(bool)]
                        self.core_sets[t] = (cx,cy)


    def test_all_tasks(self,t,test_sets,sess,epoch=10,saver=None,file_path=None,confusion=False,*args,**kargs):
        acc_record, pred_probs = [], []
        dim = test_sets[0][1].shape[1]
        cfmtx = np.zeros([dim,dim])
        feed_dict = {self.training:False} if self.net_type=='resnet18' else {}

        for t,ts in enumerate(test_sets): 
            acc, y_probs,cfm = predict(ts[0],ts[1],self.x_ph,self.H[-1],self.batch_size,sess,regression=False,confusion=confusion,feed_dict=feed_dict)
            print('accuracy',acc)
            acc_record.append(acc)
            pred_probs.append(y_probs)
            cfmtx += cfm
        print('avg accuracy',np.mean(acc_record))
        return acc_record,pred_probs,cfmtx


class MLE_Inference:
    def __init__(self,var_list,grads,optimizer=None,ll=0.,kl=0.,*args,**kargs):
        self.var_list = var_list
        self.grads = grads
        self.optimizer = optimizer
        self.ll = ll
        self.kl = kl
        self.config_train()

    def reinitialization(self,sess=None,scope='task',warm_start=True,*args,**kargs):
        if not warm_start:
            reinitialize_scope(scope=scope,sess=sess)
        return

    
    def config_train(self,*args,**kargs):
        
        grads_and_vars = list(zip(self.grads,self.var_list))
        self.train = self.optimizer[0].apply_gradients(grads_and_vars,global_step=self.optimizer[1])

        return

    
    def update(self,sess,feed_dict=None,*args,**kargs):

        sess.run(self.train, feed_dict)

        return
    


