#!/usr/bin/env python3

import json

import distrax
import equinox as eqx
import jax
import jax.numpy as jnp
import jax.tree_util as jtu
import numpy as np
import optax

import flowee
import nn


class MeanEmbedder(eqx.Module):
    phi: eqx.nn.MLP

    def __call__(self, xs: jax.Array, ws: jax.Array | None = None) -> jax.Array:
        ws = ws if ws is not None else jnp.ones(xs.shape[0])
        zs = jax.vmap(self.phi)(xs) * ws[:, jnp.newaxis]

        return jnp.sum(zs, axis=0) / jnp.sum(ws)


class DeepSetsEmbedder(eqx.Module):
    phi: eqx.nn.MLP
    rho: eqx.nn.MLP

    def __call__(self, xs: jax.Array, ws: jax.Array | None = None) -> jax.Array:
        ws = ws if ws is not None else jnp.ones(xs.shape[0])
        zs = jax.vmap(self.phi)(xs) * ws[:, jnp.newaxis]

        return self.rho(jnp.sum(zs, axis=0))


class RecurrentEmbedder(eqx.Module):
    lstm: nn.MultiLayerLSTM
    head: eqx.nn.MLP

    def __call__(self, obs: jax.Array, hidden: jax.Array) -> tuple[jax.Array, jax.Array]:
        out, hidden = self.lstm(obs, hidden)

        return self.head(out), hidden

    def predict_sequence(self, obs: jax.Array, hidden: jax.Array) -> tuple[jax.Array, jax.Array]:
        def _step(hidden, obs):
            return self(obs, hidden)[::-1]

        return jax.lax.scan(_step, hidden, obs)[::-1]

    def reset(self) -> jax.Array:
        return jnp.zeros((self.lstm.num_layers, self.lstm.hidden_size * 2), jnp.float32)


class ConditionalFlow(eqx.Module):
    embedding_network: eqx.Module
    flow_network: flowee.Flow

    def embed(self, samples: jax.Array, weights: jax.Array | None = None) -> jax.Array:
        return self.embedding_network(samples, weights)

    def sample(
        self, embedding: jax.Array, num_samples: int, key: jax.random.PRNGKey
    ) -> jax.Array:
        return jax.vmap(self.flow_network.sample)(
            jax.random.split(key, num_samples),
            sideinfo=jnp.repeat(embedding[jnp.newaxis], num_samples, axis=0)
        )

    def log_prob(
        self, samples: jax.Array, embedding: jax.Array, key: jax.random.PRNGKey
    ) -> jax.Array:
        return jax.vmap(self.flow_network.log_prob)(
            samples,
            sideinfo=jnp.repeat(embedding[jnp.newaxis], samples.shape[0], axis=0),
            key=jax.random.split(key, samples.shape[0])
        )

    def compute_loss(
        self, data: jax.Array, *, shuffle: bool = True, key: jax.random.PRNGKey
    ) -> float:
        """ Compute the loss given samples from one distribution """

        if shuffle:
            # Shuffle individual samples within a single distribution
            assert key is not None, 'Key is required when shuffle=True!'

            key, permutation_key = jax.random.split(key, 2)
            indices = jax.random.permutation(permutation_key, jnp.arange(data.shape[0]))

            data = data[indices]

        data_train, data_test = jnp.split(data, 2, axis=0)
        embedding = self.embed(data_train)

        key, *loss_keys = jax.random.split(key, 1 + data_test.shape[0])
        flow_loss = jnp.mean(
            jax.vmap(self.flow_network.loss)(
                data_test,
                sideinfo=jnp.repeat(embedding[jnp.newaxis], data_test.shape[0], axis=0),
                key=jnp.array(loss_keys)
            )
        )

        return flow_loss


class RecurrentConditionalFlow(ConditionalFlow):
    def embed(self, observation: jax.Array, hidden: jax.Array) -> tuple[jax.Array, jax.Array]:
        return self.embedding_network(observation, hidden)

    def reset(self) -> jax.Array:
        return self.embedding_network.reset()

    def compute_loss(
        self, data: jax.Array, *, shuffle: bool = True, key: jax.random.PRNGKey
    ) -> float:
        observations, samples = data

        hidden = self.embedding_network.reset()
        embeddings, hidden = self.embedding_network.predict_sequence(observations, hidden)

        def batch_loss(samples: jax.Array, embedding: jax.Array, key: jax.random.PRNGKey) -> float:
            return jnp.mean(
                jax.vmap(self.flow_network.loss)(
                    samples,
                    sideinfo=jnp.repeat(embedding[jnp.newaxis], samples.shape[0], axis=0),
                    key=jax.random.split(key, samples.shape[0])
                )
            )

        key, *loss_keys = jax.random.split(key, 1 + samples.shape[0])
        flow_loss = jnp.mean(jax.vmap(batch_loss)(samples, embeddings, key=jnp.array(loss_keys)))

        return flow_loss


def parse_network_hyperparams(hyperparams: str) -> list[int]:
    return [int(param) for param in hyperparams.split(',')]


