import numpy as np
import tensorflow as tf
from tensorflow import keras

from microsoft_nlp.utils import count_embedding_params


# Learning rate schedule
class CustomLearningRateScheduler(keras.callbacks.Callback):
    def __init__(
        self,
        decay_steps=250000,
        warmup_steps=3000,
        lr_scale=1,
        initial_step=0,
        non_embedding_params=None,
        alpha=0.0001,
    ):
        super(CustomLearningRateScheduler, self).__init__()
        self.decay_steps = tf.cast(decay_steps, tf.float32)
        self.warmup_steps = tf.cast(warmup_steps, tf.float32)
        self.lr_scale = lr_scale
        self.non_embedding_params = non_embedding_params
        self.alpha = alpha  # alpha * base_lr is the asymptotic learning rate

        self.base_lr = 0.0  # learning rate at the end of linear warmup
        self.slope = 0.0  # slope of linear warmup
        self.total_steps = initial_step

    def on_train_begin(self, logs=None):
        # Set the base learning rate (i.e lr after the warmup) as described in the scaling
        # laws paper: LR(N) ~ 0.003239 - 0.0001395 * log(N), where
        # 'N' is the number of non-embedding parameters.

        if self.non_embedding_params is None:
            embedding_params = count_embedding_params(self.model)
            self.non_embedding_params = self.model.count_params() - embedding_params

        N = tf.cast(self.non_embedding_params, tf.float32)

        # Compute the base learning rate (lr at the end of warmup)
        self.base_lr = 0.003239 - 0.0001395 * tf.math.log(N)
        # tf.keras.backend.set_value(self.model.optimizer.lr, scheduledlr)
        # print("Initial Learning rate is %6.4f." % (scheduled_lr))

        # Compute the slope of the linear warmup
        self.slope = self.base_lr / self.warmup_steps

    def on_epoch_end(self, epoch, logs=None):
        if epoch == 80:
            self.lr_scale = 1

    def on_train_batch_begin(self, batch, logs=None):
        scheduled_lr = self.schedule(self.total_steps)

        # Set the value back to the optimizer before this epoch starts
        tf.keras.backend.set_value(self.model.optimizer.lr, scheduled_lr)

        # print("\nStep %d: Learning rate is %.8f" % (self.total_steps, scheduled_lr))
        self.total_steps += 1

    def schedule(self, step):
        # Use a learning rate schedule with a linear warmup followed by
        # a cosine decay to zero
        if step < self.warmup_steps:
            lr = self.slope * step
        else:
            step = tf.math.minimum(step - self.warmup_steps, self.decay_steps)
            step = tf.cast(step, tf.float32)
            cosine_decay = 0.5 * (
                1.0 + tf.math.cos(tf.constant(np.pi) * step / self.decay_steps)
            )
            decay = (1 - self.alpha) * cosine_decay + self.alpha
            lr = self.base_lr * decay

        return self.lr_scale * lr


def learning_scale_from_string(raw_string, batch_size, effective_batch_size):
    string = raw_string.strip().lower()
    if string.startswith("batch"):
        if string == "batch":
            batch_factor = 1
        elif string.startswith("batch-"):
            substring = string[6:]
            digits = "0123456789."
            assert all(s in digits for s in substring)
            batch_factor = float(substring)
        else:
            raise NotImplementedError("learning_scale must be 'batch-##.##'")

        # see https://arxiv.org/pdf/2006.09092.pdf
        return batch_factor * np.sqrt(batch_size / effective_batch_size)

    try:
        return float(string)
    except ValueError:
        raise NotImplementedError(
            f"learning_scale must be a float or 'batch', got: {raw_string}"
        )
