# coding: utf-8

# In[1]:


from __future__ import absolute_import
from __future__ import division
from __future__ import print_function



import numpy as np
import scipy as sp
import csv
import copy
import six
import importlib
import os
import sys
import time

import tensorflow as tf
import argparse


from utils.model_util import *
from utils.train_util import *
from utils.test_util import *

from models.drl import DRL
from models.agem import AGEM
from models.multisim import MultiSim
from models.rho_margin import Rho_Margin

from tensorflow.examples.tutorials.mnist import input_data
from tensorflow.python.keras.datasets import cifar10,cifar100

parser = argparse.ArgumentParser()

parser.add_argument('-sd','--seed', default=0, type=int, help='random seed')
parser.add_argument('-ds','--dataset', default='mnist', type=str, help='specify datasets')
parser.add_argument('-dp','--data_path',default='./data/',type=str,help='path to dataset')
parser.add_argument('-rp','--result_path',default='./results/',type=str,help='the path for saving results')
parser.add_argument('-ttp','--task_type', default='split', type=str, help='task type can be split, permuted, cross split, batch')
parser.add_argument('-e','--epoch', default=50, type=int, help='number of epochs')
parser.add_argument('-pe','--print_epoch', default=100, type=int, help='number of epochs of printing loss')
parser.add_argument('-csz','--coreset_size', default=0, type=int, help='size of each class in a coreset')
parser.add_argument('-ctp','--coreset_type', default='random', type=str, help='type of coresets')
parser.add_argument('-cmod','--coreset_mode', default='ring_buffer', type=str, help='construction mode of coresets')
parser.add_argument('-gtp','--grad_type', default='adam', type=str, help='type of gradients optimizer')
parser.add_argument('-bsz','--batch_size', default=1, type=int, help='batch size')
parser.add_argument('-trsz','--train_size', default=1000, type=int, help='size of training set')
parser.add_argument('-tesz','--test_size', default=-1, type=int, help='size of testing set')
parser.add_argument('-vdsz','--valid_size', default=0, type=int, help='size of validation set')
parser.add_argument('-nts','--num_tasks', default=10, type=int, help='number of tasks')
parser.add_argument('-lr','--learning_rate', default=0.001, type=float, help='learning rate')
parser.add_argument('-af','--ac_fn', default='relu', type=str, help='activation function of hidden layers')
parser.add_argument('-cltp','--cl_type', default='drl', type=str,help='cl type can be drl,agem,multisim,rho_margin')
parser.add_argument('-hdn','--hidden',default=[100,100],type=str2ilist,help='hidden units of each layer of the network')
parser.add_argument('-cv','--conv',default=False,type=str2bool,help='if use CNN on top')
parser.add_argument('-B','--B',default=3,type=int,help='training batch size')
parser.add_argument('-rg','--reg',default=None,type=str,help='regularizer')
parser.add_argument('-lrg','--lambda_reg',default=0.,type=float,help='lambda of regularizer')
parser.add_argument('-disc','--discriminant',default=False,type=str2bool,help='enable discriminant in drs cl')
parser.add_argument('-lam_dis','--lambda_disc',default=0.001,type=float,help='lambda discriminant')
parser.add_argument('-disa','--dis_alpha',default=2.,type=float,help='alpha of DRL')
parser.add_argument('-er','--ER',default='ER',type=str,help='experience replay strategy, can be ER,BER0, BER1, BER2')
parser.add_argument('-bit','--batch_iter',default=1,type=int,help='iterations on one batch')
parser.add_argument('-ntp','--net_type',default='dense',type=str,help='network type, can be dense, conv, resnet18')
parser.add_argument('-fxbt','--fixed_budget',default=True,type=str2bool,help='if budget of episodic memory is fixed or not')
parser.add_argument('-mbs','--mem_bsize',default=256,type=int,help='memory batch size used in AGEM')
parser.add_argument('-arcs','--arc_scale',default=16.,type=float,help='scale of arcface')
parser.add_argument('-arcm','--arc_margin',default=0.5,type=float,help='margin of arcface')
parser.add_argument('-mta','--mults_alpha',default=2.,type=float,help='alpha of multisim')
parser.add_argument('-mtb','--mults_beta',default=40.,type=float,help='beta of multisim')
parser.add_argument('-mtl','--mults_lamb',default=.5,type=float,help='lambda of multisim')
parser.add_argument('-st','--strength',default=100.,type=float,help='strength of auxilary objective')
parser.add_argument('-rho','--rho',default=0,type=int,help='if larger than 0, compute rho spectrum after training')
parser.add_argument('-rbe','--r_beta',default=0.6,type=float,help='init value of beta in rho_margin')
parser.add_argument('-rga','--r_gamma',default=0.2,type=float,help='gamma in rho_margin')
parser.add_argument('-rpr','--r_prob',default=0.25,type=float,help='p_rho in rho_margin')
parser.add_argument('-cuda','--cuda',default=-1,type=int,help='gpu id, -1 means no gpu')


