
"""VAE model definitions."""

from flax import linen as nn
from jax import random
import jax.numpy as jnp
from functools import partial
import jax



# Group conic activation
@partial(jax.jit, static_argnames=['channel_axis','variant','eps','num_groups','project_axes','share_axis'])
def group_colu(input, channel_axis = -1, variant = "soft", eps = 1e-7, num_groups = 100,project_axes = False, share_axis = True):
    """project the input x onto the axes dimension"""
    """output dimension = S = axes + cone sections = [len=(G or 1)] + G * [len=(S-1)]"""
    if num_groups == 0: # trivial case
        return input
    num_channels = input.shape[channel_axis]
    if (share_axis and num_groups == num_channels - 1) or (not share_axis and num_groups * 2 == num_channels): # pointwise case
        return nn.silu(input) if variant == "soft" else nn.relu(input)
    group_size = (num_channels - 1) // num_groups + 1 if share_axis else num_channels // num_groups
        
    # y = axes, x = cone sections
    if share_axis:
        assert (num_channels - 1) % num_groups == 0, "Channel size must be a multiple of number of cones plus one"
        y, x = input.take(jnp.arange(1), axis=channel_axis), input.take(jnp.arange(1,num_channels), axis=channel_axis)
    else:
        assert num_channels % num_groups == 0, "Channel size must be a multiple of number of cones"
        y, x = input.take(jnp.arange(num_groups), axis=channel_axis), input.take(jnp.arange(num_groups,num_channels), axis=channel_axis)
        group_size = num_channels // num_groups # S = C / G

    assert channel_axis < 0, "channel_axis must be negative" # Comply with broadcasting on first dimensions
    x_old_shape = x.shape
    y_old_shape = y.shape
    x_shape = x.shape[:channel_axis] + (num_groups, group_size - 1) # NG(S-1)
    if share_axis:
        y_shape = y.shape[:channel_axis] + (1, 1) # N11
    else: 
        y_shape = y.shape[:channel_axis] + (num_groups, 1) # NG1
    if channel_axis < -1:
        x_shape += x.shape[(channel_axis+1):] # NGSHW if channel_axis = -3
        y_shape += y.shape[(channel_axis+1):] # NG1HW
    x = x.reshape(x_shape)
    y = y.reshape(y_shape)
    xn = jnp.linalg.norm(x,axis=channel_axis,keepdims=True) # NG1HW, norm

    if project_axes:
        assert not share_axis, "shuffle_axes is not compatible with share_axis"
        y0, y1 = y.take(jnp.arange(1), axis=channel_axis-1), y.take(jnp.arange(1,num_groups), axis=channel_axis-1) # N11HW, N(G-1)1HW
        yn = jnp.linalg.norm(y1,axis=channel_axis-1,keepdims=True) # N11HW
        ymask = y0 / (yn + eps) # N11HW
        ymask = nn.sigmoid(ymask-.5) if variant == "soft" else ymask.clip(0,1)
        y1 = ymask * y1 # N(G-1)1HW
        y = jnp.concatenate([y0,y1],axis=channel_axis-1)
    
    mask = y / (xn + eps) # NG1HW
    if variant == "softmax":
        mask = nn.softmax(mask, axis=channel_axis)
    elif variant == "softapprox":
        mask = nn.sigmoid(4 * mask - 2)
    elif variant == "soft":
        mask = nn.sigmoid(mask - .5)
    elif variant == "hard":
        mask = mask.clip(0,1)
    else:
        raise NotImplementedError("variant must be soft or hard.")

    x = mask * x # NGSHW
    x = x.reshape(x_old_shape)
    y = y.reshape(y_old_shape)
    output = jnp.concatenate([y,x],axis=channel_axis)

    return output


# # Multi-level Group conic activation # not quite working though
#     y, x = x.take(jnp.arange(dim_cone ** level), axis=-1), x.take(jnp.arange(level,dim_cone), axis=-1)
#     for level in range(depth):
#         if num_channels % dim_cone ** level == 0:
#             break
#         # axes, cone sections, rest
#         section_dim_from, section_dim_to = dim_cone ** level,dim_cone ** (level+1)
#         z, x = x.take(jnp.arange(section_dim_from, section_dim_to), axis=-1), x.take(jnp.arange(section_dim_to, num_channels), axis=-1)
#         zn = jnp.linalg.norm(x,axis=channel_axis,keepdims=True) # NG1HW, per-group norm, or the S dimension
#         mask = y / (zn + eps) # NG1HW


