# coding=utf-8
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Training and evaluation"""

from absl import app
from absl import flags
from ml_collections.config_flags import config_flags
import logging
import os
import torch
import tensorflow as tf
import io
import time
import numpy as np

# Keep the import below for registering all model definitions
from models import ddpm, ncsnv2, ncsnpp
import sampling
from models import utils as mutils
from models.ema import ExponentialMovingAverage
import datasets
import sde_lib
from torchvision.utils import make_grid, save_image
from utils import restore_checkpoint
from models.utils import get_noise_fn
from samplers.dpm_solver_v3 import DPM_Solver_v3
import losses
from samplers.utils import NoiseScheduleVP
import functools

FLAGS = flags.FLAGS

config_flags.DEFINE_config_file("config", None, "Training configuration.", lock_config=True)
flags.DEFINE_string("workdir", None, "Work directory.")
flags.DEFINE_string("eval_folder", "samples", "The folder name for storing evaluation results")
flags.DEFINE_string("sample_folder", "sample", "The folder name for storing evaluation results")
flags.DEFINE_integer("n_points_per_gpu", "512", "aa")
flags.DEFINE_integer("n_timesteps", "1200", "aa")
flags.mark_flags_as_required(["workdir", "config"])

tf.config.experimental.set_visible_devices([], "GPU")
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"


def main(argv):
    sample(FLAGS.config, FLAGS.workdir, FLAGS.eval_folder, FLAGS.sample_folder, FLAGS.n_points_per_gpu, FLAGS.n_timesteps)


def sample(config, workdir, eval_folder="samples", sample_dir="sample", n_points_per_gpu=128, n_timesteps=1200):
    # Fix the seed for z = sde.prior_sampling(shape).to(device) in deterministic sampling
    torch.manual_seed(config.seed)
    # Create directory to eval_folder
    eval_dir = os.path.join(workdir, eval_folder)
    tf.io.gfile.makedirs(eval_dir)

    # Build data pipeline
    train_ds, eval_ds, _ = datasets.get_dataset(config, uniform_dequantization=config.data.uniform_dequantization, evaluation=True)

    # Create data normalizer and its inverse
    scaler = datasets.get_data_scaler(config)
    inverse_scaler = datasets.get_data_inverse_scaler(config)

    # Initialize model
    score_model = mutils.create_model(config)
    optimizer = losses.get_optimizer(config, score_model.parameters())
    ema = ExponentialMovingAverage(score_model.parameters(), decay=config.model.ema_rate)
    state = dict(optimizer=optimizer, model=score_model, ema=ema, step=0)

    checkpoint_dir = os.path.join(workdir, "checkpoints")

    # Setup SDEs
    if config.training.sde.lower() == "vpsde":
        sde = sde_lib.VPSDE(beta_min=config.model.beta_min, beta_max=config.model.beta_max, N=config.model.num_scales)
        sampling_eps = 1e-3
    else:
        raise NotImplementedError(f"SDE {config.training.sde} unsupported.")

    sampling_shape = (config.eval.batch_size, config.data.num_channels, config.data.image_size, config.data.image_size)

    begin_ckpt = config.eval.begin_ckpt
    logging.info("begin checkpoint: %d" % (begin_ckpt,))
    for ckpt in range(begin_ckpt, config.eval.end_ckpt + 1):
        # Wait if the target checkpoint doesn't exist yet
        waiting_message_printed = False
        ckpt_filename = os.path.join(checkpoint_dir, "checkpoint_{}.pth".format(ckpt))
        while not tf.io.gfile.exists(ckpt_filename):
            if not waiting_message_printed:
                logging.warning("Waiting for the arrival of checkpoint_%d" % (ckpt,))
                waiting_message_printed = True
            time.sleep(60)

        # Wait for 2 additional mins in case the file exists but is not ready for reading
        ckpt_path = os.path.join(checkpoint_dir, f"checkpoint_{ckpt}.pth")
        try:
            state = restore_checkpoint(ckpt_path, state, device=config.device)
        except:
            time.sleep(60)
            try:
                state = restore_checkpoint(ckpt_path, state, device=config.device)
            except:
                time.sleep(120)
                state = restore_checkpoint(ckpt_path, state, device=config.device)
        ema.copy_to(score_model.parameters())

        if config.sampling.method == "dpm_solver_v3":
            # Dpm-Solver-v3 requires precomputing
            # Choose statistics_dir
            # {ckpt}_{opt.eps}_{opt.n_timesteps}_{num_gpus}_{opt.n_batch}_{opt.batch_size}
            statistics_dir = None
            statistics_path = os.path.join(workdir, "statistics")
            max_steps, max_samples = -1, -1
            for folder in os.listdir(statistics_path):
                # f"{ckpt}_{config.sampling.eps}_{config.sampling.steps}_{NUM_SAMPLES}"
                items = folder.split("_")
                if (
                    int(items[0]) == ckpt
                    and float(items[1]) == config.sampling.eps
                    and int(items[5]) == n_points_per_gpu
                    and int(items[2]) == n_timesteps
                ):
                    steps, samples = int(items[2]), int(items[3]) * int(items[4]) * int(items[5])
                    if (steps > max_steps) or (steps == max_steps and samples > max_samples):
                        max_steps, max_samples = steps, samples
                        statistics_dir = os.path.join(statistics_path, folder)

            assert statistics_dir is not None, "No appropriate statistics found."
            print("Use statistics", statistics_dir)

            noise_pred_fn = get_noise_fn(sde, score_model, train=False, continuous=True)
            ns = NoiseScheduleVP("linear", continuous_beta_0=sde.beta_0, continuous_beta_1=sde.beta_1)
            dpm_solver_v3 = DPM_Solver_v3(
                statistics_dir,
                max_steps,
                noise_pred_fn,
                ns,
                steps=config.sampling.steps,
                t_start=sde.T,
                t_end=config.sampling.eps,
                skip_type=config.sampling.skip_type,
                degenerated=config.sampling.degenerated,
                device=config.device,
            )

            def dpm_solver_v3_sampler():
                with torch.no_grad():
                    x = sde.prior_sampling(sampling_shape).to(config.device)
                    x = dpm_solver_v3.sample(
                        x,
                        order=config.sampling.order,
                        p_pseudo=config.sampling.predictor_pseudo,
                        use_corrector=config.sampling.use_corrector,
                        c_pseudo=config.sampling.corrector_pseudo,
                        lower_order_final=config.sampling.lower_order_final,
                    )
                return inverse_scaler(x), config.sampling.steps

            sampling_fn = dpm_solver_v3_sampler
        else:
            sampling_fn = sampling.get_sampling_fn(config, sde, sampling_shape, inverse_scaler, sampling_eps)
            sampling_fn = functools.partial(sampling_fn, score_model)

        # Directory to save samples. Different for each host to avoid writing conflicts
        this_sample_dir = os.path.join(eval_dir, f"ckpt_{ckpt}", sample_dir)
        tf.io.gfile.makedirs(this_sample_dir)
        logging.info(this_sample_dir)
        num_sampling_rounds = config.eval.num_samples // config.eval.batch_size + 1
        for r in range(num_sampling_rounds):
            samples_raw, n = sampling_fn()
            logging.info("sampling -- ckpt: %d, round: %d (NFE %d)" % (ckpt, r, n))
            samples = np.clip(samples_raw.permute(0, 2, 3, 1).cpu().numpy() * 255.0, 0, 255).astype(np.uint8)
            samples = samples.reshape((-1, config.data.image_size, config.data.image_size, config.data.num_channels))
            # Write samples to disk or Google Cloud Storage
            with tf.io.gfile.GFile(os.path.join(this_sample_dir, f"samples_{r}.npz"), "wb") as fout:
                io_buffer = io.BytesIO()
                np.savez_compressed(io_buffer, samples=samples)
                fout.write(io_buffer.getvalue())

            if r == 0:
                nrow = int(np.sqrt(samples_raw.shape[0]))
                image_grid = make_grid(samples_raw, nrow, padding=2)
                with tf.io.gfile.GFile(os.path.join(this_sample_dir, "sample.png"), "wb") as fout:
                    save_image(image_grid, fout)


if __name__ == "__main__":
    app.run(main)