args = parser.parse_args()
print(args)

os.environ["CUDA_VISIBLE_DEVICES"] = str(args.cuda)

    
seed = args.seed
print('seed',seed)
tf.set_random_seed(seed)
np.random.seed(seed)

ac_fn = set_ac_fn(args.ac_fn)
dataset = args.dataset

if dataset in ['mnist','fashion']:
    DATA_DIR = os.path.join(args.data_path,dataset)

print(dataset)

result_path = args.result_path
hidden = args.hidden 
conv = args.conv


print(args.task_type)

if 'split' in args.task_type and dataset in ['fashion','mnist','cifar10']:    
    num_tasks = 5

else:
    num_tasks = args.num_tasks


num_heads = 1
print('heads',num_heads)


# load data for different task

if  args.task_type == 'permuted':
    data = input_data.read_data_sets(DATA_DIR,one_hot=True) 
    shuffle_ids = np.arange(data.train.images.shape[0])
    X_TRAIN = data.train.images[shuffle_ids][:args.train_size]
    Y_TRAIN = data.train.labels[shuffle_ids][:args.train_size]
    X_TEST = data.test.images[:args.test_size]
    Y_TEST = data.test.labels[:args.test_size]
    out_dim = Y_TRAIN.shape[1]
    cl_n = out_dim # number of classes in each task
    cl_cmb = None
    # generate data for first task
    if 'permuted' in args.task_type:
        x_train_task,y_train_task,x_test_task,y_test_task,cl_k,clss = gen_next_task_data(args.task_type,X_TRAIN,Y_TRAIN,X_TEST,Y_TEST,\
                                                                                        train_size=args.train_size,test_size=args.test_size)
    else:
        x_train_task,y_train_task,x_test_task,y_test_task = X_TRAIN, Y_TRAIN, X_TEST, Y_TEST
        clss = None

elif 'split' in args.task_type:
    if 'cifar' in dataset:
        if args.net_type == 'resnet18':
            conv = False
            hidden = []

        else:
            conv =True
            hidden = [512,512]
        
        if dataset  == 'cifar10':
            (X_TRAIN, Y_TRAIN), (X_TEST, Y_TEST) = cifar10.load_data() 
            Y_TRAIN,Y_TEST = Y_TRAIN.reshape(-1), Y_TEST.reshape(-1)
            # standardize data
            X_TRAIN,X_TEST = standardize_flatten(X_TRAIN,X_TEST,flatten=False)
            print('data shape',X_TRAIN.shape)
           
            if num_heads > 1:
                out_dim = 2
            else:
                out_dim = 10

            cl_cmb = np.arange(10)
            cl_k = 0
            cl_n = 2 # 2 classes per task

            x_train_task,y_train_task,x_test_task,y_test_task,cl_k,clss = gen_next_task_data(args.task_type,X_TRAIN,Y_TRAIN,X_TEST,Y_TEST,train_size=args.train_size,test_size=args.test_size,\
                                                                        cl_n=cl_n,cl_k=cl_k,cl_cmb=cl_cmb,out_dim=out_dim,num_heads=num_heads) #X_TRAIN,Y_TRAIN,X_TEST,Y_TEST

        elif dataset == 'cifar100':
            (X_TRAIN, Y_TRAIN), (X_TEST, Y_TEST) = cifar100.load_data() 
            Y_TRAIN,Y_TEST = Y_TRAIN.reshape(-1), Y_TEST.reshape(-1)
            # standardize data
            X_TRAIN,X_TEST = standardize_flatten(X_TRAIN,X_TEST,flatten=False)
            print('data shape',X_TRAIN.shape)
           
            if num_heads > 1:
                out_dim = int(100/num_tasks)
            else:
                out_dim = 100

            cl_cmb = np.arange(100)
            cl_k = 0
            cl_n = int(100/num_tasks) 
            
            x_train_task,y_train_task,x_test_task,y_test_task,cl_k,clss = gen_next_task_data(args.task_type,X_TRAIN,Y_TRAIN,X_TEST,Y_TEST,train_size=args.train_size,test_size=args.test_size,\
                                                                        cl_n=cl_n,cl_k=cl_k,cl_cmb=cl_cmb,out_dim=out_dim,num_heads=num_heads) 

    else:
        if args.dataset == 'fashion':
            url = 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/'
            data = input_data.read_data_sets(DATA_DIR,source_url=url)
        else:
            data = input_data.read_data_sets(DATA_DIR) 
        X_TRAIN = np.concatenate([data.train.images,data.validation.images],axis=0)
        Y_TRAIN = np.concatenate([data.train.labels,data.validation.labels],axis=0)
        X_TEST = data.test.images
        Y_TEST = data.test.labels
        if conv:
            X_TRAIN,X_TEST = X_TRAIN.reshape(-1,28,28,1),X_TEST.reshape(-1,28,28,1)
        
        if num_heads > 1:
            out_dim = 2
        else:
            out_dim = 2 * num_tasks

        cl_cmb = np.arange(10) 
        cl_k = 0
        cl_n = 2
        
        x_train_task,y_train_task,x_test_task,y_test_task,cl_k,clss = gen_next_task_data(args.task_type,X_TRAIN,Y_TRAIN,X_TEST,Y_TEST,train_size=args.train_size,\
                                                                        test_size=args.test_size,cl_n=cl_n,cl_k=cl_k,cl_cmb=cl_cmb,out_dim=out_dim,num_heads=num_heads)


