import pathlib
import pickle
import re

import numpy as np
import tensorflow as tf
from tensorflow.keras import mixed_precision as prec

try:
    from tensorflow.python.distribute import values
except Exception:
    from google3.third_party.tensorflow.python.distribute import values

tf.tensor = tf.convert_to_tensor
for base in (tf.Tensor, tf.Variable, values.PerReplica):
    base.mean = tf.math.reduce_mean
    base.std = tf.math.reduce_std
    base.var = tf.math.reduce_variance
    base.sum = tf.math.reduce_sum
    base.any = tf.math.reduce_any
    base.all = tf.math.reduce_all
    base.min = tf.math.reduce_min
    base.max = tf.math.reduce_max
    base.abs = tf.math.abs
    base.logsumexp = tf.math.reduce_logsumexp
    base.transpose = tf.transpose
    base.reshape = tf.reshape
    base.astype = tf.cast


# values.PerReplica.dtype = property(lambda self: self.values[0].dtype)

# tf.TensorHandle.__repr__ = lambda x: '<tensor>'
# tf.TensorHandle.__str__ = lambda x: '<tensor>'
# np.set_printoptions(threshold=5, edgeitems=0)


class Module(tf.Module):
    def save(self, filename, verbose=True):
        values = tf.nest.map_structure(lambda x: x.numpy(), self.variables)
        amount = len(tf.nest.flatten(values))
        count = int(sum(np.prod(x.shape) for x in tf.nest.flatten(values)))
        if verbose:
            print(f"Save checkpoint with {amount} tensors and {count} parameters.")
        with pathlib.Path(filename).open("wb") as f:
            pickle.dump(values, f)

    def load(self, filename, verbose=True):
        with pathlib.Path(filename).open("rb") as f:
            values = pickle.load(f)
        amount = len(tf.nest.flatten(values))
        count = int(sum(np.prod(x.shape) for x in tf.nest.flatten(values)))
        if verbose:
            print(f"Load checkpoint with {amount} tensors and {count} parameters.")
        tf.nest.map_structure(lambda x, y: x.assign(y), self.variables, values)

    def get(self, name, ctor, *args, **kwargs):
        # Create or get layer by name to avoid mentioning it in the constructor.
        if not hasattr(self, "_modules"):
            self._modules = {}
        if name not in self._modules:
            self._modules[name] = ctor(*args, **kwargs)
        return self._modules[name]


class Optimizer(tf.Module):
    def __init__(
        self,
        name,
        lr,
        eps=1e-4,
        clip=None,
        wd=None,
        opt="adam",
        warmup=0,
        wd_pattern=r".*",
    ):
        assert 0 <= wd < 1
        assert not clip or 1 <= clip
        self._name = name
        self._clip = clip
        self._wd = wd
        self._wd_pattern = wd_pattern
        self._updates = tf.Variable(0, trainable=False, dtype=tf.int64)
        self._lr = lr
        if warmup:
            self._lr = lambda: lr * tf.clip_by_value(
                self._updates.astype(tf.float32) / warmup, 0.0, 1.0
            )
        self._opt = {
            "adam": lambda: tf.optimizers.Adam(self._lr, epsilon=eps),
            "nadam": lambda: tf.optimizers.Nadam(self._lr, epsilon=eps),
            "adamax": lambda: tf.optimizers.Adamax(self._lr, epsilon=eps),
            "sgd": lambda: tf.optimizers.SGD(self._lr),
            "momentum": lambda: tf.optimizers.SGD(self._lr, 0.9),
        }[opt]()
        self._mixed = prec.global_policy().compute_dtype == tf.float16
        if self._mixed:
            self._opt = prec.LossScaleOptimizer(self._opt, dynamic=True)
        self._once = True

    @property
    def variables(self):
        return self._opt.variables()

    def __call__(self, tape, loss, modules):
        assert loss.dtype is tf.float32, (self._name, loss.dtype)
        assert len(loss.shape) == 0, (self._name, loss.shape)
        metrics = {}

        # Find variables.
        modules = modules if hasattr(modules, "__len__") else (modules,)
        varibs = tf.nest.flatten([module.variables for module in modules])
        count = sum(np.prod(x.shape) for x in varibs)
        if self._once:
            print(f"Found {count} {self._name} parameters.")
            self._once = False

        # Check loss.
        tf.debugging.check_numerics(loss, self._name + "_loss")
        metrics[f"{self._name}_loss"] = loss

        # Compute scaled gradient.
        if self._mixed:
            with tape:
                loss = self._opt.get_scaled_loss(loss)
        grads = tape.gradient(loss, varibs)
        if self._mixed:
            grads = self._opt.get_unscaled_gradients(grads)
        if self._mixed:
            metrics[f"{self._name}_loss_scale"] = self._opt.loss_scale

        # Distributed sync.
        context = tf.distribute.get_replica_context()
        if context:
            grads = context.all_reduce("mean", grads)

        # Gradient clipping.
        norm = tf.linalg.global_norm(grads)
        if not self._mixed:
            tf.debugging.check_numerics(norm, self._name + "_norm")
        if self._clip:
            grads, _ = tf.clip_by_global_norm(grads, self._clip, norm)
        metrics[f"{self._name}_grad_norm"] = norm

        # Weight decay.
        if self._wd:
            self._apply_weight_decay(varibs)

        # Apply gradients.
        self._opt.apply_gradients(
            zip(grads, varibs), experimental_aggregate_gradients=False
        )
        self._updates.assign_add(1)

        return metrics

    def _apply_weight_decay(self, varibs):
        nontrivial = self._wd_pattern != r".*"
        # if nontrivial:
        #     print("Applied weight decay to variables:")
        for var in varibs:
            if re.search(self._wd_pattern, self._name + "/" + var.name):
                # if nontrivial:
                #     print("- " + self._name + "/" + var.name)
                var.assign((1 - self._wd) * var)
