#  type: ignore

import argparse
import glob
import os
import pickle
from functools import partial
from typing import Any

import haiku as hk
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import optax
import tensorflow.compat.v2 as tf

from models.vq_vae import build_vq_vae_fn


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="Process training parameters.")
    parser.add_argument("--batch-size", type=int, help="Batch size.", default=128)
    parser.add_argument(
        "--num-hiddens",
        type=int,
        help="Number of hidden neurons in the residual stacks.",
        default=128,
    )
    parser.add_argument(
        "--num-residual-layers", type=int, help="Number of residual stacks.", default=2
    )
    parser.add_argument(
        "--num-residual-hiddens",
        type=int,
        help="Number of hidden neurons in the intermediate conv of residual stacks.",
        default=64,
    )
    parser.add_argument(
        "--embedding-dim",
        type=int,
        help="Embedding dimension (number of codes in the codebooks).",
        default=64,
    )
    parser.add_argument(
        "--num-embeddings",
        type=int,
        help=(
            "Number of possibilities for each code (the higher the more information"
            + "can be kept)."
        ),
        default=512,
    )
    parser.add_argument(
        "--commitment-cost",
        type=float,
        help="Commitment cost (try different values to check the best one).",
        default=0.25,
    )
    parser.add_argument(
        "--vq-use-ema",
        type=bool,
        help="Whether to use exponential moving averages of weights",
        default=True,
    )
    parser.add_argument(
        "--decay", type=float, help="Decay of the EMA scheme.", default=0.99
    )
    parser.add_argument(
        "--learning-rate",
        type=int,
        help="Learning rate of the optimizer.",
        default=2e-4,
    )
    parser.add_argument(
        "--opt", type=str, help="Optimizer (must be an optax submodule)", default="adam"
    )
    parser.add_argument(
        "--num-epochs",
        type=int,
        help="Number of epochs.",
        default=int(5e5),
    )
    parser.add_argument(
        "--num-devices", type=int, help="Number of devices to use.", default=None
    )
    parser.add_argument(
        "--save-dir", type=str, help="Saving directory", default="./training_results/"
    )
    parser.add_argument(
        "--data-folder", type=str, help="Data folder", default="/app/data/"
    )
    parser.add_argument(
        "--training-log-epoch",
        type=int,
        help="Epoch at which to log the training losses.",
        default=100,
    )
    parser.add_argument(
        "--validation-log-epoch",
        type=int,
        help="Epoch at which to log the validation losses.",
        default=5000,
    )

    args = parser.parse_args()
    return args