# def group_colu(x):
#     """spherical normalization"""
#     xn = jnp.linalg.norm(x,axis=-1,keepdims=True) # NG1HW, per-group norm, or the S dimension
#     return x / xn

class Encoder(nn.Module):
  """VAE Encoder."""

  latents: int = 500
  model_name: str = "cifar10"
  act_fn: str = "silu"
  num_groups: int = 100
  variant: str = "soft"
  share_axis: bool = True

  @nn.compact
  def __call__(self, x):
    x = nn.Dense(self.latents, name='fc1')(x)
    x = group_colu(x,num_groups=self.num_groups,variant=self.variant,share_axis=self.share_axis) if self.act_fn == "colu" else nn.silu(x)
    mean_x = nn.Dense(20, name='fc2_mean')(x)
    logvar_x = nn.Dense(20, name='fc2_logvar')(x)
    return mean_x, logvar_x


class Decoder(nn.Module):
  """VAE Decoder."""
  latents: int = 500
  model_name: str = "cifar10"
  act_fn: str = "silu"
  num_groups: int = 100
  variant: str = "soft"
  share_axis: bool = True

  @nn.compact
  def __call__(self, z):
    z = nn.Dense(self.latents, name='fc1')(z)
    z = group_colu(z,num_groups=self.num_groups,variant=self.variant,share_axis=self.share_axis) if self.act_fn == "colu" else nn.silu(z)
    C_out = 3072 if self.model_name == 'cifar10' else 784
    z = nn.Dense(C_out, name='fc2')(z)
    return z


class VAE(nn.Module):
  """Full VAE model."""

  latents: int = 500
  model_name: str = "cifar10"
  act_fn: str = "silu"
  num_groups: int = 100
  variant: str = "soft"
  share_axis: bool = True

  def setup(self):
    self.encoder = Encoder(self.latents,self.model_name,self.act_fn,num_groups=self.num_groups,variant=self.variant,share_axis=self.share_axis)
    self.decoder = Decoder(self.latents,self.model_name,self.act_fn,num_groups=self.num_groups,variant=self.variant,share_axis=self.share_axis)

  def __call__(self, x, z_rng):
    mean, logvar = self.encoder(x)
    z = reparameterize(z_rng, mean, logvar)
    recon_x = self.decoder(z)
    return recon_x, mean, logvar

  def generate(self, z):
    return nn.sigmoid(self.decoder(z))


def reparameterize(rng, mean, logvar):
  std = jnp.exp(0.5 * logvar)
  eps = random.normal(rng, logvar.shape)
  return mean + eps * std


def create_model(latents,model_name="binarized_mnist",act_fn="silu",num_groups=100,variant="soft",share_axis=True):
  return VAE(latents=latents,model_name=model_name,act_fn=act_fn,num_groups=num_groups,variant=variant,share_axis=share_axis)


"""Input pipeline for VAE dataset."""

import jax
import jax.numpy as jnp
import tensorflow as tf
import tensorflow_datasets as tfds


def build_train_set(batch_size, ds_builder):
  """Builds train dataset."""
  train_ds = ds_builder.as_dataset(split=tfds.Split.TRAIN)
  train_ds = train_ds.map(prepare_image, num_parallel_calls=tf.data.AUTOTUNE)
  train_ds = train_ds.cache()
  train_ds = train_ds.repeat()
  train_ds = train_ds.shuffle(50000)
  train_ds = train_ds.batch(batch_size)
  train_ds = iter(tfds.as_numpy(train_ds))
  return train_ds


def build_test_set(ds_builder):
  """Builds test dataset."""
  test_ds = ds_builder.as_dataset(split=tfds.Split.TEST)
  test_ds = test_ds.map(prepare_image, num_parallel_calls=tf.data.AUTOTUNE)
  test_ds = test_ds.batch(10000)
  test_ds = jnp.array(list(test_ds)[0])
  test_ds = jax.device_put(test_ds)
  return test_ds


def prepare_image(x):
  x = tf.cast(x['image'], tf.float32)
  # Binarize the image (omitted for binarized_mnist dataset)
  # x = tf.where(x > 127.5, 1.0, 0.0)
  # Flatten the image
  x = tf.reshape(x, (-1,))
  return x


"""Training and evaluation logic."""

