import sys
import time
import tensorflow as tf
from deel.utils.lip_utils import rescale_grad_unit,redressage_grad,gradient_norm
from tensorflow.keras.metrics import top_k_categorical_accuracy,categorical_accuracy,binary_accuracy,AUC
from deel.datasets.oneclass_dataset import get_random_distribution,class_versus_random


def add_model_regularizer_loss(model, lambda_orth = 0):
    loss=0
    #for l in model.layers:
        #if hasattr(l,'layers') and l.layers: # the layer itself is a model
        #    loss+=add_model_loss(l)
        #if hasattr(l,'kernel_regularizer') and l.kernel_regularizer and lambda_orth!=0:
        #    loss+=lambda_orth*l.kernel_regularizer(l.kernel)
        #if hasattr(l,'bias_regularizer') and l.bias_regularizer:
        #    loss+=l.bias_regularizer(l.bias)
    loss += tf.reduce_sum(model.losses)
    return loss




@tf.function
def train_step_binary(x, y,
               model,
               loss_fn,
               t,
               optimizer,
               grad_coeff = 0.1,
               optim_margin = False,
               rescale_grad = False):
    with tf.GradientTape() as w_tape :
        logits = model(x, training=True)
        loss_value = loss_fn(y, logits)
        final_loss = loss_value

    weights = model.trainable_weights

    if optim_margin:
        weights = weights+[t]
    grads = w_tape.gradient(final_loss, weights)
    if rescale_grad:
        if optim_margin:
            grads = redressage_grad(grads[:-1], weights[:-1],coeff = grad_coeff)
        else :
            grads = redressage_grad(grads, weights,coeff = grad_coeff)
    optimizer.apply_gradients(zip(grads, weights))
    acc=binary_accuracy(y,logits,threshold = 0.)
    

    results = {"loss" :loss_value,"categorical_accuracy" :acc,"grad_norm" : gradient_norm(grads)}
    y_pred = logits
    y_true = y
    return results


@tf.function
def train_epoch_binary(
                train_it,
                val_it,
                model,
                loss_fn,
                metrics,
                t,
                optimizer,
                steps_per_epoch,
                callbacks,
                grad_coeff = 0.1,
                optim_margin = False,
                rescale_grad = False):

    print(model.optimizer.lr)
    start_time = time.time()
    logs ={}
    for batch in range(steps_per_epoch):
        x,y = next(train_it)
        
        for c in callbacks:
            c.on_batch_begin(batch, logs=None)
            #c.on_train_batch_begin(batch, logs=None)


        results= train_step_binary(x, y,
                            model,
                            loss_fct,
                            loss_fct.margin,
                            optimizer, 
                            grad_coeff = grad_coeff,
                            rescale_grad = rescale_grad,
                            optim_margin=optim_margin)

        logs ={}
        for k in results.keys():
            logs[k] = results[k].numpy().mean()
        for c in callbacks:
            c.on_train_batch_end(batch, logs=logs)
        for k in results.keys():
            if k not in metrics:
                metrics[k] = tf.metrics.Mean()
            metrics[k].update_state(results[k])

    total_time =time.time() - start_time
    #apply_constraints(model)
    
    if val_it is not None:
        auc = AUC()
        for batch in range(validation_step):
            x, y = next(val_it)
            logits = model(x, training=False)
            loss_value = loss_fct(y, logits)
            auc.update_state(y, tf.math.sigmoid(logits))
            acc = binary_accuracy(y,logits,threshold = 0.)
            metrics["val_loss"].update_state(loss_value.numpy())
            metrics["val_acc"].update_state(acc.numpy())
        logs ['AUC'] =auc.result().numpy() 
    logs = {k: metrics[k].result() for k in metrics.keys()}
    logs ['time']  = total_time       

    for c in callbacks:
        c.on_epoch_end(e, logs=logs)
       #print("max m",tf.reduce_max(loss_fct.margins).numpy(),"min m",tf.reduce_min(loss_fct.margins).numpy(), "men m",tf.reduce_mean(loss_fct.margins).numpy())
    for k in metrics.keys():
        metrics[k].reset_states()
    return logs


