import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
import absl.logging
absl.logging.set_verbosity(absl.logging.ERROR)
from lib import train
import numpy as np
from lib import utils
from lib import dataload as dl
from lib.config import *
from lib.tf import tf_utils as tfUtils
import tensorflow as tf
from tensorflow.keras import metrics as tfMetrics

from hyperopt import fmin, hp, STATUS_OK, STATUS_FAIL, tpe, Trials
from hyperopt.early_stop import no_progress_loss
from functools import partial
from hyperopt.pyll.stochastic import sample
from lib.hpot import EarlystopWithWarmup

from time import time
from datetime import datetime
import shutil

def param_converter(params):
    params["num_layer"] = 2
    tl_unit = int(params["tl_unit"])
    tl_drop =  float(params["tl_drop"]) if float(params["tl_drop"])>0.0 else 0.0 
    fcn_unit = [int(params["fcn_unit"]),int(params["fcn_unit"]/2)]
    fcn_drop =  [float(params["fcn_drop"])] * int(params["num_layer"])
    
    intv_l1 = float(params["intv_norm"]) if float(params["intv_norm"])>0.0 else 0.0 
    intv_l2 = float(params["intv_norm"]) if float(params["intv_norm"])>0.0 else 0.0 

    batch_size =int(params["batch_size"])    
    lr = float(params["lr"])
    
    args = {"intv_units":tl_unit,"intv_drop":tl_drop, "beta1":float(params["beta1"]), "beta2":float(1.0),"beta3":float(1.0),
            "intv_l1":intv_l1, "intv_l2":intv_l2, "rdc_units":tl_unit, 
            "fcn_units":fcn_unit, "fcn_drop":fcn_drop,"lr":lr,
            "batch_size":batch_size,
            "intv_act":"softsign"}
    return args


class HpotWrapper:

    def __init__(self,task,train_data,early_stop, epochs,valid_data=None, seed=None, save_path=None, logger=None,std=0.5):
        assert save_path is not None, "save_path should not be None"
        self._task = task
        self._metrics = TASK_MAP[self._task]["metrics"]
        self._train_func = train.train_nn
        
        self._sign = 1.0 if task =="r" else -1.0
        
        assert len(train_data) == 2
        assert len(valid_data) == 2
        
        self._train_data = train_data
        self._valid_data = valid_data
        self._n_train = len(train_data[0])

        self._early_stop = early_stop
        self._epochs = epochs
        self._seed = seed        
        self._save_path = save_path
        self._temp_save_path = os.path.join(save_path,"weight_temp")
        
        self._best_model = None
        self._best_loss = 999999999999.0
        self._best_history = None
        self._num_try = 0 
        self.logger = logger
        self.std = 0.5
           
    def objective(self,params):
        
        start = time()
        if self._seed is not None: tf.keras.utils.set_random_seed(self._seed)

 
        args = param_converter(params)
        args.update({"seed":self._seed})
        lrstart = args["lr"]
        lrend = lrstart*0.1
        batch_size = args["batch_size"]
        
        metrics_object = [getattr(tfMetrics,metric)(name=METRIC_NM_MAP[metric]) for metric in self._metrics]
        monitor_metric = ["val_"+metric.name for metric in metrics_object][-1]
        self.write(f"Trial {self._num_try} by {monitor_metric}, NNParam : {args}")        

        del args["batch_size"]
        del args["lr"]
        args["std"] = self.std
        
        os.makedirs(self._temp_save_path,exist_ok=True)
            
        nn_result = self._train_func(
            X=self._train_data[0],y=self._train_data[1],task=self._task,
            nn_nm="SINNModel",nn_args=args, 
            opt_nm="Adam",opt_args={"learning_rate":lrstart},
            epochs=self._epochs, batch_size= batch_size, monitor=monitor_metric,
            patience=self._early_stop, verbose=0, 
            validation_data=(self._valid_data[0],self._valid_data[1]), extra_data=None, 
            lr_warmup=1,lr_decay=0.999, lr_min=lrend, lr_max=lrstart,
            interval=10, model_save_path = self._temp_save_path,
            logger = None
        )

        history = nn_result["history"].history
        best_round = np.argmin(history[monitor_metric]) if self._sign == 1.0 else np.argmax(history[monitor_metric]) 
        loss = history[monitor_metric][best_round]
        loss_last = history[monitor_metric][-1]
        status = STATUS_FAIL if np.isnan(loss) else STATUS_OK
        loss *= self._sign 
        self.write(f"End : Run {len(history[monitor_metric])} epochs, Best {loss} at {best_round+1} (Last loss {loss_last})")                    
        if status==STATUS_FAIL:
            return_result = {"loss":np.inf, "params": None,'best_round': None,'train_time': time()-start, 'status': status, "test":None}
        else:
            model = nn_result["mdl"]
            metric_names = [x.name for x in model.metrics]
            
            if self._temp_save_path is not None:
                self.write("Load weights from {}".format(self._temp_save_path))
                model.load_weights(os.path.join(self._temp_save_path,"weight.h5"))
                shutil.rmtree(self._temp_save_path)

            train_scores = model.evaluate(*self._train_data, verbose=0,batch_size=batch_size*2)
            valid_scores = model.evaluate(*self._valid_data, verbose=0,batch_size=batch_size*2)
            
            self.write(f"Metrics : {metric_names}")
            self.write(f"  Train :  {train_scores}")
            self.write(f"  Valid :  {valid_scores}")

            if (self._best_loss > loss):
                self.write(f"*** Best Updated : {self._best_loss}->{loss}")
                self._best_loss = loss
                self._best_model = model
                self._best_history = history
                           
            return_result = {"loss":loss, "params": params,'best_round': best_round+1,'train_time': time()-start, 'status': status}
        self.write("{}".format("="*50))
        self._num_try += 1
        return return_result
    
    def write(self,content):
        if self.logger is not None:
            self.logger(content)
        else:
            print(content)
        
