# coding=utf-8
# Copyright 2023 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Train DPI."""
import datetime
import logging
import os
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
import time
from typing import Any, Callable

from absl import app
from absl import flags
import flax
from flax.training import checkpoints
import jax
import jax.numpy as jnp
from jaxtyping import PyTree
from ml_collections.config_flags import config_flags
import numpy as np
import optax
from PIL import Image
import tensorflow as tf

import datasets
import inference_utils
import utils
from posterior_sampling import losses
from posterior_sampling import model_utils

# DPI config.
_CONFIG = config_flags.DEFINE_config_file('config', None, 'DPI config.')
# Score-model config.
_SCORE_MODEL_CONFIG = config_flags.DEFINE_config_file(
    'score_model_config', None, 'Score-model config.')
# Data path.
_DATA_PATH = flags.DEFINE_string(
    'data_path', None, 'Path to NumPy array of test data.', required=True)
# Test image.
_IMAGE = flags.DEFINE_integer('image', None, 'Test image index.', required=True)


def train(train_step_fn,
          params,
          optimizer,
          n_iters):
  config = _CONFIG.value
  image_shape = (
      config.data.image_size, config.data.image_size, config.data.num_channels)
  inverse_scaler = datasets.get_data_inverse_scaler(config)

  # Construct training state.
  opt_state = optimizer.init(params)
  state = model_utils.GaussianState(
      step=0,
      opt_state=opt_state,
      params=params,
      data_weight=config.optim.lambda_data,
      prior_weight=config.optim.lambda_prior,
      entropy_weight=config.optim.lambda_entropy,
      rng=jax.random.PRNGKey(config.seed + 1)
  )

  p_train_step = jax.pmap(
      jax.jit(train_step_fn), axis_name='batch', donate_argnums=(1,))
  pstate = flax.jax_utils.replicate(state)
  # Create different random states for different processes in a
  # multi-host environment (e.g., TPU pods).
  rng = jax.random.fold_in(state.rng, jax.process_index())

  # JIT step.
  jit_start_time = time.perf_counter()
  rng, step_rngs = utils.psplit(rng)
  pstate, (ploss, ploss_data, ploss_prior,
            ploss_entropy), psamples = p_train_step(step_rngs, pstate)
  logging.info(
    'JIT time: %.2f seconds', time.perf_counter() - jit_start_time)

  start_time = time.perf_counter()
  for step in range(n_iters):
    step_start_time = time.perf_counter()
    # Update data weight.
    data_weight = losses.data_weight_fn(
        step,
        start_order=config.optim.lambda_data_start_order,
        decay_rate=config.optim.lambda_data_decay_steps,
        final_data_weight=config.optim.lambda_data)
    pstate = pstate.replace(
        data_weight=flax.jax_utils.replicate(data_weight))

    rng, step_rngs = utils.psplit(rng)
    pstate, (ploss, ploss_data, ploss_prior,
             ploss_entropy), psamples = p_train_step(step_rngs, pstate)

    logging.info('step %d: %.5f seconds', step + 1, time.perf_counter() - step_start_time)

  total_time = time.perf_counter() - start_time
  time_per_step = total_time / n_iters
  logging.info(
    'Avg. time/step (out of %d steps): %.8f', n_iters, time_per_step)

def main(_):
  config = _CONFIG.value
  score_model_config = _SCORE_MODEL_CONFIG.value
  inverse_scaler = datasets.get_data_inverse_scaler(config)
  sde, t0 = utils.get_sde(score_model_config)

  image_size = config.data.image_size
  image_shape = (image_size, image_size, config.data.num_channels)
  image_dim = np.prod(image_shape)

  if utils.is_coordinator():
    logging.info('[INFO] local device count = %d', jax.local_device_count())
    logging.info('[INFO] device count = %d', jax.device_count())
    logging.info('[INFO] process count = %d', jax.process_count())

  # Get likelihood module.
  likelihood = inference_utils.get_likelihood(config)

  # Get true image.
  true_image = np.load(_DATA_PATH.value)[_IMAGE.value]

  # Determine noise scale.
  # Negligible noise:
  # kspace = np.fft.fft2(true_image, axes=(1, 2), norm='ortho')
  # dc_amplitude = np.abs(kspace[0][0][0])
  # noise_sigma = dc_amplitude * 0.001
  # config.likelihood.noise_scale = noise_sigma
  # likelihood = inference_utils.get_likelihood(config)

  # Regular noise:
  kspace = np.fft.fft2(true_image, axes=(0, 1), norm='ortho')
  dc_amplitude = np.abs(kspace[0][0][0])
  noise_sigma = dc_amplitude * 0.0005  # corresponds to maximum SNR of 40 dB
  config.likelihood.noise_scale = noise_sigma
  likelihood = inference_utils.get_likelihood(config)

  # Get measurements.
  y = likelihood.get_measurement(jax.random.PRNGKey(0), true_image[None, :, :, :])
  naive_image = np.array(likelihood.invert_measurement(y)[0])

  # Initialize Gaussian parameters.
  params = {
    'mean': jnp.ones(image_dim) * config.gauss.mean_init_scale,
    'std': jnp.ones(image_dim) * config.gauss.std_init_scale
  }

  # Create optimizer.
  if config.optim.warmup > 0:
    schedule = optax.linear_schedule(
        init_value=0, end_value=config.optim.learning_rate,
        transition_steps=config.optim.warmup)
  else:
    schedule = optax.constant_schedule(config.optim.learning_rate)
  optimizer = optax.adam(
      learning_rate=schedule,
      b1=config.optim.adam_beta1, b2=config.optim.adam_beta2,
      eps=config.optim.adam_eps)

  # Data loss function.
  data_loss_fn = losses.get_data_loss_fn(likelihood, y)

  # Prior loss function.
  # Get `ProbabilityFlow` module.
  if config.optim.prior == 'ode':
    prob_flow = inference_utils.get_prob_flow(config, score_model_config)
  else:
    prob_flow = None
  # Get `score_fn`.
  if config.optim.prior in ['dsm', 'sm']:
    score_fn = inference_utils.get_score_fn(config, score_model_config)
  else:
    score_fn = None
  prior_loss_fn = losses.get_prior_loss_fn(
    config, score_fn=score_fn, sde=sde, prob_flow=prob_flow,
    t0=t0, t1=sde.T, dt0=config.prob_flow.dt0)

  # Get step function.
  train_step_fn = losses.get_gaussian_train_step_fn(
      config, optimizer, data_loss_fn, prior_loss_fn)

  start_time = time.perf_counter()
  train(train_step_fn, params, optimizer, n_iters=config.training.n_iters)


if __name__ == '__main__':
  app.run(main)