from absl import logging
import time
import flax
from flax import linen as nn
import flax.jax_utils
from flax.training import train_state
import flax.training
import flax.training.common_utils
import utils as vae_utils
import jax
from jax import random
import jax.numpy as jnp
import ml_collections
import optax
import tensorflow_datasets as tfds
from tqdm import tqdm
from tensorboardX import SummaryWriter
import functools
import yaml

@jax.vmap
def kl_divergence(mean, logvar):
  return -0.5 * jnp.sum(1 + logvar - jnp.square(mean) - jnp.exp(logvar))

@jax.vmap
def binary_cross_entropy_with_logits(logits, labels):
  logits = nn.log_sigmoid(logits)
  return -jnp.sum(
      labels * logits + (1.0 - labels) * jnp.log(-jnp.expm1(logits))
  )

def compute_metrics(recon_x, x, mean, logvar):
  bce_loss = binary_cross_entropy_with_logits(recon_x, x).mean()
  kld_loss = kl_divergence(mean, logvar).mean()
  return {'bce': bce_loss, 'kld': kld_loss, 'loss': bce_loss + kld_loss}

# define evaluation function separately for calling it outside
def eval_f(params, raw_images, z, z_rng, model, config):
  def eval_model(vae):
    images = raw_images / 255.0 if config.dataset_name == 'cifar10' else raw_images
    recon_images, mean, logvar = vae(images, z_rng)
    image_shape = (32,32,3) if config.dataset_name == 'cifar10' else (28,28,1)
    comparison = jnp.concatenate([
        images[:8].reshape(-1, *image_shape),
        recon_images[:8].reshape(-1, *image_shape),
    ])

    generate_images = vae.generate(z)
    generate_images = generate_images.reshape(-1, *image_shape)
    metrics = compute_metrics(recon_images, images, mean, logvar)
    return metrics, comparison, generate_images

  return nn.apply(eval_model, model)({'params': params})
  
