"""Models and layers for the divisibility toy dataset.

Note that a lot of these probably already exist somewhere, but its
quicker/easier just to define them here instead of search for them.
"""
import dataclasses
import os

import tensorflow as tf

from em.datasets import divisibility as divis_ds
from em.util import hdf5_util

###############################################################################


def layer_from_layer_config(layer_config: str, activation: str) -> tf.keras.layers.Layer:
    """
    Examples of possible configs:
        "dense:128"
        "dense_res:128"
        "ffw_res:128,512"
        "ffw_res_nln:128,512" (the nln stands for no layer norm)
    """
    layer_class, layer_params = layer_config.lower().split(':')
    layer_params = layer_params.split(',')

    if layer_class == 'dense':
        n_units, = layer_params
        return tf.keras.layers.Dense(int(n_units), activation=activation)

    elif layer_class == 'dense_res':
        n_units, = layer_params
        return DenseRes(int(n_units), activation=activation)

    elif layer_class == 'ffw_res':
        d_model, d_ff = layer_params
        return FfwResBlock(
            int(d_model),
            int(d_ff),
            activation=activation,
            use_layer_norm=True,
        )

    elif layer_class == 'ffw_res_nln':
        d_model, d_ff = layer_params
        return FfwResBlock(
            int(d_model),
            int(d_ff),
            activation=activation,
            use_layer_norm=False,
        )

    else:
        raise ValueError(f'Invalid layer config: {layer_config}')


def needs_initial_layer(layer_config: str) -> bool:
    layer_class, _ = layer_config.lower().split(':')
    return layer_class in ('dense_res', 'ffw_res', 'ffw_res_nln')


def get_hidden_size(layer_config: str) -> int:
    layer_class, layer_params = layer_config.lower().split(':')
    layer_params = layer_params.split(',')
    return int(layer_params[0])

###############################################################################


@dataclasses.dataclass
class DivisModelConfig:
    layer_config: str
    n_layers: int
    embeddings_size: int
    activation_fn: str = 'relu'


def create_model(config: DivisModelConfig) -> tf.keras.Sequential:
    layers = [
        EmbeddingsLayer(embeddings_size=config.embeddings_size),
        tf.keras.layers.Flatten(),
    ]
    if needs_initial_layer(config.layer_config):
        layers.append(
            tf.keras.layers.Dense(get_hidden_size(config.layer_config), activation=None)
        )
    for _ in range(config.n_layers):
        layers.append(layer_from_layer_config(config.layer_config, config.activation_fn))
    layers.append(tf.keras.layers.Dense(2, activation=None))
    return tf.keras.Sequential(layers)


def load_model_from_file(filepath: str):
    filepath = os.path.expanduser(filepath)
    weights, metadata = hdf5_util.load_np_arrays_with_metadata(filepath)
    model_config = DivisModelConfig(**metadata['model_config'])
    model = create_model(model_config)
    ds_config = divis_ds.DivisibilityDatasetConfig(**metadata['final_ds_config'])
    model(tf.keras.Input(shape=[ds_config.n_total_digits], dtype=tf.int64))
    model.set_weights(weights)
    return model, model_config


###############################################################################


class EmbeddingsLayer(tf.keras.layers.Layer):
    def __init__(self, embeddings_size: int, vocab_size: int = 10, **kwargs):
        super().__init__(**kwargs)

        self.embeddings_table = self.add_weight(
            name='embeddings_table',
            shape=[vocab_size, embeddings_size],
            trainable=True,
        )

    def call(self, token_ids):
        return tf.gather(self.embeddings_table, token_ids)


class DenseRes(tf.keras.layers.Dense):
    """x + dense(x)"""

    def call(self, x, *args, **kwargs):
        return x + super().call(x)


class FfwResBlock(tf.keras.layers.Layer):
    """Like a transformer FFW block."""
    def __init__(self, d_model, d_ff, activation, use_layer_norm=True):
        super().__init__()
        self.dense1 = tf.keras.layers.Dense(d_ff, activation=activation)
        self.dense2 = tf.keras.layers.Dense(d_model, activation=activation)
        if use_layer_norm:
            self.layer_norm = tf.keras.layers.LayerNormalization()
        else:
            self.layer_norm = None

    def call(self, x):
        ret = x + self.dense2(self.dense1(x))
        if self.layer_norm is not None:
            ret = self.layer_norm(ret)
        return ret
