#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Tue Nov 17 07:45:10 2020

"""

from dataset import Dataset
from sampling import Batch,TestSampling
import tensorflow as tf
import json
from model import Model
from helper_fn import getresponse
from metrics import distance_to_minimum
import numpy as np
import argparse
import pandas as pd
import os
# set random seeds
tf.random.set_seed(0)
np.random.seed(42)
# create parser
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--split', help='Select training fold', type=int,default=0)
parser.add_argument('--loaditeration', help='load iterations for testing', type=int)
parser.add_argument('--n_iters', default=100, type=int, nargs='?', help='number of iterations for optimization method')
parser.add_argument('--searchspace', help='Select training split scheme',choices=['a','b','c'], type=str,default='a')
parser.add_argument('--fixed_hyperparameter', help='Sampling of the tasks with fixed dataset/hyperparameter pair or only fixed dataset',\
                    type=bool,default=True)
parser.add_argument('--k', help='number of zero-shot confiugrations', type=int,default=20)
parser.add_argument('--metric', help='Select top ? ', type=str,choices=['min_20','min_50'],default='min_20')
parser.add_argument('--ablation', help='Select Ablation loss',choices=['0','1','2','3','4','5','6'], type=str)
parser.add_argument('--fold', help='Select validation/test fold', choices=['test'],type=str,default='test')
parser.add_argument('--top', help='greedy selection', type=int,default=1)
args    = parser.parse_args()

prefix      = "vanilla/" if args.ablation == None else f"ablation-{args.ablation}/"
rootdir     = os.path.dirname(os.path.realpath(__file__))
soludir     = os.path.join(rootdir,"checkpoints",f"searchspace-{args.searchspace}",f"split-{args.split}","gr20si",prefix,"iclr")

configuration   = json.load(open(os.path.join(soludir,"configuration.txt"),"r"))
loaditeration     = args.loaditeration
loaddir           = os.path.join(rootdir,configuration["savedir"],f"iteration-{int(loaditeration)}","weights","weights")
splits_file       = os.path.join(rootdir, "metadataset",f"searchspace-{args.searchspace}",f"searchspace-{args.searchspace}-splits.csv")
metasplits        = pd.read_csv(splits_file,index_col=0)
# create Dataset
normalized_dataset         = Dataset(configuration,rootdir,use_valid=True)

# load training sets
nsource          = len(normalized_dataset.orig_data['train'])
backendoptimizer = tf.keras.optimizers.Adam(configuration['backend_learning_rate'])
optimizer        = tf.keras.optimizers.SGD(configuration['learning_rate'])

configuration["batch_size"] = 16 if args.searchspace != 'c' else 18
configuration["ablation"] = args.ablation # how to retrain 
model     = Model(configuration,rootdir=rootdir,for_eval=True)
batch     = Batch(configuration["batch_size"])
print(model.model.summary())

fn = lambda targetdataset,sourcedataset,config,targetsplit,sourcesplit : normalized_dataset.instances(targetdataset=targetdataset,\
                                                                                                                   sourcedataset=sourcedataset,config=config,split=targetsplit,sourcesplit=sourcesplit)

zerooverview          = pd.DataFrame(None)
iteroverview          = pd.DataFrame(None)
for ntarget,file in enumerate(metasplits[f"{args.fold}-{args.split}"].dropna()):
    model.model.load_weights(loaddir, by_name=False, skip_mismatch=False)
    objective     = os.path.join(args.fold,args.metric,"zero-shot")
    os.makedirs(os.path.join(soludir,objective),exist_ok=True)
    results = os.path.join(rootdir,configuration["savedir"],f"iteration-{int(loaditeration)}",args.fold,f"overview-targetdataset-{ntarget}.csv")
    results = pd.read_csv(results,header=0,index_col=0)
    d2m = pd.DataFrame(distance_to_minimum(results,targetdataset=ntarget,split=args.fold,dataset=normalized_dataset))
    d2m.columns = [file]
    zerooverview = pd.concat([zerooverview,d2m],axis=1)
    if args.top > 0:
        response           = np.asarray(results["response"]).reshape(-1,)
        refittingobjective = os.path.join(args.fold,args.metric,"sequential",f"ablation-{args.ablation}" if args.ablation is not None else "",\
                                          f"top-{args.top}",f"initial-{args.k}",file)    
        os.makedirs(os.path.join(soludir,refittingobjective),exist_ok=True)
        x   = list(np.argsort(results['targetresponse'])[::-1][:args.k])
        opt = max(response)
        while len(x) in range(args.n_iters):
            sol = [response[_] for _ in x]
            if opt in sol:
                break        
            sampler      = TestSampling(dataset=normalized_dataset,fixed_hyperparameter=args.fixed_hyperparameter)
            if not args.fixed_hyperparameter:
                model.store()
                for reptile_steps in range(configuration["k-reptile"]):
                    batch = sampler.sample(batch,split=args.fold,sourcesplit='train',targetdataset=ntarget,collection=x)
                    batch.collect()
                    metrics = model.train_step(x=batch.input,y=batch.output,optimizer=optimizer,clip=True,no_metrics=True)
                model.backend_train_step(backendoptimizer)
                model.set_weights()
            else:
                for col in range(len(x)):
                    model.store()
                    for reptile_steps in range(configuration["k-reptile"]):
                        batch = sampler.sample(batch,split=args.fold,sourcesplit='train',targetdataset=ntarget,collection=x,index=col)
                        batch.collect()
                        metrics = model.train_step(x=batch.input,y=batch.output,optimizer=optimizer,clip=True,no_metrics=True)
                    model.backend_train_step(backendoptimizer)
                    model.set_weights()
            results,d2m       = getresponse(model, batch, normalized_dataset, ntarget=ntarget, split=args.fold,nsource=nsource, sourcesplit='train',fn=fn,directory= os.path.join(soludir,refittingobjective,f"iter-{len(x)}"),index=True,return_response=True)
            assert(len(results)==1)
            results       = results[0]            
            newx          = [_ for _ in np.argsort(results['targetresponse'])[::-1] if _ not in x]
            x += [np.random.choice(newx[:args.top])]
        y = [response[_] for _ in x]
        d2m = distance_to_minimum(x,ntarget,split=args.fold,dataset=normalized_dataset,csv=False)
        iteroverview = pd.concat([iteroverview,pd.DataFrame(d2m,columns=[file])],axis=1)
        iteroverview.to_csv(os.path.join(soludir,args.fold,args.metric,"sequential",f"ablation-{args.ablation}" if args.ablation is not None else "",f"top-{args.top}",f"initial-{args.k}","results.csv"))
zerooverview.to_csv(os.path.join(soludir,objective,"results.csv"))