import jax
import jax.numpy as jnp
import jax.tree_util as jtu
import numpy as np
import equinox as eqx
import optax
import dataclasses

import smarter_jax as sj
import flowee
import nn


class FlowEmbedding(eqx.Module):
  """Embeds samples to create a latent representation of their distribution.

  Trains a conditional normalizing flow model to generate samples from the
  latent representation.

  Requires:
  emb_net: function mapping sample -> latent
  gen_flow: conditional normalizing flow mapping prior -> sample
  """
  emb_net: eqx.Module
  gen_flow: flowee.Flow

  def generate(self, z, nsamples=1, *, key):
    return sj.vmap(self.gen_flow.sample, splitkey=nsamples, kw_axes=None)(
      key, sideinfo=z)

  def embed(self, xs, weights=None):
    zs = jax.vmap(self.emb_net)(xs)
    if weights is None:
      return jnp.mean(zs, axis=0)
    else:
      return jnp.sum(zs * weights.reshape((-1,1)), axis=0) / jnp.sum(weights)

  def log_prob(self, xs, z, *, key=None):
    return sj.vmap(self.gen_flow.log_prob, splitkey=True, kw_axes=None)(
      xs, sideinfo=z, key=key)

  @sj.with_subkeys
  def loss(self, xs, weights=None, *, shuffle=True, key=None):
    if shuffle:
      assert key, "key required if shuffle=True"
      sinds = jax.random.permutation(next(key), jnp.arange(xs.shape[0]))
      if weights is None:
        xs = xs[sinds]
      else:
        xs, weights = xs[sinds], weights[sinds]

    if weights is None:
      xs_train, xs_test = jnp.split(xs, 2, axis=0)
      z = self.embed(xs_train)
      x_loss = sj.vmap(self.gen_flow.loss, splitkey=True, kw_axes=None)(
        xs_test, sideinfo=z, key=next(key))
      return jnp.mean(x_loss)

    else:
      mask = split_mask(weights)
      z = self.embed(xs, mask * weights)
      x_loss = sj.vmap(self.gen_flow.loss, splitkey=True, kw_axes=None)(
        xs, sideinfo=z, key=next(key))[:,np.newaxis]
      rv = jnp.sum(x_loss * (1-mask) * weights) / jnp.sum((1-mask) * weights)
      return rv


@eqx.filter_jit
def train(model, opt, opt_state, x, weights=None, *, key):
  @eqx.filter_value_and_grad
  def compute_loss(model, x, weights, key):
    return jnp.mean(sj.vmap(model.loss, splitkey=True)(x, weights, key=key))

  loss_value, grads = compute_loss(model, x, weights, key=key)
  updates, opt_state = opt.update(grads, opt_state, model)
  model = eqx.apply_updates(model, updates)

  return model, opt_state, loss_value
