import ml_collections
from ml_collections import config_flags
from pathlib import Path
from absl import app

import jax
import jax.numpy as jnp
import optax
import orbax.checkpoint as ocp
import logging
from functools import partial


# Custom modules
from enf.bi_invariant.SE2_bi_invariant import SE2onR2BiInvariant
from enf.equivariant_neural_field import EquivariantNeuralField
from utils.datasets import get_dataloader



def get_config():

    # Define config
    config = ml_collections.ConfigDict()
    config.load_from_checkpoint = True
    config.checkpoint_to_load = "./checkpoints/ombria-enf"
    config.seed = 68

    # Define the model for segmentation
    config.seg_nef = ml_collections.ConfigDict()
    config.seg_nef.num_hidden = 128
    config.seg_nef.num_heads = 8
    config.seg_nef.num_out = 1
    config.seg_nef.emb_freq_mult_q = 2.0
    config.seg_nef.emb_freq_mult_v = 3.0
    config.seg_nef.k_nearest = 4

    # Dataset config
    config.dataset = ml_collections.ConfigDict()
    config.dataset.name = "ombria"
    config.dataset.path = "./data/ombria"
    config.dataset.num_signals_train = -1
    config.dataset.num_signals_test = -1
    config.dataset.batch_size = 16
    config.dataset.num_workers = 0
    config.dataset.tr_resolution = 256
    config.dataset.te_resolution = 256
    config.dataset.mask_points = False
    config.dataset.num_mask_points = 0.5

    # Optimizer config
    config.optim = ml_collections.ConfigDict()
    config.optim.lr_enf = 1e-4
    config.optim.lr_meta_sgd = 1e-3
    config.optim.init_inner_lr_a = 30.0
    config.optim.inner_steps = 3
    config.optim.num_points = 128*128
    config.optim.num_points_seg = 64*64
    config.optim.start_seg = 0

    # Training config
    config.train = ml_collections.ConfigDict()
    config.train.num_epochs = 500
    config.train.log_interval = 10
    config.train.validation_interval = 10
    logging.getLogger().setLevel(logging.INFO)

    # Set checkpoint path
    return config


# Set config flags
_CONFIG = config_flags.DEFINE_config_dict("config", get_config())


