import tensorflow as tf
import tensorflow_compression as tfc
# from common.constrained_opt_lib import ConstraintManager
# from common.transforms import QuadAnalysis, QuadSynthesis, HyperAnalysis, HyperSynthesis
# from common.transforms import class_builder as transform_builder  # Just use the default class map.
from common.utils import ClassBuilder
# from common.latent_rvs_lib import UQLatentRV, LatentRVCollection
# from common.latent_rvs_utils import sga_schedule_at_step
# from common.immutabledict import immutabledict
from common.image_utils import mse_psnr
from common import data_lib
from common import tf_schedule as schedule
# from common import profile_utils
from common.custom_metrics import Metrics
from collections import OrderedDict
from ml_collections import ConfigDict
from absl import logging
from rdvae.nn_models import get_activation, make_mlp
import rdvae.tfc_utils as tfc_utils
from common.utils import get_keras_optimizer, pairwise_dist_squared
import tensorflow_probability as tfp
from nerd.dcgan import Generator

tfd = tfp.distributions
tfb = tfp.bijectors

# EMPTY_DICT = immutabledict()
EMPTY_DICT = {}

CODING_RANK = 1


# Based on bounding-rd/uncons_gop.py
def batch_distortion(x, y, distort_type='mse'):
  """
  Given two batches of ndarrays, compute pairwise distortion.
  Largely based on plot_utils.constraint_fun_mse; although I now find a transposed result shape ([B, P])
  more natural for when chunking.
  :param x: [B, d1, d2, ...]
  :param y: [P, d1, d2, ...]
  :param chunksize: int, should be able to evenly divide len(x). If provided, will divide the x batch into 'chunks'
  (minibatches) with the specified chunksize, to avoid running out of memory when computing MSE on large data tensors.
  :return:  [B, P] tensor of pairwise MSE.
  """
  x_shape = tf.shape(x)
  y_shape = tf.shape(y)
  # tf.assert_equal(x_shape[1:], y_shape[1:])
  vx = tf.reshape(x, [x_shape[0], -1])
  vy = tf.reshape(y, [y_shape[0], -1])
  norm_diff_squared = pairwise_dist_squared(vx, vy)

  if distort_type == 'mse':
    data_dim = tf.cast(vx.shape[1], vx.dtype)
    distortions = norm_diff_squared / data_dim
  elif distort_type == 'sse':
    distortions = norm_diff_squared
  elif distort_type == 'half_sse':
    distortions = norm_diff_squared * 0.5
  else:
    raise NotImplementedError
  return distortions  # [B, P]


