from collections.abc import MutableMapping
import time
from typing import Any, NamedTuple, Union, Optional

import haiku as hk
import jax
import jax.numpy as jnp
import numpy as np
import optax
import argparse
import os

from optimizers import *

import dataclasses

# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A simple example loader for an ASCII language-modelling """

from collections.abc import Iterable, Iterator
import itertools
import random
from typing import NamedTuple, TypeVar

import numpy as np


# dataset preparation

VOCAB_SIZE = 128  # Number of ASCII code points.
PAD_TOKEN = 0

_T = TypeVar('_T')

class Batch(NamedTuple):
    inputs: np.ndarray  # Integer tokens, shape [B, T].
    targets: np.ndarray  # Integer tokens, shape [B, T].


def repeat(dataset: Iterable[_T]) -> Iterator[_T]:
    return itertools.cycle(dataset)


def shuffle(dataset: Iterator[_T], buffer_size: int) -> Iterator[_T]:
    buffer = [next(dataset) for _ in range(buffer_size)]
    random.shuffle(buffer)
    for item in dataset:
        idx = random.randint(0, buffer_size - 1)  # Inclusive.
        result = buffer[idx]
        buffer[idx] = item
        yield result


def load_ascii_dataset(
    corpus: str,
    *,
    batch_size: int,
    sequence_length: int,
    num_shuffle_batches: int = 10,
) -> Iterator[Batch]:
    """Loads a single-file ASCII dataset in memory."""

    if not corpus.isascii():
        raise ValueError('Loaded corpus is not ASCII.')

    if chr(PAD_TOKEN) in corpus:  # Reserve 0 codepoint for pad token.
        raise ValueError('Corpus must not contain the null byte.')

    # Naively tokenise by taking ASCII codepoints.
    corpus = np.array([ord(c) for c in corpus]).astype(np.int32)
    assert np.max(corpus) < VOCAB_SIZE

    crop_len = sequence_length + 1
    num_batches, remainder = divmod(corpus.size, batch_size * crop_len)
    if remainder:
        corpus = corpus[:-remainder]  # Drop remainder (incomplete) batch.
    ds = corpus.reshape([-1, crop_len])

    if num_batches < num_shuffle_batches:
        raise ValueError(
            f'Only {num_batches} batches in the dataset; consider using a shorter '
            'sequence length or a smaller batch batch size.',
        )

    ds = repeat(ds)
    ds = shuffle(ds, buffer_size=batch_size * num_shuffle_batches)
    while True:
        batch = np.stack([next(ds) for _ in range(batch_size)])
        yield Batch(inputs=batch[:, :-1], targets=batch[:, 1:])

# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Didactic example of an autoregressive Transformer-based language 

Glossary of shapes:
- B: Batch size.
- T: Sequence length.
- D: Model embedding size.
- H: Number of attention heads.
- V: Vocabulary size.
"""

def _layer_norm(x: jax.Array) -> jax.Array:
    """Applies a unique LayerNorm to `x` with default settings."""
    ln = hk.LayerNorm(axis=-1, create_scale=True, create_offset=True)
    return ln(x)


@dataclasses.dataclass
class Transformer(hk.Module):
    """A transformer stack."""

    num_heads: int  # Number of attention heads.
    num_layers: int  # Number of transformer (attention + MLP) layers to stack.
    attn_size: int  # Size of the attention (key, query, value) vectors.
    dropout_rate: float  # Probability with which to apply dropout.
    widening_factor: int = 4  # Factor by which the MLP hidden layer widens.
    name: Optional[str] = None  # Optional identifier for the module.

    def __call__(
            self,
            embeddings: jax.Array,  # [B, T, D]
            mask: jax.Array,  # [B, T]
    ) -> jax.Array:  # [B, T, D]
        """Transforms input embedding sequences to output embedding sequences."""

        initializer = hk.initializers.VarianceScaling(2 / self.num_layers)
        _, seq_len, model_size = embeddings.shape

        # Compute causal mask for autoregressive sequence modelling.
        mask = mask[:, None, None, :]  # [B, H=1, T'=1, T]
        causal_mask = np.tril(np.ones((1, 1, seq_len, seq_len)))  # [B=1, H=1, T, T]
        mask = mask * causal_mask  # [B, H=1, T, T]

        h = embeddings
        for _ in range(self.num_layers):
            # First the attention block.
            attn_block = hk.MultiHeadAttention(
                num_heads=self.num_heads,
                key_size=self.attn_size,
                model_size=model_size,
                w_init=initializer,
            )
            h_norm = _layer_norm(h)
            h_attn = attn_block(h_norm, h_norm, h_norm, mask=mask)
            h_attn = hk.dropout(hk.next_rng_key(), self.dropout_rate, h_attn)
            h = h + h_attn

            # Then the dense block.
            dense_block = hk.Sequential([
                hk.Linear(self.widening_factor * model_size, w_init=initializer),
                jax.nn.gelu,
                hk.Linear(model_size, w_init=initializer),
            ])
            h_norm = _layer_norm(h)
            h_dense = dense_block(h_norm)
            h_dense = hk.dropout(hk.next_rng_key(), self.dropout_rate, h_dense)
            h = h + h_dense

        return _layer_norm(h)


@dataclasses.dataclass
class LanguageModel(hk.Module):
    """An autoregressive transformer-based language """

    transformer: Transformer
    model_size: int  # Embedding size.
    vocab_size: int  # Size of the vocabulary.
    pad_token: int  # Identity of the padding token (used for masking inputs).
    name: Optional[str] = None  # Optional identifier for the module.

    def __call__(
        self,
        tokens: jax.Array,  # Batch of sequences of input tokens, shape [B, T].
    ) -> jax.Array:  # Batch of sequences of output token logits, shape [B, T, V].
        """Forward pass, producing a sequence of logits."""
        input_mask = jnp.greater(tokens, self.pad_token)
        unused_batch_size, seq_len = tokens.shape

        # Embed the input tokens and positions.
        embed_init = hk.initializers.TruncatedNormal(stddev=0.02)
        token_embedding_map = hk.Embed(
            self.vocab_size, embed_dim=self.model_size, w_init=embed_init)
        token_embeddings = token_embedding_map(tokens)
        positional_embeddings = hk.get_parameter(
            'positional_embeddings', [seq_len, self.model_size], init=embed_init)
        input_embeddings = token_embeddings + positional_embeddings  # [B, T, D]

        # Run the transformer over the inputs.
        embeddings = self.transformer(input_embeddings, input_mask)  # [B, T, D]

        # Decode the embeddings (here, we use untied weights).
        return hk.Linear(self.vocab_size)(embeddings)  # [B, T, V]


# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""Trains a transformer for language modeling on a small text 

This example serves to demonstrate:
    - A clean Haiku transformer implementation.
    - An example minimal training loop around it.

This example runs on ASCII text files.
We have not tuned the hyperparameters at all.

Example, using Karpathy's tiny_shakespeare dataset:
$ wget -O /tmp/shakespeare.txt \
    https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
$ python3 examples/transformer/train.py \
    --dataset_path=/tmp/shakespeare.txt --alsologtostderr
"""