def main(args: argparse.Namespace):

    tf.enable_v2_behavior()
    print("JAX version {}".format(jax.__version__))
    print("Haiku version {}".format(hk.__version__))
    print("TF version {}".format(tf.__version__))

    num_devices = args.num_devices if args.num_devices else len(jax.devices())
    assert (
        args.batch_size % num_devices == 0
    ), "batch_size must be dividible by num_devices"

    # Data Loading.
    filenames = glob.glob(args.data_folder)
    image_count = len(filenames)
    val_size = int(0.1 * image_count)

    def load_images(filename: str) -> Any:
        img = tf.io.read_file(filename)
        img = tf.cast(tf.io.decode_jpeg(img, channels=3), tf.float32) / 255.0 - 0.5
        return img

    list_ds = tf.data.Dataset.from_tensor_slices(filenames).shuffle(
        image_count, reshuffle_each_iteration=False
    )
    train_ds = list_ds.skip(val_size)
    val_ds = list_ds.take(val_size)

    autotune = tf.data.AUTOTUNE

    # Define the train and validation
    # Cache the data for faster use
    train_ds = (
        train_ds.map(load_images, num_parallel_calls=autotune)
        .cache()
        .shuffle(buffer_size=1000)
        .batch(args.batch_size, drop_remainder=True)
        .prefetch(buffer_size=autotune)
    )
    val_ds = (
        val_ds.map(load_images, num_parallel_calls=autotune)
        .cache()
        .shuffle(buffer_size=1000)
        .batch(args.batch_size, drop_remainder=True)
        .prefetch(buffer_size=autotune)
    )

    train_data_variance = np.var(next(iter(train_ds)))
    print("train data variance: %s" % train_data_variance)

    # Build modules.
    num_hiddens = args.num_hiddens
    num_residual_layers = args.num_residual_layers
    num_residual_hiddens = args.num_residual_hiddens
    embedding_dim = args.embedding_dim
    num_embeddings = args.num_embeddings
    commitment_cost = args.commitment_cost
    vq_use_ema = args.vq_use_ema
    decay = args.decay
    learning_rate = args.learning_rate

    vq_vae_fn = build_vq_vae_fn(
        num_hiddens,
        num_residual_hiddens,
        num_residual_layers,
        embedding_dim,
        num_embeddings,
        decay,
        vq_use_ema,
        commitment_cost,
        train_data_variance,
    )
    forward = partial(vq_vae_fn, is_training=True)

    forward = hk.transform_with_state(forward)
    optimizer = getattr(optax, args.opt)
    optimizer = optimizer(learning_rate)

    # Initialization
    num_epochs = args.num_epochs
    rng = jax.random.PRNGKey(42)

    data = next(iter(train_ds)).numpy()
    new_shape = (num_devices, args.batch_size // num_devices) + data.shape[1:]

    data = data.reshape(new_shape)
    params, state = jax.pmap(forward.init, in_axes=(None, 0), axis_name="device")(
        rng, data
    )
    opt_state = jax.pmap(optimizer.init, axis_name="device")(params)

    # Define a training step to be pmapped
    def loss_fn(params, state, data):
        # Pack model output and state together.
        model_output, state = forward.apply(params, state, None, data)
        loss = model_output["loss"]
        return loss, (model_output, state)

    loss_grad = jax.grad(loss_fn, has_aux=True)

    @jax.jit
    def train_step(params, state, opt_state, data):
        # Gradients
        grads, (model_output, state) = loss_grad(params, state, data)
        grads = jax.tree_map(lambda g: jax.lax.pmean(g, axis_name="device"), grads)
        model_output = jax.tree_map(
            lambda m: jax.lax.pmean(m, axis_name="device"), model_output
        )
        updates, opt_state = optimizer.update(grads, opt_state)
        params = optax.apply_updates(params, updates)
        return params, state, opt_state, model_output

    # Define a validation step to be pmaconfig.num_residual_hiddensfn(params, state, data)
        model_output = jax.tree_map(
            lambda m: jax.lax.pmean(m, axis_name="device"), model_output
        )
        return model_output

    # Logs
    train_losses = []
    train_recon_errors = []
    train_perplexities = []
    train_vqvae_loss = []

    test_losses = []
    test_recon_errors = []
    test_perplexities = []
    test_vqvae_loss = []

    # Save folder
    subfolder = f"{num_residual_layers}-{num_hiddens}-{num_residual_hiddens}"
    subfolder += f"-{embedding_dim}-{num_embeddings}-{commitment_cost}/"
    os.makedirs(args.save_dir + subfolder, exist_ok=True)

    best_loss = jnp.inf

    val_len = len(val_ds)
    epoch = 0
    training_log_epoch = args.training_log_epoch
    val_loss_epoch = args.val_loss_epoch

    for data in train_ds.repeat(-1):
        epoch += 1
        data = data.numpy()
        data = data.reshape(new_shape)

        params, state, opt_state, train_results = jax.pmap(
            train_step, axis_name="device"
        )(params, state, opt_state, data)
        train_results = jax.device_get(train_results)
        train_losses.append(train_results["loss"][0])
        train_recon_errors.append(train_results["recon_error"][0])
        train_perplexities.append(train_results["vq_output"]["perplexity"][0])
        train_vqvae_loss.append(train_results["vq_output"]["loss"][0])
        if epoch % training_log_epoch == 0:
            print(
                f"[Step {epoch}/{num_epochs}] "
                + ("train loss: %f " % np.mean(train_losses[-training_log_epoch:]))
                + (
                    "recon_error: %.3f "
                    % np.mean(train_recon_errors[-training_log_epoch:])
                )
                + (
                    "perplexity: %.3f "
                    % np.mean(train_perplexities[-training_log_epoch:])
                )
                + ("vqvae loss: %.3f" % np.mean(train_vqvae_loss[-training_log_epoch:]))
            )

        if epoch % val_loss_epoch == 0:
            for data in val_ds:
                if data.shape[0] % num_devices == 0:
                    data = data.numpy()
                    data = data.reshape(new_shape)
                    test_results = jax.pmap(validation_step, axis_name="device")(
                        params, state, data
                    )

                    test_results = jax.device_get(test_results)
                    test_losses.append(test_results["loss"][0])
                    test_recon_errors.append(test_results["recon_error"][0])
                    test_perplexities.append(test_results["vq_output"]["perplexity"][0])
                    test_vqvae_loss.append(test_results["vq_output"]["loss"][0])

            print(
                f"[Step {epoch}/{num_epochs}] "
                + ("test loss: %f " % np.mean(test_losses[-val_len:]))
                + ("recon_error (test): %.3f " % np.mean(test_recon_errors[-val_len:]))
                + ("perplexity (test): %.3f " % np.mean(test_perplexities[-val_len:]))
                + ("vqvae loss (test): %.3f" % np.mean(test_vqvae_loss[-val_len:]))
            )

            # Model weights
            with open(args.save_dir + subfolder + f"vq-vae-{epoch}.pkl", "wb") as f:
                pickle.dump(
                    (
                        jax.tree_map(lambda p: p[0], params),
                        jax.tree_map(lambda p: p[0], state),
                    ),
                    f,
                )
            if np.mean(test_losses[-val_size:]) < best_loss:
                best_loss = np.mean(test_losses[-val_len:])
                with open(args.save_dir + subfolder + "vq-vae.pkl", "wb") as f:
                    pickle.dump(
                        (
                            jax.tree_map(lambda p: p[0], params),
                            jax.tree_map(lambda p: p[0], state),
                        ),
                        f,
                    )
        if epoch == num_epochs:
            break
    # Plotting
    f = plt.figure(figsize=(16, 8))
    ax = f.add_subplot(1, 2, 1)
    ax.plot(train_recon_errors)
    ax.set_yscale("log")
    ax.set_title("NMSE.")

    ax = f.add_subplot(1, 2, 2)
    ax.plot(train_perplexities)
    ax.set_title("Average codebook usage (perplexity).")
    plt.savefig(args.save_dir + subfolder + "metrics.png", transparent=True, dpi=300)


if __name__ == "__main__":
    args = parse_args()
    main(args)
