import ml_collections
from ml_collections import config_flags
from pathlib import Path
from absl import app

import math
import jax
import jax.numpy as jnp
import optax
import orbax.checkpoint as ocp
import logging

# Custom modules
from enf.bi_invariant.Rn_bi_invariant import RnBiInvariant
from enf.bi_invariant.absolute_position import AbsolutePositionND
from enf.bi_invariant.SE2_bi_invariant import SE2onR2BiInvariant
from enf.equivariant_neural_field import EquivariantNeuralField
from utils.datasets import get_dataloader


def get_config():

    # Define config
    config = ml_collections.ConfigDict()
    config.seed = 68

    # Define the model
    config.nef = ml_collections.ConfigDict()
    config.nef.num_hidden = 128
    config.nef.num_heads = 3
    config.nef.num_out = 3
    config.nef.num_latents = 25
    config.nef.latent_dim = 64
    config.nef.emb_freq_mult_q = 2.0
    config.nef.emb_freq_mult_v = 5.0
    config.nef.k_nearest = 4
    config.nef.bi_invariant = "R2"

    # Dataset config
    config.dataset = ml_collections.ConfigDict()
    config.dataset.name = "cifar10"
    config.dataset.path = "./data/cifar10"
    config.dataset.batch_size = 32
    config.dataset.num_workers = 0

    # Optimizer config
    config.optim = ml_collections.ConfigDict()
    config.optim.lr_enf = 1e-4
    config.optim.lr_meta_sgd = 1e-3
    config.optim.init_inner_lr_a = 30.0
    config.optim.inner_steps = 3
    config.optim.num_points = 64*64

    # Training config
    config.train = ml_collections.ConfigDict()
    config.train.num_epochs = 10
    config.train.log_interval = 10
    config.train.validation_interval = 10
    logging.getLogger().setLevel(logging.INFO)

    # Set checkpoint path
    return config


# Set config flags
_CONFIG = config_flags.DEFINE_config_dict("config", get_config())