# Training hyperparameters.
BATCH_SIZE = 256
SEQUENCE_LENGTH = 64
LEARNING_RATE = 3e-4
GRAD_CLIP_VALUE = 1
LOG_EVERY = 50
MAX_STEPS = 500
SEED = 0

# Model hyperparameters.
NUM_LAYERS = 6
NUM_HEADS = 8  # Number of attention heads.
MODEL_SIZE = 128
KEY_SIZE = 32
DROPOUT_RATE = 0.1

# Helpful type aliases.
_Batch = Batch
_Metrics = MutableMapping[str, Any]


class TrainingState(NamedTuple):
    """Container for the training state."""
    params: hk.Params  # Current network parameters.
    opt_state: optax.OptState  # Optimiser state (e.g. gradient moments).
    rng_key: jax.Array  # RNG used for e.g. dropout. Split on each update step.
    step: jax.Array  # Tracks the number of training steps.


def forward_pass(tokens: Union[np.ndarray, jax.Array]) -> jax.Array:
    """Defines the forward pass of the language """
    lm = LanguageModel(
        model_size=MODEL_SIZE,
        vocab_size=VOCAB_SIZE,
        pad_token=PAD_TOKEN,
        transformer=Transformer(
            num_heads=NUM_HEADS,
            num_layers=NUM_LAYERS,
            attn_size=KEY_SIZE,
            dropout_rate=DROPOUT_RATE,
        ),
    )
    return lm(tokens)  # Logits, shape [B, T, V].


def optimiser() -> optax.GradientTransformation:
    return optax.chain(
        optax.clip_by_global_norm(GRAD_CLIP_VALUE),
        optax.adam(LEARNING_RATE, b1=0.9, b2=0.99),
    )


@hk.transform
def loss_fn(data: _Batch) -> jax.Array:
    """Computes the (scalar) language modelling loss on `data` w.r.t. params."""
    logits = forward_pass(data.inputs)
    log_probs = jax.nn.log_softmax(logits)  # [B, T, V]
    onehot_targets = jax.nn.one_hot(data.targets, VOCAB_SIZE)
    log_likelihood = jnp.sum(onehot_targets * log_probs, axis=-1)  # [B, T]

    # Loss is the average negative log-likelihood per (non-masked) token.
    mask = jnp.not_equal(data.inputs, PAD_TOKEN)  # [B, T]
    return -jnp.sum(log_likelihood * mask) / jnp.sum(mask)  # []


