import logging
import ml_collections
from ml_collections import config_flags
from absl import app
from pathlib import Path

import wandb
import matplotlib.pyplot as plt

import jax
import jax.numpy as jnp
from flax import linen as nn
import optax
import orbax.checkpoint as ocp

# Custom modules
from enf.bi_invariant.Rn_bi_invariant import RnBiInvariant
from enf.bi_invariant.absolute_position import AbsolutePositionND
from enf.bi_invariant.SE2_bi_invariant import SE2onR2BiInvariant
from enf.equivariant_neural_field import EquivariantNeuralField

# Custom datasets
from utils.datasets import get_dataloader

#################################
### Model definitions          ##
#################################
from enf.equivariant_neural_field import RFFEmbedding

class MessagePassingBlock(nn.Module):
    num_hidden: int
    freq_multiplier: float

    @nn.compact
    def __call__(self, bi_inv, c):
        c = MessagePassing(self.num_hidden, self.freq_multiplier)(bi_inv, c)
        c = nn.Dense(self.num_hidden)(nn.LayerNorm()(nn.silu(c)))
        c = nn.Dense(self.num_hidden)(nn.silu(c))
        return c


class MessagePassing(nn.Module):
    num_hidden: int
    freq_multiplier: float

    def setup(self):
        # Maps from bi-invariants to kernel values, we use an RFF embedding here as well,
        # for the same reasons as mentioned above
        self.kernel_basis = nn.Sequential([
            RFFEmbedding(embedding_dim=self.num_hidden, learnable_coefficients=False, std=self.freq_multiplier),
            nn.Dense(self.num_hidden), nn.gelu, nn.Dense(self.num_hidden), nn.gelu])

        self.feature_transform = nn.Dense(self.num_hidden)

        # Construct bias
        self.bias_param = self.param('bias', nn.initializers.zeros, (self.num_hidden,))

    def __call__(self, bi_inv, c):
        """ Perform message passing on a fully connected pointcloud.

        Args:
            x: Array of shape (batch, num_points, num_features)
            fiber_kernel_basis: Array of shape (batch, num_points, num_points, basis_dim)
        """
        kernel = self.kernel_basis(bi_inv)

        # Perform the appearance convolution [batch, senders, channels] * [batch, senders, receivers, channels]
        # -> [batch, receivers, channels]
        c = jnp.einsum('bsc,brsc->brc', c, kernel) + self.bias_param
        c = self.feature_transform(c)
        return c


class EquivMPNN(nn.Module):
    num_hidden: int
    num_layers: int
    num_out: int

    bi_invariant: nn.Module
    freq_multiplier: float

    def setup(self):
        # Initial node embedding
        self.c_stem = nn.Dense(self.num_hidden)

        # Message passing layers
        message_passing_layers = []
        for i in range(self.num_layers):
            message_passing_layers.append(MessagePassingBlock(self.num_hidden, self.freq_multiplier))
        self.message_passing_layers = message_passing_layers

        # Readout
        self.readout_scalar = nn.Sequential([
            nn.Dense(self.num_hidden),
            nn.gelu,
            nn.Dense(self.num_hidden),
            nn.gelu,
            nn.Dense(self.num_out)
        ])

    def __call__(self, p, c, g, t):
        """ Forward pass through the network.

        Args:
            p: Poses, array of shape (batch, num_points, spatial_dim)
            c: Context vectors, array of shape (batch, num_points, num_in)
        """
        # Calculate the bi-invariants between the latents themselves.
        bi_inv = self.bi_invariant(p, p)

        # Embed the context vector
        c = self.c_stem(c)

        # Apply interaction layers
        for layer in self.message_passing_layers:
            c = layer(bi_inv, c)

        # Readout layer, average over all nodes in the graph
        out = self.readout_scalar(c)
        
        # Average over all nodes
        out = jnp.mean(out, axis=1)
        return out



#################################
### Main function              ##
#################################
def get_config():
    config = ml_collections.ConfigDict()
    config.checkpoint_path = f"./checkpoints/cifar-enf"

    # Classification config
    config.classifier = ml_collections.ConfigDict()
    config.classifier.num_hidden = 256
    config.classifier.num_layers = 4
    config.classifier.freq_multiplier = 1.0
    config.classifier.k = 9

    # Training config
    config.train = ml_collections.ConfigDict()
    config.train.num_epochs = 250
    config.train.batch_size = 32
    config.train.log_interval = 10
    config.train.val_interval = 10
    logging.getLogger().setLevel(logging.INFO)

    # Optimizer config
    config.optim = ml_collections.ConfigDict()
    config.optim.decay_lr_every = 400
    config.optim.start_lr = 1e-4
    config.optim.lr = config.optim.start_lr

    # Set seed
    config.seed = 68
    return config


# Set config flags
_CONFIG = config_flags.DEFINE_config_dict("config", get_config())


