# Copyright 2024 The Flax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Copyright 2023 The Flax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Training and evaluation logic."""

from absl import logging
import time
import flax
from flax import linen as nn
import flax.jax_utils
from flax.training import train_state
import flax.training
import flax.training.common_utils
import input_pipeline
import models
import utils as vae_utils
import jax
from jax import random
import jax.numpy as jnp
import ml_collections
import optax
import tensorflow_datasets as tfds
from tqdm import tqdm
from tensorboardX import SummaryWriter
import functools
import yaml

@jax.vmap
def kl_divergence(mean, logvar):
  return -0.5 * jnp.sum(1 + logvar - jnp.square(mean) - jnp.exp(logvar))

@jax.vmap
def binary_cross_entropy_with_logits(logits, labels):
  logits = nn.log_sigmoid(logits)
  return -jnp.sum(
      labels * logits + (1.0 - labels) * jnp.log(-jnp.expm1(logits))
  )

def compute_metrics(recon_x, x, mean, logvar):
  bce_loss = binary_cross_entropy_with_logits(recon_x, x).mean()
  kld_loss = kl_divergence(mean, logvar).mean()
  return {'bce': bce_loss, 'kld': kld_loss, 'loss': bce_loss + kld_loss}

# define evaluation function separately for calling it outside
def eval_f(params, raw_images, z, z_rng, model, config):
  def eval_model(vae):
    images = raw_images / 255.0 if config.dataset_name == 'cifar10' else raw_images
    recon_images, mean, logvar = vae(images, z_rng)
    image_shape = (32,32,3) if config.dataset_name == 'cifar10' else (28,28,1)
    comparison = jnp.concatenate([
        images[:8].reshape(-1, *image_shape),
        recon_images[:8].reshape(-1, *image_shape),
    ])

    generate_images = vae.generate(z)
    generate_images = generate_images.reshape(-1, *image_shape)
    metrics = compute_metrics(recon_images, images, mean, logvar)
    return metrics, comparison, generate_images

  return nn.apply(eval_model, model)({'params': params})
  
def train_and_evaluate(config: ml_collections.ConfigDict):
  """Train and evaulate pipeline."""
  rng = random.PRNGKey(config.seed)
  rng, key = random.split(rng)

  ds_builder = tfds.builder(config.dataset_name)
  ds_builder.download_and_prepare()

  logging.info('Initializing dataset.')
  train_ds = input_pipeline.build_train_set(config.batch_size * 8, ds_builder)
  test_ds = input_pipeline.build_test_set(ds_builder)
  test_ds = flax.training.common_utils.shard(test_ds) # 10000 = 8 * 1250

  logging.info('Initializing model.')
  dim = 3072 if config.dataset_name == 'cifar10' else 784
  init_data = jnp.ones((config.batch_size, dim), jnp.float32)
  model = models.model(config.latents, config.dataset_name, config.act_fn,num_groups=config.num_groups,variant=config.variant,share_axis=config.share_axis)
  params = model.init(key, init_data, rng)['params']
  state = train_state.TrainState.create(
      apply_fn=model.apply,
      params=params,
      # tx=optax.sgd(config.learning_rate),
      tx=optax.adam(config.learning_rate),
  )
  state = flax.jax_utils.replicate(state)

  # define train step
  def train_step(state, batch, z_rng):
    def loss_fn(params):
      recon_x, mean, logvar = model.apply(
          {'params': params}, batch, z_rng
      )

      bce_loss = binary_cross_entropy_with_logits(recon_x, batch).mean()
      kld_loss = kl_divergence(mean, logvar).mean()
      loss = bce_loss + kld_loss
      return loss

    loss, grads = jax.value_and_grad(loss_fn)(state.params)
    grad_norm = optax.global_norm(grads)
    return state.apply_gradients(grads=grads), grad_norm, loss

  # broadcast train step
  train_step_pmap = jax.pmap(train_step, "batch", donate_argnums=(0,))

  # broadcast eval step
  eval_f_pmap = jax.pmap(functools.partial(eval_f,model=model,config=config))

  # initialize
  rng, z_key, eval_rng = random.split(rng, 3)
  eval_rng = random.split(eval_rng,jax.local_device_count())
  z = random.normal(z_key, (64, 20)) # fixed z for visualization of samples
  z = flax.jax_utils.replicate(z)
  steps_per_epoch = (
      ds_builder.info.splits['train'].num_examples // config.batch_size
  )

  # train
  log_dir = f'log/{config.act_fn}_shareaxis_{config.share_axis}_{config.variant}/C{config.latents}_group{config.num_groups}/'
  run_dir = log_dir+f'{time.strftime("%m_%d_%H:%M", time.localtime())}/'
  writer = SummaryWriter(run_dir)
  with open(run_dir+"config.yaml", 'w') as f:
    yaml.dump(config.to_dict(), f)
  checkpoint_time = time.time()
  foobar = tqdm(range(config.num_epochs),desc=f"{config.variant}, shareaxis={config.share_axis}, C={config.latents}, group={config.num_groups}",position=0,leave=True)
  min_loss = float(1e9)
  for epoch in foobar:
    bar = tqdm(range(steps_per_epoch), position=1, leave=False)
    for _ in bar:
      batch = next(train_ds) 
      batch = batch / 255.0 if config.dataset_name == 'cifar10' else batch
      batch = flax.training.common_utils.shard(batch)
      rng, key = random.split(rng)
      key = random.split(key, jax.local_device_count())
      state, grad_norm, loss = train_step_pmap(state, batch, key)
      grad_norm, loss = flax.jax_utils.unreplicate((grad_norm, loss))
      min_loss = min(min_loss, loss.item())
      bar.set_postfix({"grad_norm":grad_norm,"loss":loss})
      writer.add_scalar('train/grad_norm', grad_norm, epoch * steps_per_epoch + _)
      writer.add_scalar('train/loss', loss, epoch * steps_per_epoch + _)

    bar.close()

    # save model every 15 minutes for longer training
    if time.time() - checkpoint_time > 900:
      vae_utils.save_model(jax.device_get(flax.jax_utils.unreplicate(state.params)), f'results/{config.act_fn}_group{config.num_groups}.msgpack')
      checkpoint_time = time.time()

    metrics, comparison, sample = eval_f_pmap(state.params, test_ds, z, eval_rng)
    metrics, comparison, sample = flax.jax_utils.unreplicate((metrics, comparison, sample))

    foobar.set_postfix(metrics)
    writer.add_scalar('eval/loss', metrics['loss'], epoch)
    writer.add_scalar('eval/bce', metrics['bce'], epoch)
    writer.add_scalar('eval/kld', metrics['kld'], epoch)
    if epoch % 10 == 0:
        writer.add_images('eval/sample', sample, epoch, dataformats='NHWC')

  # save
  params_offdevice = jax.device_get(flax.jax_utils.unreplicate(state.params)) # don't forget to unreplicate
  ckpt_dir = log_dir+f'seed{config.seed}.msgpack'
  vae_utils.save_model(params_offdevice, ckpt_dir)
  return min_loss