from typing import Any

import equinox as eqx
import jax
import jax.numpy as jnp

import flows.flowee as flowee


def split_mask(weights: jax.Array):
    num_particles = weights.shape[0]

    cumsum_weights = jnp.cumsum(weights)
    total_weight = cumsum_weights[-1]
    split_idx = jnp.searchsorted(cumsum_weights, total_weight / 2.0)
    split_idx = jnp.clip(split_idx, 0, num_particles - 2)
    mask = jnp.arange(num_particles) <= split_idx

    return mask


class FlowEmbedding(eqx.Module):
    """Embeds samples to create a latent representation of their distribution.

    Trains a conditional normalizing flow model to generate samples from the
    latent representation.

    Requires:
    emb_net: function mapping sample -> latent
    gen_flow: conditional normalizing flow mapping prior -> sample
    """

    emb_net: eqx.Module
    gen_flow: flowee.Flow

    def generate(self, z: jax.Array, nsamples: int, key: jax.Array, params: Any | None = None):
        keys = jax.random.split(key, nsamples)
        return jax.vmap(self.gen_flow.sample, in_axes=(0, None, None))(keys, z, params)

    def embed(self, xs: jax.Array, weights: jax.Array | None = None):
        zs = jax.vmap(self.emb_net)(xs)

        if weights is None:
            weights = jnp.ones(xs.shape[0])

        embedding = jnp.sum(zs * weights.reshape((-1, 1)), axis=0) / jnp.sum(weights)
        norm = jnp.linalg.norm(embedding) + 1e-8

        return embedding / norm, norm

    def log_prob(self, xs: jax.Array, z: jax.Array, key: jax.Array | None = None):
        if key is None:
            return jax.vmap(self.gen_flow.log_prob, in_axes=(0, None, None, None, None))(
                xs, 1, None, z, None
            )

        keys = jax.random.split(key, xs.shape[0])
        return jax.vmap(
            self.gen_flow.log_prob,
            in_axes=(0, None, 0, None, None),
        )(xs, 1, keys, z, None)

    def loss(
        self,
        xs: jax.Array,
        weights: jax.Array | None = None,
        shuffle: bool = True,
        key: jax.Array | None = None,
        params: Any | None = None,
    ):
        if shuffle:
            assert key is not None, 'key required if shuffle=True'

            key, perm_key = jax.random.split(key)
            sinds = jax.random.permutation(perm_key, jnp.arange(xs.shape[0]))

            if weights is None:
                xs = xs[sinds]
            else:
                xs, weights = xs[sinds], weights[sinds]

        if weights is None:
            xs_train, xs_test = jnp.split(xs, 2, axis=0)
            z, norm = self.embed(xs_train)

            keys = jax.random.split(key, xs_test.shape[0])
            loss, prior_log_prob, ldj = jax.vmap(
                self.gen_flow.loss, in_axes=(0, None, 0, None, None)
            )(xs_test, 1, keys, z, params)

            return jnp.mean(loss), jnp.mean(prior_log_prob), jnp.mean(ldj), norm
        else:
            mask = split_mask(weights)
            z, norm = self.embed(xs, mask * weights)

            keys = jax.random.split(key, xs.shape[0])
            loss, prior_log_prob, ldj = jax.vmap(
                self.gen_flow.loss, in_axes=(0, None, 0, None, None)
            )(xs, 1, keys, z, params)
            mean_loss = jnp.sum(loss * (1 - mask) * weights) / jnp.sum((1 - mask) * weights)

            return mean_loss, jnp.mean(prior_log_prob), jnp.mean(ldj), norm


@eqx.filter_jit
def train(model, opt, opt_state, xs, key):
    @eqx.filter_value_and_grad
    def compute_loss(model, xs, keys):
        loss, *_ = jax.vmap(model.loss)(xs, key=keys)

        return jnp.mean(loss)

    keys = jax.random.split(key, xs.shape[0])
    loss, grads = compute_loss(model, xs, keys)

    updates, opt_state = opt.update(grads, opt_state, model)
    model = eqx.apply_updates(model, updates)

    return model, opt_state, loss