TRAIN_SIZE = x_train_task.shape[0]
TEST_SIZE = x_test_task.shape[0]
original_batch_size = args.batch_size
batch_size = TRAIN_SIZE if args.batch_size > args.train_size else args.batch_size
print('batch size',batch_size)

# set results path and file name
if not os.path.exists(result_path):
    os.mkdir(result_path)

file_name = dataset+'_tsize'+str(TRAIN_SIZE)+'_cset'+str(args.coreset_size)+args.coreset_type+'_bsize'+str(batch_size)\
            +'_e'+str(args.epoch)+'_fxb'+str(args.fixed_budget)+'_'+args.task_type+'_disc'+str(args.discriminant)+'_'\
            +str(args.ER)+'_'+args.grad_type+'_'+args.cl_type+'_sd'+str(seed)

file_path = os.path.join(result_path,file_name)
file_path = config_result_path(file_path)
with open(file_path+'configures.txt','w') as f:
    f.write(str(args))


if args.net_type == 'resnet18':
    x_ph = tf.placeholder(dtype=tf.float32,shape=[None,*x_train_task.shape[1:]])
    in_dim = None
    dropout = None
    conv_net_shape,strides = None, None
    pooling = False

elif conv:
    x_ph = tf.placeholder(dtype=tf.float32,shape=[None,*x_train_task.shape[1:]])
    in_dim = None
    dropout = 0.5
    if 'cifar' in dataset:
        conv_net_shape = [[3,3,3,32],[3,3,32,32],[3,3,32,64],[3,3,64,64]]
        strides = [[1,2,2,1],[1,2,2,1],[1,1,1,1],[1,1,1,1]]
        hidden = [512,256]
    else:
        conv_net_shape = [[4,4,1,32],[4,4,32,32]]
        strides = [[1,2,2,1],[1,1,1,1]]
    
    pooling = True

else:
    x_ph = tf.placeholder(dtype=tf.float32,shape=[None,x_train_task.shape[1]])
    in_dim = x_train_task.shape[1]
    dropout = None
    conv_net_shape,strides = None, None
    pooling = False


y_ph = tf.placeholder(dtype=tf.float32,shape=[None,out_dim]) 
net_shape = [in_dim]+hidden+[out_dim]



if args.cl_type=='drl':
    Model = DRL(net_shape,x_ph,y_ph,num_heads,batch_size,args.coreset_size,args.coreset_type,conv=conv,
            dropout=dropout,ac_fn=ac_fn,B=args.B,discriminant=args.discriminant,lambda_dis=args.lambda_disc,           
            ER=args.ER,coreset_mode=args.coreset_mode,task_type=args.task_type,batch_iter=args.batch_iter,
            net_type=args.net_type,fixed_budget=args.fixed_budget,reg=args.reg,lambda_reg=args.lambda_reg,
            alpha=args.dis_alpha)

elif args.cl_type=='agem':
    Model = AGEM(net_shape,x_ph,y_ph,num_heads,batch_size,args.coreset_size,args.coreset_type,conv=conv,dropout=dropout,
            ac_fn=ac_fn,B=args.B,coreset_mode=args.coreset_mode,task_type=args.task_type,batch_iter=args.batch_iter,
            net_type=args.net_type,fixed_budget=args.fixed_budget,mem_batch_size=args.mem_bsize,reg=args.reg,lambda_reg=args.lambda_reg)


