import pathlib
from contextlib import nullcontext
from datetime import datetime

import click
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf

import microsoft_nlp
from microsoft_nlp import OWT2, OWT2Records, OWTSample, paths
from microsoft_nlp.utils import (
    PaddedSparseCategoricalCrossentropy,
    count_flops_per_step,
)


class ApplicationState:
    """Encapsulates the state of the click application for use in nested commands."""

    def __init__(
        self,
        model_type,
        dataset_cls,
        vocab_size,
        seq_len,
        embed_dim,
        n_layers,
        weight_tying,
        parallel_strategy,
        **model_settings,
    ):
        self.model_type = model_type
        self.dataset_cls = dataset_cls
        self.vocab_size = vocab_size
        self.seq_len = seq_len
        self.embed_dim = embed_dim
        self.n_layers = n_layers
        self.weight_tying = weight_tying
        self.model_settings = model_settings
        self.data_dtype = tf.uint16

        click.echo("Loading model...")
        Strategy = strategy_map[parallel_strategy]
        self.strategy = Strategy()

        with self.strategy.scope():
            embedding_params = None
            if model_type == "LMUD":
                self.model = microsoft_nlp.models.create_lmud(
                    vocab_size=vocab_size,
                    seq_len=seq_len,
                    n_layers=n_layers,
                    activation=model_settings["activation"],
                    order=model_settings["order"],
                    theta=model_settings["theta"],
                    theta_factor=model_settings["theta_factor"],
                    embed_dim=embed_dim,
                    n_filters=model_settings["n_filters"],
                    share_filters=model_settings["share_filters"],
                    lmud_order=model_settings["lmud_order"],
                    post_ffn=model_settings["post_ffn"],
                    pre_ffn=model_settings["pre_ffn"],
                    n_heads=model_settings["n_heads"],
                    weight_tying=weight_tying,
                    option2=model_settings["option2"],
                    lmue_like=model_settings["lmue_like"],
                    eqn11=model_settings["eqn11"],
                    order_increment=model_settings["order_increment"],
                )
            elif model_type.lower() == "lmue":
                self.model = microsoft_nlp.models.create_lmue(
                    vocab_size=vocab_size,
                    seq_len=seq_len,
                    n_layers=n_layers,
                    activation=model_settings["activation"],
                    order=model_settings["order"],
                    theta=model_settings["theta"],
                    theta_factor=model_settings["theta_factor"],
                    embed_dim=embed_dim,
                    ff_dim=model_settings["ff_dim"],
                    n_filters=model_settings["n_filters"],
                    share_filters=model_settings["share_filters"],
                    layernorm=model_settings["layernorm"],
                    positional_embed=model_settings["positional_embed"],
                    lmud_order=model_settings["lmud_order"],
                    pre_ffn=model_settings["pre_ffn"],
                    post_ffn=model_settings["post_ffn"],
                    option2=model_settings["option2"],
                    weight_tying=weight_tying,
                )
            elif model_type.lower() in ("lmumlp", "lmu-mlp"):
                self.model = microsoft_nlp.models.create_lmu_mlp(
                    vocab_size=vocab_size,
                    seq_len=seq_len,
                    n_layers=n_layers,
                    activation=model_settings["activation"],
                    order=model_settings["order"],
                    theta=model_settings["theta"],
                    theta_factor=model_settings["theta_factor"],
                    embed_dim=embed_dim,
                    ff_dim=model_settings["ff_dim"],
                    n_filters=model_settings["n_filters"],
                    share_filters=model_settings["share_filters"],
                    # gating=model_settings["gating"],
                    layernorm=model_settings["layernorm"],
                    weight_tying=weight_tying,
                )
            elif model_type == "LSTM":
                self.model = microsoft_nlp.models.create_lstm(
                    vocab_size=vocab_size,
                    seq_len=seq_len,
                    n_layers=n_layers,
                    embed_dim=embed_dim,
                    hidden_dim=model_settings["hidden_dim"],
                    weight_tying=weight_tying,
                )
            elif model_type == "hfGPT1":
                assert model_settings["ff_dim"] is None, "'hfGPT1' does not use ff_dim"
                n_heads = model_settings["n_heads"]
                n_heads = max(1, embed_dim // 64) if n_heads is None else n_heads
                self.model = microsoft_nlp.models.create_huggingface_gpt1(
                    vocab_size=vocab_size,
                    seq_len=seq_len,
                    embed_dim=embed_dim,
                    n_layers=n_layers,
                    n_heads=n_heads,
                )
                self.data_dtype = tf.int32
                embedding_params = (vocab_size + seq_len) * embed_dim
            elif model_type == "GPT1":
                n_heads = model_settings["n_heads"]
                n_heads = max(1, embed_dim // 64) if n_heads is None else n_heads
                self.model = microsoft_nlp.models.create_model_gpt1(
                    vocab_size=vocab_size,
                    seq_len=seq_len,
                    n_layers=n_layers,
                    embed_dim=embed_dim,
                    ff_dim=model_settings["ff_dim"],
                    n_heads=n_heads,
                    weight_tying=weight_tying,
                )
                embedding_params = (vocab_size + seq_len) * embed_dim
            elif model_type == "GPT2":
                n_heads = model_settings["n_heads"]
                n_heads = max(1, embed_dim // 64) if n_heads is None else n_heads
                self.model = microsoft_nlp.models.create_model_gpt2(
                    vocab_size=vocab_size,
                    seq_len=seq_len,
                    n_layers=n_layers,
                    embed_dim=embed_dim,
                    ff_dim=model_settings["ff_dim"],
                    n_heads=n_heads,
                    weight_tying=weight_tying,
                )
                embedding_params = (vocab_size + seq_len) * embed_dim
            else:
                raise NotImplementedError(f"Unsupported: {model_type}")

        self.model.summary(print_fn=click.echo)
        self.relevant_params = microsoft_nlp.utils.parameter_summary(
            self.model, print_fn=click.echo, embedding_params=embedding_params
        )

        click.echo("Loading dataset...")
        self.dataset = dataset_to_cls[dataset_cls]()
        click.echo(str(self.dataset))
        assert vocab_size == len(self.dataset.vocab)


pass_application_state = click.make_pass_decorator(ApplicationState)

dataset_to_cls = {
    "owt-sample": OWTSample,
    "owt2": OWT2,
    "owt2-records": OWT2Records,
}


@click.group()
@click.option(
    "--model_type",
    type=click.Choice(["LMUD", "LMUMLP", "LMUE", "LSTM", "hfGPT1", "GPT1", "GPT2"]),
    required=True,
    help="Type of model architecture",
)
@click.option(
    "--dataset_cls",
    type=click.Choice(list(dataset_to_cls.keys())),
    default="owt2-records",
    help="Name of dataset class",
    show_default=True,
)
@click.option(
    "--vocab_size",
    type=int,
    default=50257,
    help="Size of the vocabulary",
    show_default=True,
)
@click.option(
    "--seq_len",
    type=int,
    default=1024,
    help="Maximum length of the sequence",
    show_default=True,
)
@click.option("--embed_dim", type=int, help="Dimension of the word embedding vectors")
@click.option("--n_layers", type=int, help="Number of layers in architecture")
@click.option(
    "--weight_tying/--no_weight_tying",
    default=True,
    help="Enable weight tying",
    show_default=True,
)
@click.option(
    "--parallel_strategy",
    type=click.Choice(["none", "Mirrored"]),
    default="none",
    help="tf.distribute strategy for parallel training",
)
@click.option(
    "--activation",
    type=click.Choice(["relu", "gelu", "tanh", "sigmoid"]),
    default="gelu",
    help="Type of activation function to suggest to the architecture",
    show_default=True,
)
@click.option("--order", type=int, help="[LMU* only] LMU's order")
@click.option("--theta", type=float, help="[LMU* only] LMU's theta")
@click.option(
    "--theta_factor",
    type=float,
    default=1,
    help="[LMU* only] Factor by which to scale theta from one layer to the next",
)
@click.option(
    "--ff_dim",
    type=int,
    help="[LMU, GPT1 and GPT2 only] Dimensionality of each layer / ff sub_block",
)
@click.option(
    "--hidden_dim",
    type=int,
    help="[LSTM only] Hidden dimensionality of LSTM",
)
@click.option(
    "--n_filters", type=int, help="[LMU2Dli only] Number of filters from the LMU output"
)
@click.option(
    "--n_heads",
    type=int,
    default=None,
    help="[Transformers and LMUD only] Number of heads for the attention",
)
@click.option(
    "--share_filters/--no_share_filters",
    default=False,
    help="[LMU2Dli only] Whether to share the filters or be unique per dimension",
    show_default=True,
)
@click.option(
    "--gating/--no_gating",
    default=True,
    help="[LMUMLP only] Whether to use gating",
    show_default=True,
)
@click.option(
    "--layernorm/--no_layernorm",
    default=False,
    help="[LMUE only] Whether to use layer normalization",
    show_default=True,
)
@click.option(
    "--initial_dense",
    type=click.Choice(["affine", "non-linear", "none"]),
    default="affine",
    help="[LMU2Dli only] The type of Dense layer that appears right after the embedding layer",
    show_default=True,
)
@click.option(
    "--positional_embed/--no_positional_embed",
    default=False,
    help="[LMU2Dli only] Whether to use positional embedding",
    show_default=True,
)
@click.option(
    "--lmue_like/--no_lmue_like",
    default=False,
    help="[LMUD only] Whether to use LMUE-like stream sharing structure",
    show_default=True,
)
@click.option(
    "--lmud_order",
    type=float,
    default=1,
    help="[LMU2Dli-LMUD only] Size of 'dense_r' and 'dense_r_gate' matrices is set to 'int(lmud_order * order)'",
    show_default=True,
)
@click.option(
    "--pre_ffn",
    type=float,
    default=0,
    help="[LMUE only] Expansion factor for pre-FFN (0 to disable pre-FFN)",
    show_default=True,
)
@click.option(
    "--post_ffn",
    type=float,
    default=0,
    help="[LMUE and LMUD only] Expansion factor for post-FFN (0 to disable post-FFN)",
    show_default=True,
)
@click.option(
    "--option2/--no_option2",
    default=False,
    help="[LMUE and LMUD only] Whether to use a second gating stream in LMUE block",
    show_default=True,
)
@click.option(
    "--eqn11",
    type=int,
    default=0,
    help="[LMUD only] Four options avaialabe (0, 1, 2, 3). 0 is the default equation 11 option",
    show_default=True,
)
@click.option(
    "--order_increment",
    type=int,
    default=0,
    help="[LMUD* only] Amount by which the order is incremented from one layer to the next",
)
@click.pass_context
def cli(ctx, *args, **kwargs):
    ctx.obj = ApplicationState(*args, **kwargs)  # access via @pass_application_state


@cli.command()
@pass_application_state
@click.option("--batch_size", type=int, required=True, help="Batch size for training")
@click.option(
    "--tokens_per_epoch",
    type=int,
    default=200 * 1024 * 512,  # ~105B
    help="Number of tokens to process per training epoch",
    show_default=True,
)
@click.option(
    "--patience",
    type=int,
    default=25,
    help="Number of epochs to wait for an improvement in train loss before stopping",
    show_default=True,
)
@click.option(
    "--steps",
    type=int,
    default=250000,
    help="Maximum number of training steps w.r.t. an effective batch size of 512",
    show_default=True,
)
@click.option(
    "--decay_steps",
    type=int,
    default=250000,
    help=(
        "Number of training steps for learning rate decay w.r.t. an effective batch "
        "size of 512"
    ),
    show_default=True,
)
@click.option(
    "--warmup_steps",
    type=int,
    default=3000,
    help=(
        "Number of training steps for learning rate warmup w.r.t. an effective batch "
        "size of 512"
    ),
    show_default=True,
)
@click.option(
    "--model_weights",
    type=click.Path(dir_okay=False, exists=True),
    help="Model weights to load before training",
)
@click.option(
    "--initial_epoch",
    type=int,
    default=0,
    help="Epoch to start at (typically for resuming an old run)",
    show_default=True,
)
@click.option(
    "--learning_scale",
    type=str,
    required=True,
    help="Scaling for the learning rate. A float "
    "or 'batch' (scale with square root of batch size)",
)
def train(
    app,
    batch_size,
    tokens_per_epoch,
    patience,
    steps,
    decay_steps,
    warmup_steps,
    model_weights,
    initial_epoch,
    learning_scale,
):
    """Trains an NLP model on the chosen dataset."""
    model = app.model
    dataset = app.dataset
    seq_len = app.seq_len
    splits = dataset.get_splits(
        seq_len=seq_len, batch_size=batch_size, dtype=app.data_dtype
    )

    if model_weights is not None:
        click.echo(f"Warm restart from: {model_weights}")
        model.load_weights(model_weights)

    click.echo("Launching training run...")
    if initial_epoch > 0:
        click.echo(f"Starting learning schedule at epoch {initial_epoch}")
    unique_name = f"{model.name}_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"
    weights_file = paths.weights / f"{unique_name}.hdf5"
    click.echo(f"Saving weights to: {weights_file}")

    # The learning rate schedule has its constants configured for a batch size of
    #  512 and so we rescale them to get approximately the same effect.
    effective_batch_size = 512
    assert 1 <= batch_size <= effective_batch_size
    epochs = (steps * effective_batch_size * seq_len) // tokens_per_epoch
    steps_per_epoch = tokens_per_epoch // (batch_size * seq_len)
    learning_scale = microsoft_nlp.learning_schedulers.learning_scale_from_string(
        learning_scale, batch_size, effective_batch_size
    )
    lr_schedule = microsoft_nlp.learning_schedulers.CustomLearningRateScheduler(
        decay_steps=(decay_steps * effective_batch_size) // batch_size,
        warmup_steps=(warmup_steps * effective_batch_size) // batch_size,
        lr_scale=learning_scale,
        initial_step=initial_epoch * steps_per_epoch,
        non_embedding_params=app.relevant_params,
    )

    callbacks = [
        lr_schedule,
        tf.keras.callbacks.ModelCheckpoint(
            weights_file, monitor="loss", save_weights_only=True, save_best_only=True
        ),
        tf.keras.callbacks.EarlyStopping(monitor="loss", patience=patience),
        tf.keras.callbacks.TensorBoard(
            log_dir=paths.logs / unique_name,
            histogram_freq=1,
            update_freq=effective_batch_size // batch_size,  # in units of batches
        ),
    ]

    _ = model.fit(
        splits["train"],
        epochs=epochs,
        steps_per_epoch=steps_per_epoch,
        callbacks=callbacks,
        initial_epoch=initial_epoch,
    )


@cli.command()
@pass_application_state
@click.option("--batch_size", type=int, required=True, help="Batch size for testing")
@click.option(
    "--n_examples",
    type=int,
    default=None,
    help="Number of examples to test (defaults to the entire test set)",
)
@click.option(
    "--model_weights",
    type=click.Path(dir_okay=False, exists=True),
    default=None,
    help="Model weights to load before testing",
)
@click.option(
    "--test_split",
    type=str,
    default="test",
    help="Dataset split to use for testing",
)
@click.option(
    "--pad_docs/--no_pad_docs",
    default=False,
    help="Whether to pad documents to use at most one document per sequence",
)
def test(app, batch_size, n_examples, model_weights, test_split, pad_docs):
    """Tests an NLP model on the chosen dataset."""

    # When testing in the single document mode, we use the same pad_value as
    #  eod_token because the time-step where the input is eod_token we do not
    #  require the model to output anything.
    pad_value = app.dataset.eod_token if pad_docs else None
    split = app.dataset.get_split(
        split=test_split,
        seq_len=app.seq_len,
        batch_size=batch_size,
        dtype=app.data_dtype,
        pad_value=pad_value,
        shuffle_buffer=None,
        repeat=False,
    )
    if model_weights is not None:
        click.echo(f"Loading weights from '{model_weights}'")
        app.model.load_weights(model_weights)

    if n_examples is not None:
        batches = -(-n_examples // batch_size)  # take ceiling
        split = split.take(batches)
        click.echo(f"Testing ({batches} batches)...")
    else:
        click.echo("Testing...")

    with app.strategy.scope():
        token_loss = microsoft_nlp.metrics.TokenLoss(app.seq_len, mask_value=pad_value)
        app.model.compile(metrics=[token_loss])

        _, loss_per_token = app.model.evaluate(split)
        loss = token_loss.overall_loss()
        click.echo(f"Loss: {loss}")

    # plot and save loss per token
    model_unique_name = pathlib.Path(model_weights).stem
    file_name = f"{model_unique_name}_{app.dataset_cls}_{test_split}"
    file_name = (file_name + "_pad") if pad_docs else file_name
    if n_examples is not None:
        file_name += f"_examples={n_examples}"
    plots_file = paths.plots / f"{file_name}.pdf"
    array_file = paths.plots / f"{file_name}.npy"

    plt.xscale("log")
    plt.plot(tf.range(1, app.seq_len + 1), loss_per_token)
    plt.xlabel("Token Index in Context")
    plt.ylabel("Per-token Test Loss")
    plt.title(f"overall loss: {loss}")
    plt.savefig(plots_file)
    click.echo(f"Saved loss-per-token plot to '{plots_file}'")

    # save loss_per_token array
    np.save(array_file, loss_per_token)
    click.echo(f"Saved loss-per-token array to '{array_file}'")


@cli.command()
@pass_application_state
def summary(app):
    """Summarizes the NLP model and the chosen dataset."""

    flops = count_flops_per_step(app.model)
    # deembedding uses embed_dim * vocab_size matmul
    non_embed_flops = flops - 2 * app.vocab_size * app.embed_dim
    print(f"FLOPs per step: {flops}")
    print(f"Non-embed FLOPs per step: {non_embed_flops}")


class NullStrategy:
    def scope(self):
        return nullcontext()


strategy_map = {
    None: NullStrategy,
    "none": NullStrategy,
    "Mirrored": tf.distribute.MirroredStrategy,
}


if __name__ == "__main__":
    cli()
