#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Sun Sep 20 00:22:26 2020

"""

from dataset import Dataset
from sampling import Batch,Sampling
import tensorflow as tf
import copy
import tensorflow_addons as tfa
import json
from model import Model
from helper_fn import getresponse
import numpy as np
import argparse
import pandas as pd
import os
# set random seeds
tf.random.set_seed(0)
np.random.seed(42)
# create parser
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--split', help='Select training fold', type=int,default=0)
parser.add_argument('--d2v', help='train only the metafeature extractor', type=str,choices=['True','False'],default='False')
parser.add_argument('--fixed_hyperparameter', help='Sampling of the tasks with fixed dataset/hyperparameter pair or only fixed dataset', type=bool,default=True)
parser.add_argument('--searchspace', help='Select metadataset',choices=['a','b','c'], type=str,default='a')
parser.add_argument('--ablation', help='Apply ablation on losses',choices=['0','1','2','3','4','5','6','None'], type=str,default='None')
parser.add_argument('--learning_rate', help='Learning rate',type=float,default=1e-2)
parser.add_argument('--alpha', help='batch identification task hyperparameter',type=float,default=1)
parser.add_argument('--beta', help='regularization task hyperparameter',type=float,default=0.5)
parser.add_argument('--delta', help='negative datasets weight',type=float,default=2)
parser.add_argument('--gamma', help='distance hyperparameter',type=float,default=1)
parser.add_argument('--reptile_steps', help='reptile-steps',type=int,default=5)

args    = parser.parse_args()

args.ablation  = None if args.ablation =='None' else args.ablation

rootdir     = os.path.dirname(os.path.realpath(__file__))
config_file = os.path.join(rootdir, "configurations","iclr.json")
info_file   = os.path.join(rootdir, "metadataset"  ,"info.json")
# load configuration
configuration = json.load(open(config_file,'r'))
# update with shared configurations with specifics
config_specs = {
    'split':	args.split,
    'd2v':	eval(args.d2v),
    'fixed_hyperparameter':	args.fixed_hyperparameter,
    'searchspace':	args.searchspace,
    'ablation':	args.ablation,
    'learning_rate':	args.learning_rate,
    'alpha':	args.alpha,
    'beta':	args.beta,
    'delta':	args.delta,
    'gamma':	args.gamma,
    'k-reptile':	args.reptile_steps,
    'minmax':	True,
    'batch_size':	16,
    'backend_learning_rate':	1e-3
    }

configuration.update(config_specs)

searchspaceinfo = json.load(open(info_file,'r'))
configuration.update(searchspaceinfo[args.searchspace])

# create Dataset
normalized_dataset         = Dataset(configuration,rootdir,use_valid=True)

# load training sets
nsource = len(normalized_dataset.orig_data['train'])
ntarget = len(normalized_dataset.orig_data['valid'])
ntest   = len(normalized_dataset.orig_data['test'])
# create Dataset
# configure decay steps
# learning rate scheduler
backendoptimizer = tf.keras.optimizers.Adam(configuration['backend_learning_rate'])
optimizer        = tf.keras.optimizers.SGD(configuration['learning_rate'])
# create training model
# create model
model     = Model(configuration,rootdir=rootdir)
batch     = Batch(configuration['batch_size'])
# create evaluation model (for validation)

testconfiguration = copy.copy(configuration)
testconfiguration['batch_size'] = 16 if args.searchspace != 'c' else 18

testmodel   = Model(testconfiguration,rootdir=rootdir,for_eval=True)
testbatch   = Batch(testconfiguration['batch_size'])

# define list/csv tracking
print(model.model.summary())

# Define training parameters
epochs = 10000
# reset metric trackers
model.reset_states()    

fn = lambda targetdataset,sourcedataset,config,targetsplit,sourcesplit : normalized_dataset.instances(targetdataset=targetdataset,\
                                                                                                                   sourcedataset=sourcedataset,config=config,split=targetsplit,sourcesplit=sourcesplit)
# Start training
sampler = Sampling(dataset=normalized_dataset,fixed_hyperparameter=configuration["fixed_hyperparameter"])

for epoch in range(epochs):
    # iterate over all task-dataset taskdistribution
    reuse = False
    model.reset_states()  
    model.store() # copy backend from model
    for reptile_steps in range(configuration["k-reptile"]):
        batch = sampler.sample(batch,split='train',sourcesplit='train',reuse=reuse)
        batch.collect()
        metrics = model.train_step(x=batch.input,y=batch.output,optimizer=optimizer,clip=True)
        reuse=True
        # update tracker from internal model.metrics
        model.update_tracker(training=True,metrics=metrics)
    # update backend model
    model.backend_train_step(backendoptimizer)
    # set weights of internal model
    model.set_weights() # copy model from backend
    #model.report()
    model.dump()
    sampler_file = os.path.join(model.directory,"distribution.csv")
    sampler.distribution.to_csv(sampler_file)