def create_model(
    embedding_net_hyperparams: dict[str, int],
    flow_net_hyperparams: dict[str, int],
    key: jax.random.PRNGKey
) -> ConditionalFlow:
    # Create an embedding network
    embedder = embedding_net_hyperparams['embedder']
    obs_size = embedding_net_hyperparams['obs_size']
    num_cards = embedding_net_hyperparams['num_cards']
    embedding_size = embedding_net_hyperparams['embedding_size']
    num_layers = embedding_net_hyperparams['num_layers']
    hidden_size = embedding_net_hyperparams['hidden_size']

    if embedder == 'mean':
        key, phi_key = jax.random.split(key, 2)
        phi = eqx.nn.MLP(num_cards, embedding_size, hidden_size, num_layers, key=phi_key)
        embedding_network = MeanEmbedder(phi)
    elif embedder == 'deepsets':
        key, phi_key, rho_key = jax.random.split(key, 3)
        phi = eqx.nn.MLP(num_cards, embedding_size, hidden_size, num_layers, key=phi_key)
        rho = eqx.nn.MLP(embedding_size, embedding_size, hidden_size, num_layers, key=rho_key)
        embedding_network = DeepSetsEmbedder(phi, rho)
    elif embedder == 'rnn':
        key, lstm_key, head_key = jax.random.split(key, 3)
        lstm = nn.MultiLayerLSTM(obs_size, num_layers, hidden_size, key=lstm_key)
        head = eqx.nn.MLP(hidden_size, embedding_size, hidden_size, num_layers, key=head_key)
        embedding_network = RecurrentEmbedder(lstm, head)
    else:
        raise ValueError(f'Unknown embedder type: {embedder}!')

    def mask(i: int, num_cards: int) -> jax.Array:
        block_size = 1 if i % 2 == 0 else 3
        return flowee.create_mask((num_cards,), (block_size,), dtype=jnp.uint8)

    # Create a normalizing flow model
    dequant_num_layers = flow_net_hyperparams['dequant_num_layers']
    dequant_hidden_size = flow_net_hyperparams['dequant_hidden_size']
    dequant_num_params = flow_net_hyperparams['dequant_num_params']
    num_coupling_layers = flow_net_hyperparams['num_coupling_layers']
    num_layers = flow_net_hyperparams['num_layers']
    hidden_size = flow_net_hyperparams['hidden_size']
    num_params = flow_net_hyperparams['num_params']

    key, dequant_key = jax.random.split(key, 2)
    dequant_layer = flowee.Dequantize(
        max_val=num_cards - 1, in_dtype=jnp.int8,
        var_flow=flowee.Coupling(
            mask(0, num_cards),
            nn.MultiMLP(
                (num_cards, num_cards), num_cards * dequant_num_params,
                dequant_hidden_size, dequant_num_layers, key=dequant_key
            ),
            dual=True
        )
    )

    key, *flow_keys = jax.random.split(key, num_coupling_layers + 1)
    flow_network = flowee.Sequential(
        [dequant_layer] +
        [
            flowee.Coupling(
                mask(i, num_cards),
                nn.MLP(
                    num_cards + embedding_size, num_cards * num_params,
                    hidden_size, num_layers, key=flow_keys[i]
                ),
                flowee.ParameterizedNLSq((num_cards,)),
                dual=True
            )
            for i in range(num_coupling_layers)
        ]
    )
    flow_network.add_prior(distrax.Normal(0.0, 1.0), (num_cards,))

    model = RecurrentConditionalFlow(embedding_network, flow_network) \
        if embedder == 'rnn' else ConditionalFlow(embedding_network, flow_network)

    print(
        'Model parameters:', np.sum([
            l.size for l in jtu.tree_leaves(eqx.filter(model, eqx.is_inexact_array))
        ])
    )

    return model


@eqx.filter_jit
def train_step(
    model: ConditionalFlow, optimizer: optax.GradientTransformation,
    optimizer_state: optax.OptState, data: jax.Array, key: jax.random.PRNGKey
) -> tuple[ConditionalFlow, optax.OptState, float]:
    batch_size = data[0].shape[0] if isinstance(data, tuple) else data.shape[0]

    @eqx.filter_value_and_grad
    def compute_loss(model: ConditionalFlow, data: jax.Array, key: jax.random.PRNGKey) -> float:
        return jnp.mean(
            jax.vmap(model.compute_loss)(data, key=jax.random.split(key, batch_size))
        )

    loss_value, grads = compute_loss(model, data, key)
    updates, optimizer_state = optimizer.update(grads, optimizer_state, model)
    model = eqx.apply_updates(model, updates)

    return model, optimizer_state, loss_value


def load_model(model_file: str, key: jax.random.PRNGKey) -> ConditionalFlow:
    with open(model_file, 'rb') as f:
        hyperparams = json.loads(f.readline().decode())
        model = create_model(**hyperparams, key=key)
        print(*[f'{k}: {v}' for k, v in hyperparams.items()], sep='\n')

        return eqx.tree_deserialise_leaves(f, model)