def fit_one_class(model, 
            X,
            validation,
            loss_fct, 
            optimizer, 
            validation_step=50,
            steps_per_epoch=50,
            callbacks=[],
            epochs=20,
            schedule = None,
            change = 0.1,
            rescale_grad = False,
            scale = None,
            batch_size = 16,
            grad_coeff = 0.1,
            trace_func = None,
            test = None,
            filename = None,
            optim_margin=False):
    for c in callbacks:
        c.set_model(model)

    dtset = class_versus_random(X,batch_size = batch_size,scale = scale)
    
    train = tf.data.Dataset.from_generator(dtset['train'],(tf.float32, tf.float32)).prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
    validation = tf.data.Dataset.from_generator(validation,(tf.float32, tf.float32)).prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
    train_it = train.__iter__()
    val_it = validation.__iter__()
    model_vars = model.trainable_variables
    logs = {}
    for c in callbacks:
        c.on_train_begin(logs=logs)
    metrics = {}
    metrics["val_loss"] = tf.metrics.Mean()
    metrics["val_acc"] = tf.metrics.Mean()


    for e in range(epochs):
        if schedule is not None:
            loss_fct.set_margin_coeff(schedule["margins"][e])
            optimizer.lr.assign(schedule["lrs"][e])
            tf.print("margin :",schedule["margins"][e], " lr", schedule["lrs"][e])
        start_time = time.time()
        for batch in range(steps_per_epoch):
            x,y = next(train_it)
            
            for c in callbacks:
                c.on_batch_begin(batch, logs=None)
                #c.on_train_batch_begin(batch, logs=None)


            results= train_step_binary(x, y,
                                model,
                                loss_fct,
                                loss_fct.margin,
                                optimizer, 
                                grad_coeff = grad_coeff,
                                rescale_grad = rescale_grad,
                                optim_margin=optim_margin)

            logs ={}
            for k in results.keys():
                logs[k] = results[k].numpy().mean()
            for c in callbacks:
                c.on_train_batch_end(batch, logs=logs)
            for k in results.keys():
                if k not in metrics:
                    metrics[k] = tf.metrics.Mean()
                metrics[k].update_state(results[k])

        total_time =time.time() - start_time
        #apply_constraints(model)
        auc = AUC()
        for batch in range(validation_step):
            x, y = next(val_it)
            logits = model(x, training=False)
            loss_value = loss_fct(y, logits)
            auc.update_state(y, tf.math.sigmoid(logits))
            acc = binary_accuracy(y,logits,threshold = 0.)
            metrics["val_loss"].update_state(loss_value.numpy())
            metrics["val_acc"].update_state(acc.numpy())
        logs = {k: metrics[k].result() for k in metrics.keys()}
        logs ['time']  = total_time       
        logs = {k: metrics[k].result() for k in metrics.keys()}
        logs ['time']  = total_time
        logs ['AUC'] =auc.result().numpy() 
        dtset = class_versus_random(X,model=model,batch_size = batch_size,scale = scale,X_prev_vs =dtset['curent_vs'], change = change)
    
        train = tf.data.Dataset.from_generator(dtset['train'],(tf.float32, tf.float32)).prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
        train_it = train.__iter__()
        if trace_func is not None:
            if filename is None:
                trace_func(model,dtset)
            else :
                trace_func(model,dtset, filename = filename+"_"+str(e)+".png")
        for c in callbacks:
            c.on_epoch_end(e, logs=logs)
        print(f"Epoch {e+1}/{epochs}")
        tf.print(f"time : {total_time:.2f}s *** loss:",metrics["loss"].result(), "accuracy:",metrics["categorical_accuracy"].result(),"margin",loss_fct.margin, "val accuracy", metrics["val_acc"].result(), " val AUC",logs ['AUC'])

        #print("max m",tf.reduce_max(loss_fct.margins).numpy(),"min m",tf.reduce_min(loss_fct.margins).numpy(), "men m",tf.reduce_mean(loss_fct.margins).numpy())
        for k in metrics.keys():
            metrics[k].reset_states()
        sys.stdout.flush()

    for c in callbacks:
        c.on_train_end(logs=logs)