# Encapsulates model + optimizer.
# It's also possible to inherit from tf.keras.Model, although it might make the model construction
# code more cumbersome (from what I've seen, tf.keras.Model must define build() and call(),
# and these methods must be runnable in graph mode), but can simplify distributed training since
# it's already integrated into Model.fit,
# https://www.tensorflow.org/guide/distributed_training#use_tfdistributestrategy_with_keras_modelfit
# (whereas adding support for dstributed training to a custom training loop can take more work).
class Model(tf.Module):
  def __init__(self,
               rd_lambda,
               latent_dim,
               distort_type,
               num_samples,
               data_dim=None,  # Only used by MLP decoder.
               scheduled_num_steps=5000,
               # laplace_tail_mass=0,
               # offset_heuristic=True,
               # encoder_units=[],
               # decoder_units=[],
               # posterior_type='gaussian',
               # prior_type='deep',
               # ar_hidden_units=[],
               # ar_activation=None,
               transform_config=EMPTY_DICT,
               optimizer_config=EMPTY_DICT,
               dtype='float32'):
    super().__init__()
    self.latent_dim = latent_dim
    self.data_dim = data_dim
    self.distort_type = distort_type
    self._scheduled_num_steps = scheduled_num_steps
    self._rd_lambda = rd_lambda
    self.num_samples = num_samples
    self.dtype = dtype

    # Set up lr and optimizer
    self._optimizer_config = optimizer_config
    optimizer, lr_schedule_fn = self._get_optimizer(self._optimizer_config,
                                                    self._scheduled_num_steps)
    # self.compile(optimizer=optimizer)  # This sets self.optimizer and readies the model for training.
    self.optimizer = optimizer
    self._lr_schedule_fn = lr_schedule_fn

    # self._transform_config = transform_config
    # self._profile = profile
    self._init_transforms(transform_config)

  def _get_optimizer(self, optimizer_config, scheduled_num_steps):  # Note this overrides base.
    optimizer_config = dict(optimizer_config)  # Make a copy to avoid mutating the original.

    learning_rate = optimizer_config.pop("learning_rate", 1e-4)
    reduce_lr_after = optimizer_config.pop("reduce_lr_after", 0.8)
    reduce_lr_factor = optimizer_config.pop("reduce_lr_factor", 0.1)
    if "warmup_steps" in optimizer_config:
      warmup_steps = optimizer_config.pop("warmup_steps")
    else:
      warmup_until = optimizer_config.pop("warmup_until", 0.02)
      warmup_steps = int(warmup_until * scheduled_num_steps)
    warmup_start_step = optimizer_config.pop("warmup_start_step", 0)
    if "lr_drop_steps" in optimizer_config:
      # Specify a "multi-drop" lr schedule with explicit steps at which to drop lr.
      lr_schedule_fn = schedule.CustomDropCompressionSchedule(base_learning_rate=learning_rate,
                                                              total_num_steps=scheduled_num_steps,
                                                              drop_steps=optimizer_config.pop(
                                                                "lr_drop_steps"),
                                                              drop_factor=reduce_lr_factor,
                                                              warmup_steps=warmup_steps,
                                                              warmup_start_step=warmup_start_step,
                                                              )
    else:
      lr_schedule_fn = schedule.CompressionSchedule(base_learning_rate=learning_rate,
                                                    total_num_steps=scheduled_num_steps,
                                                    warmup_steps=warmup_steps,
                                                    warmup_start_step=warmup_start_step,
                                                    drop_after=reduce_lr_after,
                                                    drop_factor=reduce_lr_factor)
    optimizer_cls = get_keras_optimizer(optimizer_config.pop('name', 'adam'))
    optimizer = optimizer_cls(learning_rate=lr_schedule_fn, **optimizer_config)
    return optimizer, lr_schedule_fn

  def _init_transforms(self, transform_config=EMPTY_DICT):
    dtype = self.dtype
    self.__dict__.update(transform_config)
    # posterior_type = self.posterior_type
    latent_dim = self.latent_dim
    if self.decoder_type == 'mlp':
      decoder = make_mlp(
        self.decoder_units + [self.data_dim],
        get_activation(self.decoder_activation, dtype),
        "decoder",
        input_shape=[latent_dim],
        dtype=dtype,
      )
    else:
      # DCGAN
      decoder = Generator(self.img_shape, self.latent_dim, self.nn_size, self.decoder_activation)

    self.decoder = decoder

    # self.prior_type = prior_type
    self._prior = None
    if self.prior_type == "deep":
      self._prior = tfc_utils.MyDeepFactorized(
        batch_shape=[self.latent_dim], dtype=self.dtype)
    elif self.prior_type == 'std_gaussian':  # use 'gmm_1' for gaussian prior with learned mean/scale
      self._prior = tfd.MultivariateNormalDiag(loc=tf.zeros([self.latent_dim], dtype=self.dtype),
                                               scale_diag=tf.ones([self.latent_dim], dtype=self.dtype))
    elif self.prior_type[:4] in ("gsm_", "gmm_", "lsm_", "lmm_"):  # mixture prior; specified like 'gmm_2'
      # This only implements a scalar mixture for each dimension, and the dimensions themselves are
      # still fully factorized just like tfc.DeepFactorized
      components = int(self.prior_type[4:])
      shape = (self.latent_dim, components)
      self.logits = tf.Variable(tf.random.normal(shape, dtype=self.dtype))
      self.log_scale = tf.Variable(
        tf.random.normal(shape, mean=2., dtype=self.dtype))
      if "s" in self.prior_type:  # scale mixture
        self.loc = 0.
      else:
        self.loc = tf.Variable(tf.random.normal(shape, dtype=self.dtype))
    else:
      raise ValueError(f"Unknown prior_type: '{self.prior_type}'.")

  def prior(self, conv_unoise=False):
    if self._prior is not None:
      prior = self._prior
    elif self.prior_type[:4] in ("gsm_", "gmm_", "lsm_", "lmm_"):
      cls = tfd.Normal if self.prior_type.startswith("g") else tfd.Logistic
      prior = tfd.MixtureSameFamily(
        mixture_distribution=tfd.Categorical(logits=self.logits),
        components_distribution=cls(
          # loc=self.loc, scale=tf.math.exp(self.log_scale)),
          loc=self.loc, scale=tf.math.softplus(self.log_scale)),
      )
    if conv_unoise:  # convolve with uniform noise for NTC compression model
      prior = tfc_utils.MyUniformNoiseAdapter(prior)
    return prior

  @property
  def global_step(self):
    return self.optimizer.iterations

  @property
  def _scheduled_lr(self):
    # This is just for logging/debugging purpose. Should equal self._lr_schedule_fun(self.global_step)
    # Also see https://github.com/google-research/google-research/blob/bb5e979a2d9389850fda7eb837ef9c8b8ba8244b/vct/src/models.py#672
    return self.optimizer._decayed_lr(tf.float32)

  @property
  def _scheduled_rd_lambda(self):
    """Returns the scheduled rd-lambda.
    Based on https://github.com/google-research/google-research/blob/master/vct/src/models.py#L400
    """
    _rd_lambda = tf.convert_to_tensor(self._rd_lambda)
    # if self._rd_lambda <= 0.01:  # Only do lambda warmup during model training.
    #   schedule_value = schedule.schedule_at_step(
    #     self.global_step,
    #     vals=[HIGHER_LAMBDA_FACTOR, 1.],
    #     boundaries=[int(self._scheduled_num_steps * HIGHER_LAMBDA_UNTIL)],
    #     interpolation=schedule.InterpolationType.CONSTANT
    #   )
    #   schedule_value = _rd_lambda * schedule_value
    # else:
    #   schedule_value = _rd_lambda
    schedule_value = _rd_lambda
    return schedule_value

  def get_loss(self, x, training=None):
    y_samples = self.sample(self.num_samples)
    # if img_data:
    #   unnormalize imgs
    C = batch_distortion(x, y_samples, self.distort_type)  # [batchsize, num_samples]
    scaled_C = self._rd_lambda * C
    n = tf.cast(self.num_samples, 'float32')
    log_marg_x = tf.reduce_logsumexp(-scaled_C, axis=1, keepdims=True) - tf.math.log(n)  # - phi(x)
    loss = - tf.reduce_mean(log_marg_x)

    pi_density = tf.exp(-log_marg_x - scaled_C)  # optimal coupling density w.r.t product dist; [batchsize, num_samples]
    distortion = tf.reduce_mean(C * pi_density)
    # log_pi_density = -log_marg_x - scaled_C   # Original implementation below; can underflow
    # distortion = tf.reduce_mean(tf.exp(tf.math.log(C) + log_pi_density))
    rate = loss - self._rd_lambda * distortion
    record_dict = dict(loss=loss, distortion=distortion, rate=rate, scheduled_lr=self._scheduled_lr)

    metrics = Metrics.make()
    metrics.record_scalars(record_dict)
    # return dict(loss=loss, rate=rate, rates=rates, mse=distortion, y_tilde=y_tilde)
    return loss, metrics

  def train_step(self, batch):
    with tf.GradientTape() as tape:
      loss, metrics = self.get_loss(batch, training=True)

    var_list = self.trainable_variables
    gradients = tape.gradient(loss, var_list)
    self.optimizer.apply_gradients(zip(gradients, var_list))
    grad_norm = tf.linalg.global_norm(gradients)
    metrics.record_scalar('grad_norm', grad_norm)
    return metrics

  # def test_step(self, image_batch):
  def validation_step(self, batch, training=False) -> Metrics:
    loss, metrics = self.get_loss(batch, training=training)
    return metrics

  # def evaluate(self, images) -> Metrics:
  #   """
  #   Used for getting final results.
  #   If a [B, H, W, 3] tensor is provided, will evaluate on individual image
  #   tensors ([1, H, W, 3]) in order. Otherwise, we assume a caller has passed in
  #   an iterable of images (although we do not verify that each image tensor has
  #   batch size = 1).
  #   :param images:
  #   :return:
  #   """
  #   if isinstance(images, tf.Tensor):
  #     batch_size = images.shape[0]
  #     images = tf.split(images, batch_size)
  #   else:
  #     images = images
  #
  #   for img in images:
  #     loss, metrics = self.encode_decode(img, training=False)
  #     yield metrics

  def sample(self, num_samples):
    """
    Draw samples from the compression model.
    :param num_samples: int
    :return: a [num_samples, data_shape] tensor.
    """
    if self.prior_type == 'deep' and self.posterior_type == 'uniform':
      # sample from the discretized prior (as would be in actual entropy coding)
      prior = self.prior(
        conv_unoise=True)  # the prior is not actually convolved with unoise; this is just to get quantized samples
      samples = prior.sample(num_samples, quantized=True)
    else:
      samples = self.prior(conv_unoise=False).sample(num_samples)
    if self.decoder:
      samples = self.decoder(samples)
    return samples

  def create_images(self, num_samples=1000, coords_to_scatter=(0, 1), title=None):
    """
    Plot img metrics. Here we simply visualize samples from the model.
    :param coords_to_scatter:
    :param title:
    :return:
    """
    import matplotlib.pyplot as plt
    import plot_utils

    if isinstance(self.decoder, Generator):
      model_samples = self.sample(10)
      cmap = None
      vmin = vmax = None
      if self.decoder.img_shape[-1] == 1:
        cmap = 'gray'
        vmin, vmax = 0, 255
      fig = plot_utils.plot_float_imgs(model_samples, figsize=(12, 8), cmap=cmap, vmin=vmin, vmax=vmax)
      if title:
        plt.gca().set_title(title)

    else:
      model_samples = self.sample(num_samples)
      fig, ax = plt.subplots()
      i, j = coords_to_scatter
      ax.scatter(model_samples[:, i], model_samples[:, j], marker='.', alpha=0.3, label=r'$\nu$')

      # if nu_w is not None:
      #   # Make sure the color isn't too faint, since the nu_w can be very close to 0.
      #   min_w_to_plot = 0.45
      #   nu_w_c = (1 - min_w_to_plot) * nu_w + min_w_to_plot
      #   ax.scatter(nu_x[:, i], nu_x[:, j], c=nu_w_c, cmap='Oranges', vmin=0, vmax=1, marker='x', label=r'$\nu$')
      # else:
      #   ax.scatter(nu_x[:, i], nu_x[:, j], marker='x', label=r'$\nu$')

      ax.legend()
      # if not hasattr(self, '_scatter_xlim'):
      #   self._scatter_xlim = ax.get_xlim()
      # if not hasattr(self, '_scatter_ylim'):
      #   self._scatter_ylim = ax.get_ylim()
      # ax.set_xlim(self._scatter_xlim)
      # ax.set_ylim(self._scatter_ylim)
      ax.set_aspect('equal')

      if title:
        ax.set_title(title)

    img = plot_utils.fig_to_np_arr(fig)
    plt.close(fig)

    return {'samples': img}