def main(_):

    # Get config
    config = _CONFIG.value
    config.unlock()

    # Initialize wandb
    config.checkpoint_path = f"./checkpoints/cifar-enf"

    ##############################
    # Initializing the model
    ##############################

    # Load dataset, get sample image, create corresponding coordinates
    train_dloader, test_dloader = get_dataloader(config.dataset)
    sample_img, _ = next(iter(train_dloader))

    # Random key
    key = jax.random.PRNGKey(55)

    # Create coordinate grid for training
    x = jnp.stack(jnp.meshgrid(jnp.linspace(-1, 1, sample_img.shape[1]), jnp.linspace(-1, 1, sample_img.shape[2])), axis=-1)
    x = jnp.reshape(x, (-1, 2))
    x = jnp.repeat(x[None, ...], sample_img.shape[0], axis=0)

    # Select the bi-invariant
    if config.nef.bi_invariant == "R2":
        bi_inv = RnBiInvariant(2)
    elif config.nef.bi_invariant == "SE2":
        bi_inv = SE2onR2BiInvariant()
    else:
        bi_inv = AbsolutePositionND(2)

    # Define the model
    model = EquivariantNeuralField(
        num_hidden=config.nef.num_hidden,
        num_heads=config.nef.num_heads,
        num_out=config.nef.num_out,
        latent_dim=config.nef.latent_dim,
        bi_invariant=bi_inv,
        embedding_freq_multiplier=[config.nef.emb_freq_mult_q, config.nef.emb_freq_mult_v],
        k_nearest=config.nef.k_nearest
    )

    # Create dummy latents for model init
    d_c = jnp.ones((config.dataset.batch_size, config.nef.num_latents, config.nef.latent_dim))  # context vectors
    d_g = jnp.ones((config.dataset.batch_size, config.nef.num_latents, 1))  # gaussian window parameter

    # Initialize the poses on a grid
    if config.nef.num_latents == 1:
        poses = jnp.zeros((1, config.nef.num_latents, 2))
    else:
        lims = 1 - 1 / math.sqrt(config.nef.num_latents)
        poses = jnp.stack(jnp.meshgrid(jnp.linspace(-lims, lims, int(math.sqrt(config.nef.num_latents))), jnp.linspace(-lims, lims, int(math.sqrt(config.nef.num_latents)))), axis=-1)
        poses = jnp.reshape(poses, (1, -1, 2))

    # Add orientation optionally
    if config.nef.bi_invariant == "SE2":
        poses = jnp.concatenate((poses, jnp.zeros((1, config.nef.num_latents, 1))), axis=-1)

    # Init the model
    enf_params = model.init(key, x[:, :config.optim.num_points], jnp.broadcast_to(poses, ((config.dataset.batch_size, config.nef.num_latents, poses.shape[-1]))), d_c, d_g)

    # Define optimizer for the ENF backbone
    enf_optimizer = optax.adam(learning_rate=config.optim.lr_enf)
    enf_opt_state = enf_optimizer.init([enf_params, poses])

    # Define optimizer for meta SGD
    meta_sgd_params = jnp.ones(config.nef.latent_dim) * config.optim.init_inner_lr_a
    meta_sgd_optimizer = optax.adam(learning_rate=config.optim.lr_meta_sgd)
    meta_sgd_opt_state = meta_sgd_optimizer.init(meta_sgd_params)

    # Define checkpointing
    checkpoint_options = ocp.CheckpointManagerOptions(
        save_interval_steps=1,
        max_to_keep=1,
    )
    checkpoint_manager = ocp.CheckpointManager(
        directory=Path(config.checkpoint_path).absolute(),
        options=checkpoint_options,
        item_handlers={
            'state': ocp.StandardCheckpointHandler(),
            'config': ocp.JsonCheckpointHandler(),
        },
        item_names=['state', 'config']
    )

    ##############################
    # Training logic
    ##############################
    @jax.jit
    def inner_loop(params, x_c, y_c, key):
        # Unpack params
        (enf_params, poses), meta_sgd_params = params

        # Broadcast over batch size
        poses = jnp.broadcast_to(poses, ((config.dataset.batch_size, config.nef.num_latents, poses.shape[-1])))

        # Add some noise to the poses
        poses = poses + jax.random.normal(key, poses.shape) * 0.2 / jnp.sqrt(config.nef.num_latents)

        # Initialize values for the poses, note that these depend on the bi-invariant, context and window
        c = jnp.ones((config.dataset.batch_size, config.nef.num_latents, config.nef.latent_dim))  # context vectors
        g = jnp.ones((config.dataset.batch_size, config.nef.num_latents, 1)) * 2 / jnp.sqrt(config.nef.num_latents)  # gaussian window parameter

        def mse_loss(z, x_i, y_i):
            out = model.apply(enf_params, x_i, *z)
            return jnp.sum(jnp.mean((out - y_i) ** 2, axis=(1, 2)), axis=0)

        for i in range(config.optim.inner_steps):
            loss, grads = jax.value_and_grad(mse_loss)((poses, c, g), x_c, y_c)

            # Update the latent features
            c = c - meta_sgd_params * grads[1]

        # Return loss with resulting latents, average over batch
        return mse_loss((poses, c, g), x_c, y_c), (poses, c, g)

    @jax.jit
    def outer_step(x_i, y_i, enf_params, poses, meta_sgd_params, enf_opt_state, meta_sgd_opt_state, key):
        # SPlit key
        key, new_key = jax.random.split(key)

        # Perform inner loop optimization
        (loss, z), grads = jax.value_and_grad(inner_loop, has_aux=True)([[enf_params, poses], meta_sgd_params], x_i, y_i, key)

        # Update the ENF backbone
        enf_grads, enf_opt_state = enf_optimizer.update(grads[0], enf_opt_state)
        enf_params, poses = optax.apply_updates([enf_params, poses], enf_grads)

        # Update the meta SGD parameters
        meta_sgd_grads, meta_sgd_opt_state = meta_sgd_optimizer.update(grads[1], meta_sgd_opt_state)
        meta_sgd_params = optax.apply_updates(meta_sgd_params, meta_sgd_grads)

        # Sample new key
        return loss, z, enf_params, poses, meta_sgd_params, enf_opt_state, meta_sgd_opt_state, new_key

    def run_epoch(enf_params, dloader, poses, meta_sgd_params, enf_opt_state, meta_sgd_opt_state, key, epoch, val=False):
        epoch_loss = []
        for i, batch in enumerate(dloader):
            # Unpack batch, flatten img
            img, _ = batch
            y = jnp.reshape(img, (img.shape[0], -1, img.shape[-1]))

            # Perform outer loop optimization
            if not val:
                loss, z, enf_params, poses, meta_sgd_params, enf_opt_state, meta_sgd_opt_state, key = outer_step(
                    x, y, enf_params, poses, meta_sgd_params, enf_opt_state, meta_sgd_opt_state, key)
            else:
                loss, z, _, _, _, _, _, _ = outer_step(
                    x, y, enf_params, poses, meta_sgd_params, enf_opt_state, meta_sgd_opt_state, key)

            epoch_loss.append(loss)
        logging.info(f"{'val' if val else 'train'} - epoch {epoch} -- loss: {sum(epoch_loss) / len(epoch_loss)}")
        return sum(epoch_loss) / len(epoch_loss), enf_params, poses, meta_sgd_params, enf_opt_state, meta_sgd_opt_state, key

    # Training loop
    for epoch in range(config.train.num_epochs):
        loss, enf_params, poses, meta_sgd_params, enf_opt_state, meta_sgd_opt_state, key = run_epoch(
            enf_params, train_dloader, poses, meta_sgd_params, enf_opt_state, meta_sgd_opt_state, key, epoch
        )

        if epoch % config.train.validation_interval == 0:
            v_loss, _, _, _, _, _, _ = run_epoch(
                enf_params, test_dloader, poses, meta_sgd_params, enf_opt_state, meta_sgd_opt_state, key, epoch, val=True
            )
        checkpoint_manager.save(step=epoch, args=ocp.args.Composite(
            state=ocp.args.StandardSave([[enf_params, poses], meta_sgd_params]),
            config=ocp.args.JsonSave(config.to_dict()))
        )


if __name__ == "__main__":
    app.run(main)