def run(run,seed,tag,
        max_eval, hpot_es, nn_es,epochs,save_path,
        n_startup=20,warmup=20, el_candid=24,desc=""
       ):
    tfUtils.set_precision("float32")
    result_dict = {"param":None,"metric":None,"train":None,"valid":None,"test1":None}
    logger = utils.Logger(os.path.join(save_path,"log.txt"))

    space = {
        'tl_unit': hp.quniform('tl_unit', 50-5, 500+5, 10), 
        'tl_drop': hp.quniform('tl_drop', 0.00-0.025,0.50+0.025,0.05), 
        'fcn_unit':hp.quniform('fcn_unit', 100-10, 1000+10,20), 
        'fcn_drop': hp.quniform('fcn_drop', 0.00-0.025,0.50+0.025,0.05), 
        'beta1': hp.quniform('beta1', 1,20,0.1),

        'intv_norm':  hp.quniform('intv_norm', 1e-5,1e-3,1e-4), 
        'lr':hp.pchoice("lr",[(0.5,hp.quniform('lr1', 1e-4,1e-3+1e-4,2e-4)),(0.5,hp.quniform('lr2', 1e-3,4e-3,5e-4))]), 
    }
    if run in BIG_DATASET:
        space.update({'batch_size': hp.quniform('batch_size', 512-64,4096+64,128)})
    else:
        space.update({'batch_size': hp.quniform('batch_size', 128-16,1024+16,32)})
        
    data_load_func = getattr(dl,f"load_{run}")
    data_dir = os.path.join(DATA_DIR,run)
    if seed is not None: tf.keras.utils.set_random_seed(seed)
    if run in PREDEFINED_TEST:
        X_train,X_valid,X_test,y_train,y_valid,y_test,info = data_load_func(data_dir,seed=seed,as_frame=False,normalize=True,qtran=False)
    else:
        X_train,X_valid,X_test,y_train,y_valid,y_test,info = data_load_func(data_dir,test_size=0.2,seed=seed,as_frame=False,normalize=True,qtran=False)
    
    task = DATA_TASK_MAP[run]
    wrapper = HpotWrapper(task,(X_train,y_train),nn_es, epochs, valid_data=(X_valid,y_valid), seed=seed, save_path=save_path, logger=logger)
    objective_func = wrapper.objective
    logger("Run {}({},{}), train:valid:test={}:{}:{}".format(run,task,X_train.shape[-1],len(X_train),len(X_valid),len(X_test)))

    trials = Trials()
    es_func = EarlystopWithWarmup(hpot_es,warmup)
    algo = partial(tpe.suggest, n_startup_jobs=n_startup, n_EI_candidates=el_candid)
    best = fmin(fn = objective_func, space = space, algo = algo, max_evals = max_eval, 
                trials = trials, early_stop_fn = es_func, rstate=np.random.default_rng(seed)
               )
    best_params = trials.best_trial["result"]["params"]
    
    best_mdl = wrapper._best_model
    metric_names = [m.name for m in best_mdl.metrics]
    train_scores = best_mdl.evaluate(X_train,y_train, verbose=0)
    valid_scores = best_mdl.evaluate(X_valid,y_valid, verbose=0)
    test_scores = best_mdl.evaluate(X_test,y_test, verbose=0)

    wrapper.write(f"Trial Best, Param : {best_params}")  
    wrapper.write(f"Metrics : {metric_names}")
    wrapper.write(f"  Train :  {train_scores}")
    wrapper.write(f"  Valid :  {valid_scores}")
    wrapper.write(f"   Test :  {test_scores}")                  
    
    result_dict["param"] = trials.best_trial["result"]["params"]
    result_dict["metric"] = metric_names
    result_dict["train"] = train_scores
    result_dict["valid"] = valid_scores
    result_dict["test1"] = test_scores
    result_dict["info"] = {"scale":str(info.get("scale",None)),"task":task,"num_feat":str(info.get("num_feat",None))}
    result_dict["desc"] = desc

    best_mdl.save_weights(os.path.join(save_path,"weights.h5"))
    utils.save_json(best_mdl.return_config(),os.path.join(save_path,f"config.json"))
    utils.save_json(result_dict,os.path.join(save_path,f"result.json"))

    # Save all informations
    import pandas as pd
    result_list = [elem["result"] for elem in trials.trials]
    hyperparam_list = [elem["params"] for elem in result_list]
    hyperparam_df = pd.DataFrame(hyperparam_list )
    other_list = [{key:value for key,value in elem.items() if key!="params"} for elem in result_list]
    other_df = pd.DataFrame(other_list )
    final_df = pd.concat([other_df,hyperparam_df],keys=["info","params"],axis=1)    
    final_df.reset_index(drop=False,inplace=True)
    final_df.to_csv(os.path.join(save_path,f"hpot.csv"),index=False)

