import numpy as np
import pandas as pd
import os

import matplotlib.pyplot as plt
from lib.tf import callbacks
from lib.tf import models
from tensorflow.keras import metrics as tfMetrics
from tensorflow.keras import optimizers
from tensorflow.keras import losses
from lib import evaluate as ev

def return_metric_by_task(task,name_only=False):
    
    if name_only:
        if task == "r":
            metrics = ["rmse"]
        elif task == "bc":
            metrics = ["auc","acc"]
        elif task == "mc":
            metrics = ["acc"]
        else:
            raise ValueError(f"{task} invalid")
    else:    
        if task == "r":
            metrics = [tfMetrics.RootMeanSquaredError(name="rmse")]
        elif task == "bc":
            metrics = [tfMetrics.AUC(from_logits=False,name="auc"),tfMetrics.BinaryAccuracy(name="acc")]
        elif task == "mc":
            metrics = [tfMetrics.SparseCategoricalAccuracy(name="acc")]
        else:
            raise ValueError(f"{task} invalid")
    
    return metrics

def return_model(task,num_output,mdl_nm,mdl_args,opt_nm,opt_args):
    if task == "r":
        loss_func = losses.MeanSquaredError()
        metrics = [tfMetrics.RootMeanSquaredError(name="rmse")]
        output_act = "linear"
    elif task == "bc":
        loss_func = losses.BinaryCrossentropy(from_logits=False,label_smoothing=0.0)
        metrics = [tfMetrics.AUC(from_logits=False,name="auc"),tfMetrics.BinaryAccuracy(name="acc")]
        output_act = "sigmoid"
    elif task == "mc":
        loss_func = losses.SparseCategoricalCrossentropy(from_logits=False)
        metrics = [tfMetrics.SparseCategoricalAccuracy(name="acc")]
        output_act = "softmax"
    else:
        raise ValueError

    adtnl_mdl_args = {key:value for key,value in mdl_args.items()}
    adtnl_mdl_args.update({"num_output":num_output,"output_act":output_act})
    
    mdl = getattr(models,mdl_nm)(**adtnl_mdl_args)
    optimizer = getattr(optimizers,opt_nm)(**opt_args)

    mdl.compile(optimizer, loss=loss_func, metrics=metrics)
    return mdl

def train_nn(X,y,task,nn_nm="",nn_args={},
             opt_nm="",opt_args={},
             lr_warmup=1, lr_decay=1.0, lr_min=None, lr_max=None,
             epochs= 1000, batch_size= 1000,monitor="loss", patience=15, 
             validation_data = None, extra_data = None, 
             verbose=0,  interval=1,model_save_path = None, logger = print
            ):
    num_feat = X.shape[-1]   
    if task == "r":
        num_output = 1
        best_mode = "min"
    elif task == "bc":
        num_output = len(np.unique(y))-1
        assert num_output == 1
        best_mode = "max"
    elif task == "mc":
        num_output = len(np.unique(y))
        best_mode = "max"
    else:
        raise ValueError       
    nn_args.update({"num_output":num_output})
    
    nn_mdl = return_model(task,num_output,nn_nm,nn_args,opt_nm,opt_args)
    nn_mdl.build((None,num_feat))
    if nn_nm in ("SINNModel"):
        nn_mdl.build_norm_metrics()
    
    callbacks = [callbacks.MyLearningRateScheduler(lr_warmup,lr_decay,lr_min,lr_max)]
    if model_save_path is not None: 
        if model_save_path.endswith(".h5"):
            checkpoint_dir = model_save_path
        else:
            checkpoint_dir = os.path.join(model_save_path,"weight.h5")
        callbacks.append(callbacks.ModelCheckpoint(checkpoint_dir, save_weights_only=True,monitor=monitor,save_best_only=True,mode=best_mode))
    
    if validation_data is not None and patience >0:
        callbacks.append(callbacks.EarlyStopping(monitor=monitor, patience=patience, verbose=0,restore_best_weights=True, mode=best_mode))
    # Log is must-be
    callbacks.append(callbacks.MyProgressCallback(verbose,interval,digit=5,extra_validation=extra_data,batch_size=batch_size,logger=logger))        
    
    history = nn_mdl.fit(
        X,y, epochs= epochs, batch_size = batch_size, shuffle=True,
        callbacks=callbacks,
        verbose=0, validation_data= validation_data
    )
    stop_epoch = len(history.history[monitor])
    best_epoch = np.argmin(history.history[monitor])+1 if task=="r" else np.argmax(history.history[monitor])+1

    desc = {"stop_epoch":stop_epoch,"best_epoch":best_epoch}
    desc.update(nn_args)
    result = {"mdl":nn_mdl,"desc":desc,"history":history}
    return result

def visualize_nn_performance(mdl,history,X_train,y_train,X_test,y_test,X_valid,y_valid,title="",fig_args={},digit=4,second_axis=None):
    if type(mdl.loss).__name__ == "BinaryCrossentropy":
        eval_func = ev.evaluate_bc
    elif type(mdl.loss).__name__ == "SparseCategoricalCrossentropy":
        eval_func = ev.evaluate_mc
    elif type(mdl.loss).__name__ == "MeanSquaredError":
        eval_func = ev.evaluate_r
    else:
        raise ValueError
        
    history_df = pd.DataFrame(history).reset_index().rename({"index":"epoch"},axis=1)
    
    train_result = mdl.evaluate(X_train,y_train,batch_size=1000,verbose=0,return_dict=True)
    valid_result = mdl.evaluate(X_valid,y_valid,batch_size=1000,verbose=0,return_dict=True)
    test_result = mdl.evaluate(X_test,y_test,batch_size=1000,verbose=0,return_dict=True)
    perf_result_df=  pd.DataFrame([train_result,valid_result,test_result],index=["train","valid","test"])
    
    train_result = [round(value,digit) for key,value in train_result.items()]
    valid_result = [round(value,digit) for key,value in valid_result.items()]
    test_result = [round(value,digit) for key,value in test_result.items()]

    fig,ax = plt.subplots(1,1,**fig_args)
    art1=ax.plot(history_df["loss"],color="black") ; ax.plot(history_df["val_loss"],color="red")
    ax.set_ylabel("Loss")
    ax.set_xlabel("epoch")
    if hasattr(train_result,"__iter__") and second_axis is not None:
        ax2 = ax.twinx()
        ax2.plot(history_df[second_axis], linestyle="--",color="black") ; ax2.plot(history_df[f"val_{second_axis}"], linestyle="--",color="red")
        ax2.set_ylabel(second_axis)

    if title != "":
        ax.set_title(f"{title}\nTrain:{train_result}/Valid:{valid_result}/Test:{test_result}",fontsize=9.0)
    else:
        ax.set_title(f"Train:{train_result}/Valid:{valid_result}/Test:{test_result}",fontsize=9.0)

    return fig, history_df,  perf_result_df