@jax.jit
def init(rng: jax.Array, data: _Batch) -> TrainingState:
    """Makes an initial training state (random parameters)."""
    rng, init_rng = jax.random.split(rng)
    initial_params = loss_fn.init(init_rng, data)
    initial_opt_state = optimiser().init(initial_params)
    
    num_params = hk.data_structures.tree_size(initial_params)
    print(f"Number of parameters: {num_params}")
    
    return TrainingState(
        params=initial_params,
        opt_state=initial_opt_state,
        rng_key=rng,
        step=jnp.array(0),
    )


@jax.jit
def update(
    state: TrainingState, data: _Batch
) -> tuple[TrainingState, _Metrics]:
    """Does an SGD step, returning a new training state and metrics."""
    rng, net_rng = jax.random.split(state.rng_key)
    loss_and_grad_fn = jax.value_and_grad(loss_fn.apply)
    loss, gradients = loss_and_grad_fn(state.params, net_rng, data)

    updates, new_opt_state = optimiser().update(gradients, state.opt_state)
    new_params = optax.apply_updates(state.params, updates)

    new_state = TrainingState(
        params=new_params,
        opt_state=new_opt_state,
        rng_key=rng,
        step=state.step + 1,
    )

    metrics = {
        'step': state.step,
        'loss': loss,
    }
    return new_state, metrics

def parse_arguments():
    parser = argparse.ArgumentParser(description='OptEx(Network) experiments')
    parser.add_argument('--data', default="./data/harrypotter1.txt", type=str, help='dataset name')
    parser.add_argument('--opt_name', default="sgd", type=str, help='optimizer name')
    parser.add_argument('--lr', default=1e-2, type=float, help='learning rate')
    parser.add_argument('--method', default="optex", type=str, help='method name')
    parser.add_argument('--num_parall', default=4, type=int, help='number of parallel iterations')
    parser.add_argument('--seed', default=0, type=int, help='seed for random number generator')
    parser.add_argument('--num_runs', default=3, type=int, help='number of runs')
    parser.add_argument('--edim', default=100000, type=int, help='number of runs')
    parser.add_argument('--tune_every', default=50, type=int, help='number of runs')
    args = parser.parse_args()
    return args

def main():
    args = parse_arguments()
    
    folder = args.data.split('/')[-1].split('.txt')[0]
    root = f"./results/transformer-{folder}"
    
    if os.path.exists(root) is False:
        os.makedirs(root)
        
    with open(args.data) as file:
        train_dataset = load_ascii_dataset(
            corpus=file.read(),
            batch_size=BATCH_SIZE,
            sequence_length=SEQUENCE_LENGTH,
        )

    all_results = []
    
    for r in range(args.num_runs):
        # Initialise the model parameters.
        rng = jax.random.PRNGKey(args.seed + r * 1234)
        data = next(train_dataset)
        state = init(rng, data)
        
        x0, unravel_fn = jax.flatten_util.ravel_pytree(state.params)
        
        print("# of iters:", MAX_STEPS // (args.num_parall))
        
        arguments = {
            "opt_name":     "optax."+args.opt_name,
            "lr":           args.lr,
            "x0":           x0, 
            "num_iters":    1, 
            "num_parall":   args.num_parall,
            "opt_state":    None,
        }
        
        if args.method in ["optex", "line_search", "benchmark"]:
            arguments["inter_results"] = {}
        if args.method == "optex":
            arguments["effective_dim"] = args.edim
            arguments["inter_results"].update({"length_scale": 1.0})
        
        steps = MAX_STEPS // (arguments['num_parall'])
        print("\nTraining start for [%s]..." % args.method.upper())

        loss_result = []
        for step in range(steps):
            rng, net_rng = jax.random.split(state.rng_key)
            arguments["datas"] = [[net_rng, next(train_dataset)] for _ in range(args.num_parall)]
            
            x, fx, opt_state = eval("run_" + args.method)(
                lambda p, k, d: loss_fn.apply(unravel_fn(p), k, d), **arguments
            )
            
            loss_result.append(fx)
            print(fx)
            
            arguments.update({
                "x0":           x,
                "opt_state":    opt_state,
            })
            
            
            if args.method == "optex" and step % args.tune_every == args.tune_every - 1:
                xs, ys = arguments["inter_results"]["x_history"], arguments["inter_results"]["g_history"]
                xs, ys = np.concatenate(xs, axis=0), np.concatenate(ys, axis=0)
                
                indices = np.random.choice(len(xs), int(0.8 * len(xs))).tolist()
                target_indices = [i for i in range(len(xs)) if i not in indices]
                
                length_scale = tuning_mattern(
                    xs[indices], 
                    ys[indices], 
                    xs[target_indices], 
                    ys[target_indices], 
                    choice=[0.01, 0.1, 1, 10, 100],
                    effective_dim=5000
                )
                
                arguments["inter_results"].update({
                    "length_scale": length_scale,
                })
                print("optimized length scale:", length_scale)
        all_results.append(loss_result)
    np.save(f"{root}/{args.opt_name}({args.lr})-{MAX_STEPS}x{args.num_parall}-{args.method}.npy", np.array(all_results))
    

if __name__ == '__main__':
    main()