import tensorflow as tf
import tensorflow.keras as keras
from tensorflow.keras.callbacks import LearningRateScheduler, EarlyStopping, ModelCheckpoint
import numpy as np
from datetime import datetime
from time import time
from operator import lt,le,ge,gt

class MyLearningRateScheduler(keras.callbacks.Callback):

    def __init__(self,warmup=1,rate=0.99,lr_min=None,lr_max=None,verbose=0):
        self.warmup = warmup
        self.rate = rate
        self.lr_min = lr_min 
        self.lr_max = lr_max
        self._init_lr = None
        if rate < 1.0 and lr_min is not None:
            self._is_valid = True
        elif rate > 1.0 and lr_max is not None: 
            self._is_valid = True
        else:
            if verbose: print("WARNING : rate,lr_min,lr_max is ambiguous. The schedule will not be applied")
            self._is_valid = False
            self.rate = 1.0
        self._verbose =verbose
        
    def _lr_decay_func(self,epoch):
        if epoch <= self.warmup:
            return self._init_lr
        else :
            epoch_after_warmup = epoch - self.warmup
            new_lr = np.power(self.rate,epoch_after_warmup) * self._init_lr
            clip_lr = np.clip(new_lr, self.lr_min, self.lr_max)
            return float(clip_lr)

    def on_train_begin(self,logs=None):
        self._init_lr = float(self.model.optimizer.learning_rate)
        if not self._is_valid:
            self.lr_min = self._init_lr
            self.lr_max = self._init_lr
        
    def on_epoch_begin(self,epoch,logs=None):
        if not hasattr(self.model.optimizer, "lr"):
            raise ValueError('Optimizer must have a "lr" attribute.')
        # Get the current learning rate from model's optimizer.
        new_lr = self._lr_decay_func(epoch)
        self.model.optimizer.learning_rate.assign(new_lr)
        if self._verbose: print("Epoch {}: lr={:.4f}".format(epoch,new_lr))
        
class MyProgressCallback(keras.callbacks.Callback):

    def __init__(self,verbose=0,interval=1,digit=6,f=None,extra_validation=None, batch_size=-1, logger=print):
        self._verbose = verbose
        self._f = f
        assert interval >0, "Must be an integer > 0"
        self._interval = interval
        self._digit = digit
        self._extra_data = extra_validation
        self._extra_data_type = type(extra_validation).__name__
        self._batch_size = batch_size
        self.logger = logger
        
    def print_func(self,content):
        if self.logger is None:
            pass        
        else :
            if self._verbose: 
                self.logger(content)
    
    def on_train_begin(self,logs=None):
        self.model._cb_result = {
            "StartTime":0.0,"EndTime":0.0,
            "W_init":dict()
        }
    
    def on_epoch_begin(self,epoch,logs=None):
        self._start_time = time()
        epoch += 1        
                
        if epoch == 1 :
            self.model._cb_result["StartTime"] = time()
            self.model._cb_result["W_init"] ={W.name: W.numpy().copy() for W in self.model.trainable_variables}
            
    def on_epoch_end(self,epoch,logs=None):
        epoch += 1
        start_time = time()
        
        if self._verbose <= 0: return
    
        if (epoch % self._interval == 0) or (epoch==1):
            zfill_epoch = str(epoch).zfill(4)
            main_str = ""
            for key,value in logs.items():
                value_round = round(value,self._digit)
                main_str+=f"{key}:{value_round} "

            if self._extra_data is not None:
                if self._extra_data_type in ["list","tuple"]:
                    extra_result = self.model.evaluate(self._extra_data[0],self._extra_data[1],verbose=0, batch_size=self._batch_size)
                else:
                    extra_result = self.model.evaluate(self._extra_data,verbose=0)
                extra_result = [round(x,self._digit) for x in extra_result]
                extra_metrics = [x.name for x in self.model.metrics]
                extra_string = " ".join([f"extra_{key}:{value_round}" for key,value_round in zip(extra_metrics,extra_result)])
                main_str += extra_string
            end_time = time()
            fianl_log = "Epoch{} {:.2f} {}".format(zfill_epoch,end_time-self._start_time,main_str.strip())
            # Gradient Log
            grad_norm_log = "\n [Grad] "+" ".join(["{}={:.6f}".format(metric.name,metric.result().numpy()) for metric in self.model._norm_metrics])            
            # Other Changing Variables
            lr_log = "\n [NoTrain] lr={:.6f},l1={:.6f},l2={:.6f}".format(
                float(self.model.optimizer.learning_rate),float(self.model._intv_l1),float(self.model._intv_l2))
            
            # Other figures about model
            intv_log = "\n [IntvNorm] C={:.6f},W={:.6f},B={:.6f}".format(
                float(tf.reduce_mean(tf.math.abs(self.model._intv_layer._center_left - self.model._intv_layer._center_right))),
                float(tf.reduce_mean(tf.math.abs(self.model._intv_layer.kernel))),
                float(tf.reduce_mean(tf.math.abs(self.model._intv_layer._bias)))
            )
            fianl_log += grad_norm_log+lr_log+intv_log
            
            self.print_func(fianl_log)
            
        # 어찌됐든 한번씩 무조건 업데이트해줘야함
        for norm_metric in self.model._norm_metrics:
            norm_metric.reset_state()
            
    def on_train_end(self,logs=None):
        ## 4-3. Hessian of activation of FCN
        self.model._cb_result["EndTime"] = time()
        self.print_func("End train: took {:.4f}secs".format(self.model._cb_result["EndTime"]-self.model._cb_result["StartTime"]))
        
class GradientRecorder(keras.callbacks.Callback):
    def __init__(self, X,y,loss_func,interval=1,max_epoch=-1,save_as="raw"):
        self._X= X
        self._y= y
        self._loss_func = loss_func
        self._interval  = interval
        self._max_epoch = max_epoch
        assert save_as in ("raw","sum")
        self._save_as = save_as
        
        self._epoch_result = {}
        self._loss_result = {}
        
    def on_epoch_end(self,epoch,logs=None):
        epoch += 1
        start_time = time()
    
        # 1보다 크고 interval에 해당ㅎ
        if epoch > self._max_epoch:
            pass
        else:
            if (epoch % self._interval == 0) or (epoch==1):
                zfill_epoch = str(epoch).zfill(4)
                with tf.GradientTape() as tape:
                    output = self.model(self._X,training=False)
                    loss_value = self._loss_func(self._y,output)
                gradients = tape.gradient(loss_value,{elem.name:elem for elem in self.model.trainable_variables})
                if self._save_as=="raw":
                    gradients = {key:value.numpy() for key,value in gradients.items()}
                elif self._save_as=="sum":
                    gradients = {key:tf.reduce_sum(tf.math.abs(tf.squeeze(value))).numpy() for key,value in gradients.items()}
                else:
                    pass
                self._epoch_result[epoch] = gradients
                self._loss_result[epoch] = logs
    def on_train_end(self,logs=None):
        self.model._epoch_gradient_result = {key:value for key,value in self._epoch_result.items()}
        self.model._epoch_log_result = {key:value for key,value in self._loss_result.items()}