elif args.cl_type=='multisim':
    Model = MultiSim(net_shape,x_ph,y_ph,num_heads,batch_size,args.coreset_size,args.coreset_type,conv=conv,dropout=dropout,
            ac_fn=ac_fn,B=args.B,ER=args.ER,coreset_mode=args.coreset_mode,task_type=args.task_type,batch_iter=args.batch_iter,
            net_type=args.net_type,fixed_budget=args.fixed_budget,mem_batch_size=args.mem_bsize,reg=args.reg,lambda_reg=args.lambda_reg,
            alpha=args.mults_alpha,beta=args.mults_beta,lamb=args.mults_lamb,strength=args.strength)

elif args.cl_type=='rho_margin':
    Model = Rho_Margin(net_shape,x_ph,y_ph,num_heads,batch_size,args.coreset_size,args.coreset_type,conv=conv,dropout=dropout,
            ac_fn=ac_fn,B=args.B,ER=args.ER,coreset_mode=args.coreset_mode,task_type=args.task_type,batch_iter=args.batch_iter,
            net_type=args.net_type,fixed_budget=args.fixed_budget,mem_batch_size=args.mem_bsize,reg=args.reg,lambda_reg=args.lambda_reg,
            beta=args.r_beta,gamma=args.r_gamma,p_rho=args.r_prob,strength=args.strength)


else:
    raise TypeError('Wrong type of model')


Model.init_inference(learning_rate=args.learning_rate,decay=None,grad_type=args.grad_type)


# Start training tasks
test_sets, valid_sets = [],[]
avg_accs ,acc_record = [],[]


saver = tf.train.Saver()
sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) 

tf.global_variables_initializer().run(session=sess)
print('num tasks',args.num_tasks)

time_count = 0.
for t in range(args.num_tasks):
    # get test data
    test_sets.append((x_test_task,y_test_task))
    if args.valid_size > 0 and args.valid_size < args.train_size:
        valid_sets.append((x_train_task[-args.valid_size:],y_train_task[-args.valid_size:]))
        x_train_task, y_train_task = x_train_task[:-args.valid_size], y_train_task[:-args.valid_size]
        
    start = time.time()
    Model.train_task(sess,t,x_train_task,y_train_task,epoch=args.epoch,print_iter=args.print_epoch)
    end = time.time()

    time_count += end-start
    print('training time',time_count)

    if len(valid_sets)>0:
        print('********start validation********')
        accs, probs, _ = Model.test_all_tasks(t,valid_sets,sess,args.epoch,saver=saver,file_path=file_path,confusion=False)
    
    print('*********start testing***********')
    accs, probs, _ = Model.test_all_tasks(t,test_sets,sess,args.epoch,saver=saver,file_path=file_path,confusion=False)


    acc_record.append(accs)
    avg_accs.append(np.mean(accs))

    if t < num_tasks-1:

        x_train_task,y_train_task,x_test_task,y_test_task,cl_k,clss = Model.update_task_data_and_inference(sess,t,args.task_type,X_TRAIN,Y_TRAIN,X_TEST,Y_TEST,out_dim,\
                                                                                                        original_batch_size=batch_size,cl_n=cl_n,cl_k=cl_k,cl_cmb=cl_cmb,clss=clss,\
                                                                                                        x_train_task=x_train_task,y_train_task=y_train_task,rpath=file_path,\
                                                                                                        train_size=args.train_size,test_size=args.test_size)
    
with open(file_path+'accuracy_record.csv','w') as f:
    writer = csv.writer(f,delimiter=',')
    for t in range(len(acc_record)):
        writer.writerow(acc_record[t])

with open(file_path+'avg_accuracy.csv','w') as f:
    writer = csv.writer(f,delimiter=',')
    writer.writerow(avg_accs)


with open(file_path+'eplapsed_time.csv','w') as f:
    writer = csv.writer(f,delimiter=',')
    writer.writerow([time_count])

if args.rho > 0:
    print('compute rho spectrum on test sets')
    tx = np.vstack([ts[0][:1000] for ts in test_sets])
    ty = np.vstack([ts[1][:1000] for ts in test_sets])
    feed_dict = {Model.training:False,Model.x_ph:tx} if Model.net_type=='resnet18' else {Model.x_ph:tx}
    features = sess.run(Model.H[-2],feed_dict=feed_dict)
    rho = rho_spectrum(features,mode=2)
    print('rho',rho)
    
    with open(file_path+'rho.csv','w') as f:
        writer = csv.writer(f,delimiter=',')
        writer.writerow([rho])







    