if __name__ == "__main__":
    import argparse
    parser  = argparse.ArgumentParser()
    parser.add_argument("--run",type=str)    
    parser.add_argument("--seed",type=int,default=None)
    parser.add_argument("--tag",type=str)
    parser.add_argument("--max_eval",type=int,default=100)
    parser.add_argument("--hpot_es",type=int,default=20)
    parser.add_argument("--nn_es",type=int,default=200)
    parser.add_argument("--epochs",type=int,default=1000)
    parser.add_argument("--desc",type=str,default="")
    
    parser.add_argument("--n_startup",type=int,default=20)
    parser.add_argument("--warmup",type=int,default=0)
    parser.add_argument("--el_candid",type=int,default=24)
    
    args = parser.parse_args()

    if args.seed is not None: 
        np.random.seed(args.seed)
    result_save_dir = os.path.join(SINN_SAVE_DIR,args.run,f"{args.tag}-{args.seed}")
    os.makedirs(result_save_dir,exist_ok=True)
    print("{} : save at {}, run/tag/seed ={}/{}/{}".format(datetime.today(),result_save_dir,args.run, args.tag, args.seed))

    run(args.run, args.seed, args.tag, args.max_eval, args.hpot_es, args.nn_es, args.epochs, result_save_dir, 
        args.n_startup, args.warmup, args.el_candid
       )