def main(_):

    # Get config
    config = _CONFIG.value
    config.unlock()

    # Load prev checkpoint
    p_checkpoint_manager = ocp.CheckpointManager(
        directory=Path(config.checkpoint_to_load).absolute(),
        item_handlers={
            'state': ocp.StandardCheckpointHandler(),
            'config': ocp.JsonCheckpointHandler(),
        },
        item_names=['state', 'config']
    )
    ckpt = p_checkpoint_manager.restore(p_checkpoint_manager.latest_step())
    [enf_params, poses], meta_sgd_params = ckpt['state']
    config.nef = ml_collections.ConfigDict(ckpt['config']['nef'])
    p_checkpoint_manager.close()

    ##############################
    # Initializing the model
    ##############################

    # Load dataset, get sample image, create corresponding coordinates
    train_dloader, test_dloader = get_dataloader(config.dataset)
    sample_img, _ = next(iter(train_dloader))

    # Random key
    key = jax.random.PRNGKey(55)

    # Create coordinate grid for training
    x = jnp.stack(jnp.meshgrid(jnp.linspace(-1, 1, sample_img.shape[1]), jnp.linspace(-1, 1, sample_img.shape[2])), axis=-1)
    x = jnp.reshape(x, (-1, 2))
    x_tr = jnp.repeat(x[None, ...], sample_img.shape[0], axis=0)

    # Create coordinate grid for testing
    sample_img, _ = next(iter(test_dloader))
    x = jnp.stack(jnp.meshgrid(jnp.linspace(-1, 1, sample_img.shape[1]), jnp.linspace(-1, 1, sample_img.shape[2])), axis=-1)
    x = jnp.reshape(x, (-1, 2))
    x_te = jnp.repeat(x[None, ...], sample_img.shape[0], axis=0)

    # Define the model
    model = EquivariantNeuralField(
        num_hidden=config.nef.num_hidden,
        num_heads=config.nef.num_heads,
        num_out=config.nef.num_out,
        latent_dim=config.nef.latent_dim,
        bi_invariant=SE2onR2BiInvariant(),
        embedding_freq_multiplier=[config.nef.emb_freq_mult_q, config.nef.emb_freq_mult_v],
        k_nearest=config.nef.k_nearest
    )

    # Define optimizer for meta SGD
    meta_sgd_optimizer = optax.adam(learning_rate=config.optim.lr_meta_sgd)
    meta_sgd_opt_state = meta_sgd_optimizer.init(meta_sgd_params)

    # Define the model
    seg_model = EquivariantNeuralField(
        num_hidden=config.seg_nef.num_hidden,
        num_heads=config.seg_nef.num_heads,
        num_out=config.seg_nef.num_out,
        latent_dim=config.nef.latent_dim,
        bi_invariant=SE2onR2BiInvariant(),
        embedding_freq_multiplier=[config.seg_nef.emb_freq_mult_q, config.seg_nef.emb_freq_mult_v],
        k_nearest=config.seg_nef.k_nearest
    )
    seg_params = seg_model.init(
        key,
        x_tr[:, :1024],
        jnp.broadcast_to(poses, ((config.dataset.batch_size, config.nef.num_latents, poses.shape[-1])),),
        jnp.ones((config.dataset.batch_size, config.nef.num_latents, config.nef.latent_dim)),
        jnp.ones((config.dataset.batch_size, config.nef.num_latents, 1)) * 2 / jnp.sqrt(config.nef.num_latents))

    # Define optimizer for the seg model
    seg_optimizer = optax.adam(learning_rate=config.optim.lr_enf)
    seg_opt_state = seg_optimizer.init(seg_params)

    ##############################
    # Training logic
    ##############################
    @partial(jax.jit, static_argnums=(4,))
    def inner_loop(params, x_c, y_c, key, input_mask=None):
        # Unpack params
        (enf_params, poses), meta_sgd_params = params

        if input_mask:
            point_idcs = input_mask
        else:
            # Subsample points, one for every step
            point_idcs = jax.random.permutation(key, jnp.arange(x_c.shape[1]))

        # Broadcast over batch size
        poses = jnp.broadcast_to(poses, ((config.dataset.batch_size, config.nef.num_latents, poses.shape[-1])))

        # Add some noise to the poses
        poses = poses# + jax.random.normal(key, poses.shape) * 0.2 / jnp.sqrt(config.nef.num_latents)

        # Initialize values for the poses, note that these depend on the bi-invariant, context and window
        c = jnp.ones((config.dataset.batch_size, config.nef.num_latents, config.nef.latent_dim))  # context vectors
        g = jnp.ones((config.dataset.batch_size, config.nef.num_latents, 1)) * 2 / jnp.sqrt(config.nef.num_latents)  # gaussian window parameter

        def mse_loss(z, x_i, y_i):
            out = model.apply(enf_params, x_i, *z)
            return jnp.sum(jnp.mean((out - y_i) ** 2, axis=(1, 2)), axis=0)

        for i in range(config.optim.inner_steps):
            x_i = x_c[:, point_idcs[i * config.optim.num_points:(i + 1) * config.optim.num_points]]
            y_i = y_c[:, point_idcs[i * config.optim.num_points:(i + 1) * config.optim.num_points]]
            loss, grads = jax.value_and_grad(mse_loss)((poses, c, g), x_i, y_i)

            # Update the latent features
            c = c - meta_sgd_params * grads[1]

        # Return loss with resulting latents, average over batch
        return mse_loss((poses, c, g), x_i, y_i), (poses, c, g)

    @jax.jit
    def seg_loss(seg_params, x_i, y_i, z, key):
        # Forward pass
        pred = seg_model.apply(seg_params, x_i, *z)

        # compute the binary cross entropy loss
        loss = optax.sigmoid_binary_cross_entropy(pred, y_i).mean()

        # Calculate IoU
        pred = pred > 0.
        intersection = jnp.sum(y_i.astype(bool) & pred)
        union = jnp.sum(y_i.astype(bool) | pred)
        iou = intersection / union
        return loss, iou

    @jax.jit
    def seg_train_step(x_i, y_i, z, seg_params, seg_opt_state, key):
        # Subsample points
        point_idcs = jax.random.permutation(key, jnp.arange(x.shape[1]))[:config.optim.num_points_seg]
        x_i = x_i[:, point_idcs]
        y_i = y_i[:, point_idcs]

        # Compute the loss and the gradients
        (loss, iou), grads = jax.value_and_grad(seg_loss, has_aux=True)(seg_params, x_i, y_i, z, key)

        # Update the ENF backbone
        seg_grads, seg_opt_state = seg_optimizer.update(grads, seg_opt_state)
        seg_params = optax.apply_updates(seg_params, seg_grads)

        return (loss, iou), seg_params, seg_opt_state

    def run_epoch(enf_params, seg_params, dloader, poses, meta_sgd_params, seg_opt_state, meta_sgd_opt_state, key, epoch, val=False):
        epoch_loss = []
        epoch_seg_iou = []
        for i, batch in enumerate(dloader):
            # Unpack batch, flatten img
            img, mask = batch
            y = jnp.reshape(img, (img.shape[0], -1, img.shape[-1]))
            mask = jnp.reshape(mask, (mask.shape[0], -1, 1))

            # Perform outer loop optimization
            if not val:
                loss, z = inner_loop([[enf_params, poses], meta_sgd_params], x_tr, y, key)
                (seg_loss, seg_iou), seg_params, seg_opt_state = seg_train_step(x_tr, mask, z, seg_params, seg_opt_state, key)
            else:
                loss, z = inner_loop([[enf_params, poses], meta_sgd_params], x_te, y, key, input_mask)
                (seg_loss, seg_iou), _, _ = seg_train_step(x_te, mask, z, seg_params, seg_opt_state, key)

            epoch_loss.append(seg_loss)
            epoch_seg_iou.append(seg_iou)

        logging.info(f"{'val' if val else 'train'} - epoch {epoch} -- loss: {sum(epoch_loss) / len(epoch_loss)} -- seg_iou {sum(epoch_seg_iou) / len(epoch_seg_iou)}")
        return sum(epoch_loss) / len(epoch_loss), enf_params, seg_params, poses, meta_sgd_params, seg_opt_state, meta_sgd_opt_state, key

    # Sample a test mask
    if config.dataset.mask_points:
        input_mask = jax.random.permutation(key, jnp.arange(x_te.shape[1]))[:int(x_te.shape[1] * config.dataset.num_mask_points)]
    else:
        input_mask = None

    # Training loop
    for epoch in range(config.train.num_epochs):
        loss, enf_params, seg_params, poses, meta_sgd_params, seg_opt_state, meta_sgd_opt_state, key = run_epoch(
            enf_params, seg_params, train_dloader, poses, meta_sgd_params, seg_opt_state, meta_sgd_opt_state, key, epoch
        )

        if epoch % config.train.validation_interval == 0:
            v_loss, _, _, _, _, _, _, _ = run_epoch(
                enf_params, seg_params, test_dloader, poses, meta_sgd_params, seg_opt_state, meta_sgd_opt_state, key, epoch, val=True
            )

if __name__ == "__main__":
    app.run(main)