def train_and_evaluate(config: ml_collections.ConfigDict):
  """Train and evaulate pipeline."""
  rng = random.PRNGKey(config.seed)
  rng, key = random.split(rng)

  ds_builder = tfds.builder(config.dataset_name)
  ds_builder.download_and_prepare()

  logging.info('Initializing dataset.')
  train_ds = build_train_set(config.batch_size * 8, ds_builder)
  test_ds = build_test_set(ds_builder)
  test_ds = flax.training.common_utils.shard(test_ds) # 10000 = 8 * 1250

  logging.info('Initializing model.')
  dim = 3072 if config.dataset_name == 'cifar10' else 784
  init_data = jnp.ones((config.batch_size, dim), jnp.float32)
  model = create_model(config.latents, config.dataset_name, config.act_fn,num_groups=config.num_groups,variant=config.variant,share_axis=config.share_axis)
  params = model.init(key, init_data, rng)['params']
  state = train_state.TrainState.create(
      apply_fn=model.apply,
      params=params,
      # tx=optax.sgd(config.learning_rate),
      tx=optax.adam(config.learning_rate),
  )
  state = flax.jax_utils.replicate(state)

  # define train step
  def train_step(state, batch, z_rng):
    def loss_fn(params):
      recon_x, mean, logvar = model.apply(
          {'params': params}, batch, z_rng
      )

      bce_loss = binary_cross_entropy_with_logits(recon_x, batch).mean()
      kld_loss = kl_divergence(mean, logvar).mean()
      loss = bce_loss + kld_loss
      return loss

    loss, grads = jax.value_and_grad(loss_fn)(state.params)
    grad_norm = optax.global_norm(grads)
    return state.apply_gradients(grads=grads), grad_norm, loss

  # broadcast train step
  train_step_pmap = jax.pmap(train_step, "batch", donate_argnums=(0,))

  # broadcast eval step
  eval_f_pmap = jax.pmap(functools.partial(eval_f,model=model,config=config))

  # initialize
  rng, z_key, eval_rng = random.split(rng, 3)
  eval_rng = random.split(eval_rng,jax.local_device_count())
  z = random.normal(z_key, (64, 20)) # fixed z for visualization of samples
  z = flax.jax_utils.replicate(z)
  steps_per_epoch = (
      ds_builder.info.splits['train'].num_examples // config.batch_size
  )

  # train
  log_dir = f'log/{config.act_fn}_shareaxis_{config.share_axis}_{config.variant}/C{config.latents}_group{config.num_groups}/'
  run_dir = log_dir+f'{time.strftime("%m_%d_%H:%M", time.localtime())}/'
  writer = SummaryWriter(run_dir)
  with open(run_dir+"config.yaml", 'w') as f:
    yaml.dump(config.to_dict(), f)
  checkpoint_time = time.time()
  foobar = tqdm(range(config.num_epochs),desc=f"{config.variant}, shareaxis={config.share_axis}, C={config.latents}, group={config.num_groups}",position=0,leave=True)
  min_loss = float(1e9)
  for epoch in foobar:
    bar = tqdm(range(steps_per_epoch), position=1, leave=False)
    for _ in bar:
      batch = next(train_ds) 
      batch = batch / 255.0 if config.dataset_name == 'cifar10' else batch
      batch = flax.training.common_utils.shard(batch)
      rng, key = random.split(rng)
      key = random.split(key, jax.local_device_count())
      state, grad_norm, loss = train_step_pmap(state, batch, key)
      grad_norm, loss = flax.jax_utils.unreplicate((grad_norm, loss))
      min_loss = min(min_loss, loss.item())
      bar.set_postfix({"grad_norm":grad_norm,"loss":loss})
      writer.add_scalar('train/grad_norm', grad_norm, epoch * steps_per_epoch + _)
      writer.add_scalar('train/loss', loss, epoch * steps_per_epoch + _)

    bar.close()

    # save model every 15 minutes for longer training
    if time.time() - checkpoint_time > 900:
      vae_utils.save_model(jax.device_get(flax.jax_utils.unreplicate(state.params)), f'results/{config.act_fn}_group{config.num_groups}.msgpack')
      checkpoint_time = time.time()

    metrics, comparison, sample = eval_f_pmap(state.params, test_ds, z, eval_rng)
    metrics, comparison, sample = flax.jax_utils.unreplicate((metrics, comparison, sample))

    foobar.set_postfix(metrics)
    writer.add_scalar('eval/loss', metrics['loss'], epoch)
    writer.add_scalar('eval/bce', metrics['bce'], epoch)
    writer.add_scalar('eval/kld', metrics['kld'], epoch)
    if epoch % 10 == 0:
        writer.add_images('eval/sample', sample, epoch, dataformats='NHWC')

  # save
  params_offdevice = jax.device_get(flax.jax_utils.unreplicate(state.params)) # don't forget to unreplicate
  ckpt_dir = log_dir+f'seed{config.seed}.msgpack'
  vae_utils.save_model(params_offdevice, ckpt_dir)
  return min_loss

"""Default Hyperparameter configuration."""

import ml_collections
def get_config():
  """Get the default hyperparameter configuration."""
  config = ml_collections.ConfigDict()
  config.dataset_name = 'mnist'#"cifar10"#
  config.model_name = "mnist"
  config.act_fn = "colu"
  config.variant = "soft"
  config.share_axis = True
  config.learning_rate = 1e-3
  config.latents = 500
  config.num_groups = 100
  config.batch_size = 128
  config.num_epochs = 100
  config.seed = 0
  return config

"""Main file for running the VAE example.

The orgininal file is intentionally kept short. (But merged into this single file.) The majority for logic is in libraries
that can be easily tested and imported in Colab.
"""

from absl import app 
# import jax
import tensorflow as tf

# The following script need a config.py and uses running flag --config=config.py
# from clu import platform
# from ml_collections import config_flags
# from absl import flags
# FLAGS = flags.FLAGS

# config_flags.DEFINE_config_file(
#     'config',
#     None,
#     'File path to the training hyperparameter configuration.',
#     lock_config=True,
# )

def main(argv):
  # if len(argv) > 1:
  #   raise app.UsageError('Too many command-line arguments.')

  # Make sure tf does not allocate gpu memory.
  # tf.config.experimental.set_visible_devices([], 'GPU')

  # logging.info('JAX process: %d / %d', jax.process_index(), jax.process_count())
  # logging.info('JAX local devices: %r', jax.local_devices())

  # Add a note so that we can tell which task is which JAX host.
  # (Depending on the platform task 0 is not guaranteed to be host 0)
  # platform.work_unit().set_task_status(
  #     f'process_index: {jax.process_index()}, '
  #     f'process_count: {jax.process_count()}'
  # )
  config = get_config()
  train_and_evaluate(config)

if __name__ == '__main__':
  app.run(main)