def main(_):
    config = _CONFIG.value
    config.unlock()

    # Define checkpointing
    checkpoint_options = ocp.CheckpointManagerOptions(
        save_interval_steps=10,
        max_to_keep=1,
    )
    checkpoint_manager = ocp.CheckpointManager(
        directory=Path(config.checkpoint_path).absolute(),
        options=checkpoint_options,
        item_handlers={
            'state': ocp.StandardCheckpointHandler(),
            'config': ocp.JsonCheckpointHandler(),
        },
        item_names=['state', 'config']
    )

    # Load checkpoint
    ckpt = checkpoint_manager.restore(checkpoint_manager.latest_step())
    ((enf_params, poses), meta_sgd_params), enf_config = ckpt['state'], ckpt['config']

    # Store config used to fit the enf
    config.enf_config = ml_collections.ConfigDict(enf_config)

    # Determine number of classes
    num_classes = 10
    config.num_classes = num_classes
    NUM_CLASSES = num_classes

    # Overwrite the batch size
    config.enf_config.dataset.batch_size = config.train.batch_size
    config.enf_config.dataset.num_signals_train = -1

    ##############################
    # Initializing the model
    ##############################

    # Load dataset, get sample image, create corresponding coordinates
    train_dloader, test_dloader = get_dataloader(config.enf_config.dataset)
    sample_img, _ = next(iter(train_dloader))
    img_shape = sample_img.shape[1:]

    # Create coordinate grid
    x = jnp.stack(jnp.meshgrid(jnp.linspace(-1, 1, sample_img.shape[1]), jnp.linspace(-1, 1, sample_img.shape[2])), axis=-1)
    x = jnp.reshape(x, (-1, 2))
    x = jnp.repeat(x[None, ...], sample_img.shape[0], axis=0)

    if config.enf_config.nef.bi_invariant == "R2":
        bi_inv = RnBiInvariant(2)
    elif config.nef.bi_invariant == "SE2":
        bi_inv = SE2onR2BiInvariant()
    else:
        bi_inv = AbsolutePositionND(2)

    # Define the models
    classifier = EquivMPNN(
        num_hidden=config.classifier.num_hidden,
        num_layers=config.classifier.num_layers,
        num_out=NUM_CLASSES,
        bi_invariant=bi_inv,
        freq_multiplier=config.classifier.freq_multiplier,
    )

    enf_model = EquivariantNeuralField(
        num_hidden=config.enf_config.nef.num_hidden,
        num_heads=config.enf_config.nef.num_heads,
        num_out=config.enf_config.nef.num_out,
        latent_dim=config.enf_config.nef.latent_dim,
        bi_invariant=bi_inv,
        embedding_freq_multiplier=[config.enf_config.nef.emb_freq_mult_q, config.enf_config.nef.emb_freq_mult_v],
        k_nearest=config.enf_config.nef.k_nearest
    )

    # Create dummy latents for model init
    d_c = jnp.ones((config.enf_config.dataset.batch_size, config.enf_config.nef.num_latents, config.enf_config.nef.latent_dim))  # context vectors
    d_g = jnp.ones((config.enf_config.dataset.batch_size, config.enf_config.nef.num_latents, 1))  # gaussian window parameter

    # Set key
    key = jax.random.PRNGKey(config.seed)

    # Define optimizer for the ENF backbone
    classifier_params = classifier.init(key, jnp.broadcast_to(poses, (config.enf_config.dataset.batch_size, *poses.shape[1:])), d_c, d_g, None)
    classifier_optimizer = optax.adam(learning_rate=config.optim.lr)
    classifier_opt_state = classifier_optimizer.init(classifier_params)

    # Count number of trainable parameters
    num_params = sum([p.size for p in jax.tree.flatten(classifier_params)[0]])
    logging.info(f"Number of trainable parameters: {num_params}.")

    ##############################
    # Training logic
    ##############################
    @jax.jit
    def inner_loop(params, x_i, y_i, key):
        # Unpack params
        (enf_params, poses), meta_sgd_params = params

        # Broadcast over batch size
        p = jnp.broadcast_to(poses, ((config.enf_config.dataset.batch_size, config.enf_config.nef.num_latents, poses.shape[-1])))

        # Add some noise to the poses
        p = p + jax.random.normal(key, p.shape) * 0.2 / jnp.sqrt(config.enf_config.nef.num_latents)

        # Initialize values for the poses, note that these depend on the bi-invariant, context and window
        c = jnp.ones((x_i.shape[0], config.enf_config.nef.num_latents, config.enf_config.nef.latent_dim))  # context vectors
        g = jnp.ones((x_i.shape[0], config.enf_config.nef.num_latents, 1)) * 2 / jnp.sqrt(
            config.enf_config.nef.num_latents)  # gaussian window parameter

        def mse_loss(z, x_i, y_i):
            out = enf_model.apply(enf_params, x_i, *z)
            return jnp.sum(jnp.mean((out - y_i) ** 2, axis=(1, 2)), axis=0)

        for i in range(config.enf_config.optim.inner_steps):
            loss, grads = jax.value_and_grad(mse_loss)((p, c, g), x_i, y_i)

            # Update the latent features
            c = c - meta_sgd_params * grads[1]

        # Return loss with resulting latents
        return mse_loss((p, c, g), x_i, y_i), (p, c, g)

    @jax.jit
    def cross_entropy_loss(params, p, c, g, labels):
        out = nn.log_softmax(classifier.apply(params, p, c, g, None), axis=-1)
        one_hot_labels = jnp.reshape(jax.nn.one_hot(labels, num_classes=NUM_CLASSES), (-1, NUM_CLASSES))
        return -jnp.mean(jnp.sum(one_hot_labels * out, axis=-1)), jnp.mean(jnp.argmax(out, axis=-1) == labels)

    # Let's define a train step
    @jax.jit
    def classifier_train_step(classifier_params, classifier_opt_state, enf_params, poses, x_i, y_i, labels, key):
        # Perform an inner loop with the trained ENF to obtain latents z.
        enf_params, meta_sgd_params = enf_params

        # Perform inner loop to obtain latents
        _, (p_0, c_0, g_0) = inner_loop(((enf_params, poses), meta_sgd_params), x_i, y_i, key)

        # Make zero mean
        c_0 = (c_0 - context_vectors_mean) / context_vectors_std
        key = jax.random.split(key)[0]

        # Take gradients wrt cross-entropy loss
        (loss, acc), grad = jax.value_and_grad(cross_entropy_loss, has_aux=True)(classifier_params, p_0, c_0, g_0, labels)

        # Get gradient updates
        classifier_updates, classifier_opt_state = classifier_optimizer.update(grad, classifier_opt_state)
        classifier_params = optax.apply_updates(classifier_params, classifier_updates)

        return (loss, acc), classifier_params, classifier_opt_state, key


    # Calculate statistics of the context vectors
    cs = []
    for i, batch in enumerate(train_dloader):
        img, _ = batch
        y = jnp.reshape(img, (img.shape[0], -1, img.shape[-1]))
        _, (p_b, c, g) = inner_loop([[enf_params, poses], meta_sgd_params], x, y, key)

        cs.append(c)
    cs = jnp.concatenate(cs, axis=0)
    context_vectors_mean, context_vectors_std = jnp.mean(cs, axis=(0, 1)), jnp.std(cs, axis=(0, 1))
    logging.info(f"Mean: {context_vectors_mean}, std: {context_vectors_std}")

    # Perform sanity check inner loop
    y_s = sample_img.reshape((sample_img.shape[0], -1, sample_img.shape[-1]))
    loss, (p_b, c, g) = inner_loop([[enf_params, poses], meta_sgd_params], x, y_s, key)

    recons = enf_model.apply(enf_params, x, p_b, c, g)
    recons = jnp.reshape(recons, (config.enf_config.dataset.batch_size, *img_shape))

    # Plot all config.enf_config.dataset.batch_size images
    fig, axs = plt.subplots(8, 8, figsize=(16, 16))
    for i in range(8): 
        for j in range(8):
            axs[i, j].imshow(jnp.clip(recons[i*8+j], 0, 1))
            axs[i, j].axis('off')

    plt.show()
    logging.info("Sanity check passed.")

    # Training loop
    for epoch in range(config.train.num_epochs):
        epoch_loss, epoch_acc = [], []
        for i, batch in enumerate(train_dloader):
            # Unpack batch, flatten img
            img, labels = batch
            y = jnp.reshape(img, (img.shape[0], -1, img.shape[-1]))

            # Perform training step
            (loss, acc), classifier_params, classifier_opt_state, key = classifier_train_step(
                classifier_params, 
                classifier_opt_state, 
                [enf_params, meta_sgd_params], 
                poses, 
                x, y, labels, 
                key,
            )
            epoch_loss.append(loss)
            epoch_acc.append(acc)

        if epoch % config.train.val_interval == 0:
            val_loss, val_acc = [], []
            for i, batch in enumerate(test_dloader):
                # Unpack batch, flatten img.
                img,  labels = batch
                y_i = jnp.reshape(img, (img.shape[0], -1, img.shape[-1]))
            
                # Perform inner loop to obtain latents
                _, (p_0, c_0, g_0) = inner_loop(((enf_params, poses), meta_sgd_params), x, y_i, key)
            
                # Make zero mean
                c_0 = (c_0 - context_vectors_mean) / context_vectors_std
                key = jax.random.split(key)[0]
            
                # Take gradients wrt cross-entropy loss
                loss, acc = cross_entropy_loss(classifier_params, p_0, c_0, g_0, labels)
                val_loss.append(loss)
                val_acc.append(acc)

            logging.info(f"val_loss {sum(val_loss) / len(val_loss)} -- val_acc {sum(val_acc) / len(val_acc)}")
        logging.info(f"epoch {epoch} -- loss: {sum(epoch_loss) / len(epoch_loss)}, acc: {sum(epoch_acc) / len(epoch_acc)}")


if __name__ == "__main__":
    app.run(main)