@tf.function
def train_step_hkr(x, y,
               model,
               loss_fn,
               t,
               optimizer,
               lambda_orth = 0,
               grad_coeff = 0.1,
               optim_margin = False,
               redress = False):
    with tf.GradientTape() as w_tape :
        logits = model(x, training=True)
        loss_value = loss_fn(y, logits)
        if lambda_orth != 0:
            regul = lambda_orth*add_model_regularizer_loss(model,lambda_orth=lambda_orth)
        else :
            regul = 0
        final_loss = loss_value+ regul

    weights = model.trainable_weights
    #tf.print("regul",regul,lambda_orth)
    if optim_margin:
        weights = weights+[t]
    grads = w_tape.gradient(final_loss, weights)
    if redress:
        #grad_coeff = grad_coeff*optimizer.lr.numpy()
        if optim_margin:
            grads[:-1] = redressage_grad(grads[:-1], weights[:-1],coeff = grad_coeff)
        else :
            grads = redressage_grad(grads, weights,coeff = grad_coeff)
    optimizer.apply_gradients(zip(grads, weights))


    #tf.print("final",diff_grad(saved_weights,new_weights),diff_grad(saved_weights,last_weights))
    acc=categorical_accuracy(y,logits)
    #tf.print(loss_value,tf.reduce_mean(top_k),tf.reduce_mean(acc))
    #train_acc_metric.update_state(y, logits)

    results = {"loss" :loss_value,"categorical_accuracy" :acc, "regul" :regul,"grad_norm" : gradient_norm(grads)}
    y_pred = logits
    y_true = y

    H1 = tf.where(y_true==1,tf.reduce_min(y_pred), y_pred) ## set y_true at minimum on batch to avoid being the max
    vYtrue = tf.reduce_sum(y_pred * y_true, axis=1)
    maxOthers = tf.reduce_max(H1, axis=1)
    results["robustness"] = tf.reduce_mean(vYtrue)
    results["avg_value"] = tf.reduce_mean(tf.abs(y_pred))
    results["abs_margin"] = tf.reduce_mean(tf.abs(vYtrue-maxOthers))
    results["margin"] = tf.reduce_mean(vYtrue-maxOthers)
    results["margin_std"] = tf.math.reduce_std(vYtrue-maxOthers)
    return results


def fit_hkr(model, 
            train,
            validation, 
            loss_fct, 
            optimizer, 
            steps_per_epoch=50,
            validation_step=50,
            callbacks=[],
            epochs=20,
            betas = None,
            redress = False,
            verbose=2,
            grad_coeff = 0.1,
            optim_margin=False,
            lambda_orth = 0):
    for c in callbacks:
        c.set_model(model)

    train_it = train.__iter__()
    val_it = validation.__iter__()
    model_vars = model.trainable_variables
    logs = {}
    for c in callbacks:
        c.on_train_begin(logs=logs)
    metrics = {}
    metrics["val_loss"] = tf.metrics.Mean()
    metrics["val_acc"] = tf.metrics.Mean()
    val_acc = 0.
    val_k_acc = 0.
    if hasattr(loss_fct, 'margins'):
        margin_variables = loss_fct.margins
    else :
        margin_variables = None
    for e in range(epochs):
        if betas is not None and hasattr(loss_fct, 'beta'):
            loss_fct.beta.assign(betas[e])
        start_time = time.time()
        for batch in range(steps_per_epoch):
            x,y = next(train_it)
            
            for c in callbacks:
                c.on_batch_begin(batch, logs=None)
                #c.on_train_batch_begin(batch, logs=None)


            results= train_step_hkr(x, y,
                                model,
                                loss_fct,
                                margin_variables,
                                optimizer, 
                                grad_coeff = grad_coeff,
                                redress = redress,
                                lambda_orth = lambda_orth,
                                optim_margin=optim_margin)

            logs ={}
            for k in results.keys():
                logs[k] = results[k].numpy().mean()
            for c in callbacks:
                c.on_train_batch_end(batch, logs=logs)
            for k in results.keys():
                if k not in metrics:
                    metrics[k] = tf.metrics.Mean()
                metrics[k].update_state(results[k])

        total_time =time.time() - start_time
        #apply_constraints(model)
        

        for batch in range(validation_step):
            x, y = next(val_it)
            logits = model(x, training=False)
            loss_value = loss_fct(y, logits)
            acc = categorical_accuracy(y, logits)
            metrics["val_loss"].update_state(loss_value.numpy())
            metrics["val_acc"].update_state(acc.numpy())
        logs = {k: metrics[k].result() for k in metrics.keys()}
        logs ['time']  = total_time

            
        for c in callbacks:
            c.on_epoch_end(e, logs=logs)
        print(f"Epoch {e+1}/{epochs}")
        print(f"time : {total_time:.2f}s *** loss:",metrics["loss"].result(), "accuracy:",metrics["categorical_accuracy"].result())
        if margin_variables is not None :
            print("max m",tf.reduce_max(margin_variables).numpy(),"min m",tf.reduce_min(margin_variables).numpy(), "men m",tf.reduce_mean(margin_variables).numpy())
        for k in metrics.keys():
            metrics[k].reset_states()
        sys.stdout.flush()

    for c in callbacks:
        c.on_train_end(logs=logs)