# Copyright 2022 Google.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Sequence to sequence model."""

from typing import Any, Callable, Dict, Tuple

from absl import logging
from flax import linen as nn
from flax.training import common_utils
import gin
import jax
import jax.numpy as jnp
import  metrics_summary
from transformer import decoder_stack
from transformer import metric_utils
from transformer import text_dataset
from transformer import tasks
import numpy as np
import seqio


Array = jnp.ndarray
MetricsSummary = metrics_summary.MetricsSummary


# TODO(mrabe): Remove this function and find a better way to turn text metrics
# into text on tensorboard.
def process_summaries(vocab: seqio.Vocabulary,
                      met_summary: MetricsSummary,
                      mode: str) -> MetricsSummary:
  """Compute some additional summaries, and convert tokens to text.

  Args:
    vocab: The vocabulary to detokenize generated text.
    met_summary: The summary object to process.
    mode: The mode of the summary (e.g. "test", "train")

  Returns:
    The modified summary dictionary.
  """

  mdict = met_summary.current_metric_dict()

  # Calculate perplexity from the average nats_per_token over all replicas.
  # This has to be done here, because the perplexities themselves can't be
  # averaged in the usual way.
  if "nats_per_token" in mdict:
    nats_per_token = mdict["nats_per_token"].to_value()
    met_summary.add({"perplexity": np.exp(nats_per_token)})

  if mode.startswith("generate") and "gen_tokens" in mdict:
    # Convert output tokens to example output text.
    # Write text to both the summary, and pretty-print to the log file.
    gen_toks = mdict["gen_tokens"].to_value()[0]
    if np.ndim(gen_toks) != 2:
      raise ValueError("Unsupported shape for gen_tokens: %s" % gen_toks.shape)

    ntoks = gen_toks.shape[-1]
    gen_text = text_dataset.decode_tokens(gen_toks, vocab, max_length=ntoks)
    logging.info("Generated text = %s", gen_text)
    met_summary.add_text({"gen_text": gen_text})
    del mdict["gen_tokens"]   # Otherwise it will turn into a histogram.
    
  return met_summary


@gin.configurable
def process_summaries_function(vocab: seqio.Vocabulary) -> Callable[
    [MetricsSummary, str], MetricsSummary]:
  """Return a function that processes summaries with the given vocabulary."""
  # For use with training_loop.process_summaries_function
  def process_fn(met_summary: MetricsSummary, mode: str):
    return process_summaries(vocab, met_summary, mode)
  return process_fn


@gin.configurable
class DecoderOnlyLanguageModel(nn.Module):
  """Decoder only language modeling."""

  mode: str
  task_config: decoder_stack.TransformerTaskConfig = gin.REQUIRED
  decoder_factory: Callable[[], Any] = gin.REQUIRED

  sample_method: str = "sample"   # Can be {"sample", "greedy"}
  output_token_losses: bool = False
  output_logits: bool = False

  randomize_tokens_ind: bool = False
  randomize_tokens: bool = False
  p_randomize_vocab_per_instance: float = 1.
  p_randomize_per_instance: float = 1.

  def get_fake_input(self):
    """Returns a fake input for initialization of the appropriate shape."""
    b = self.task_config.batch_size
    fake_input_dict = {
        "targets": jnp.ones([b, self.task_config.sequence_length],
                            dtype=jnp.int32),
        "start_of_sequence": jnp.ones([b], dtype=jnp.bool_),
        "epoch": jnp.ones([b], dtype=jnp.int32),
    }
    if text_dataset.get_loss_mask_tokens(split=self.mode) != (None, None):
      # We are not adding the loss mask to the dummy input by default as it can
      # cause a slowdown during evaluation and perhaps inference.
      fake_input_dict["loss_mask"] = jnp.ones(
          [b, self.task_config.sequence_length], dtype=jnp.bool_)
    return fake_input_dict

  def metrics_summary_operations(self, aggregate_over: str) -> Dict[str, str]:
    """Summary operation to use for recorded metrics."""
    metric_ops = {
        "loss": "mean",
        "nats_per_token": "mean",
        "bits_per_token": "mean",
        "bits_per_char": "mean",
        "accuracy": "mean",
        "losses": "mean",
        "sentence_accuracy": "mean",
        "average_sentence_accuracy": "mean",
        "average_losses": "mean",
        "num_examples": "sum",
        "num_tokens": "mean",
        "num_chars_per_device": "mean",
        "num_chars_per_batch": "mean",
        "nonzero_tokens": "mean",
        "num_tokens_per_device": "mean",
        "num_tokens_per_batch": "mean",
        "epoch": "mean",
    }
    for t in list(seqio.TaskRegistry.names()):
      metric_ops[t + "_sentence_accuracy"] = "mean"
      metric_ops[t + "_losses"] = "mean"
      metric_ops[t + "_num_examples"] = "sum"
      metric_ops[t + "_nats_per_token"] = "mean"
      metric_ops[t + "_accuracy"] = "mean"
    for task, metric_name in tasks.ALL_METRICS.items():
      metric_ops["average_" + metric_name] = "mean"
      for t in tasks.ALL_INDEXED_TASKS[task]:
        metric_ops[t + "_" + metric_name] = "mean"
    if aggregate_over == "steps":
      return metric_ops
    elif aggregate_over == "devices":
      # Ensure that statistics that refer to the total batch size stay constant
      # as TPU topologies change. For those we have to sum over devices, but
      # compute the mean over steps.
      metric_ops.update({
          "num_tokens_per_batch": "sum",
          "num_chars_per_batch": "sum",
          "loss": "sum"})
      return metric_ops
    else:
      raise ValueError("Don't know how to aggregate over: %s" % aggregate_over)

  def setup(self):
    self.decoder = self.decoder_factory(mode=self.mode,
                                        task_config=self.task_config)  # pytype: disable=wrong-keyword-args  # trace-all-classes

  def __call__(self, inputs: ...):
    task_config = self.task_config
    
    input_tokens = inputs["targets"]                  # [b, seq_len]
    start_of_sequence = inputs["start_of_sequence"]   # [b]
    epochs = inputs["epoch"]                          # [b]
    
    should_randomize = (self.mode == "train")
    if "should_randomize" in inputs:
      should_randomize = inputs["should_randomize"] 
   
    if "loss_mask" in inputs:
      loss_mask = inputs["loss_mask"]                 # [b, seq_len]
    else:
      loss_mask = jnp.ones((1, 1), dtype=jnp.bool_)

    input_tokens = jnp.asarray(input_tokens)
    assert input_tokens.ndim == 2
    assert input_tokens.shape[0] == task_config.batch_size
    assert input_tokens.shape[1] == task_config.sequence_length
    assert start_of_sequence.shape[0] == task_config.batch_size

    # Sanity check to avoid out-of-bounds on token lookup.

    if self.randomize_tokens_ind and should_randomize:
        shuffled_index = jax.random.randint(subkey, shape = input_tokens.shape, minval = 1, maxval = task_config.vocab_size)

        vocab_mixing= jax.random.bernoulli(
          subkey, 
          p=self.p_randomize_vocab_per_instance,
          shape=input_tokens.shape,
        )
        input_tokens = shuffled_index * vocab_mixing  + input_tokens * (1 - vocab_mixing)

    logging.info("langmodel: Compiling model for mode %s", self.mode)
    logging.info("langmodel: input_tokens = %r", input_tokens)
    logging.info("langmodel: start_of_sequece = %r", start_of_sequence)
    logging.info("langmodel: epochs = %r", epochs)

    # The target outputs are the next character in each sequence.
    # Shift tokens left and pad with a zero at the end.
    # TODO(delesley): We don't predict the first token of each sequence.
    target_tokens = jnp.pad(input_tokens[:, 1:], [(0, 0), (0, 1)])
    if self.decoder.final_mlp_factory is not None:
        target_tokens = input_tokens
        if self.decoder.probe_single_letter != -1:
          logging.info("probe_single_letter: probing letter = %r", 98 + self.decoder.probe_single_letter)
          target_tokens = jnp.ones_like(target_tokens) * (98 + self.decoder.probe_single_letter) 
        elif self.decoder.probe_all_letters:
          # probe all letters
          target_tokens = jnp.tile(np.arange(98,124), [target_tokens.shape[1], target_tokens.shape[0]]).transpose(1,0)
    if self.randomize_tokens and should_randomize:
        orig_input_tokens = input_tokens
        rnd_key = self.make_rng("emb_randomize_once")        
        rnd_key, subkey = jax.random.split(rnd_key)

        multi_occurances = jnp.zeros((input_tokens.shape[0], task_config.vocab_size))
        for i in range(input_tokens.shape[0]):
          multi_occurances = multi_occurances.at[i, input_tokens[i]].add(1)


        indexes = []
        for i in range(input_tokens.shape[0]):
            rnd_key, subkey = jax.random.split(rnd_key)
            index = jax.random.shuffle(subkey, jnp.arange(task_config.vocab_size))
            indexes.append(index)
        shuffled_index = jnp.stack(indexes)

        vocab_mixing= jax.random.bernoulli(
          subkey, 
          p=self.p_randomize_vocab_per_instance,
          shape=(input_tokens.shape[0], task_config.vocab_size),
        )

        rnd_key, subkey = jax.random.split(rnd_key)
        instance_mixing= jax.random.bernoulli(
          subkey, 
          p=self.p_randomize_per_instance,
          shape=(input_tokens.shape[0], 1),
        )

        vocab_mixing = vocab_mixing * instance_mixing *  (multi_occurances > 1)
        vocab_mixing = vocab_mixing.at[:, 0].set(0) # do not shuffle padding

        new_token_index = shuffled_index * vocab_mixing + jnp.arange(task_config.vocab_size) * (1 - vocab_mixing)    
        input_tokens = jnp.take_along_axis(new_token_index, input_tokens, axis = 1)

        target_tokens = jnp.pad(input_tokens[:, 1:], [(0, 0), (0, 1)])

        # change back first occurrence 
        for i in range(input_tokens.shape[0]):
          first_occurances =  jnp.ones(task_config.vocab_size) * input_tokens.shape[1]
          first_occurances = first_occurances.at[orig_input_tokens[i]].min(jnp.arange(input_tokens.shape[1]))
          target_tokens = target_tokens.at[i, first_occurances.astype(int)- 1].set(np.arange(task_config.vocab_size))

        
    logging.info("langmodel: target_tokens = %r", target_tokens)

    # Invoke the decoder stack.
    # The decoder will return pre-softmax logits for the predicted targets.
    (logits, _, d_metrics) = self.decoder(input_tokens=input_tokens,
                                          target_tokens=target_tokens,
                                          start_of_sequence=start_of_sequence, 
                                          should_randomize = should_randomize)

    # Softmax cross-entropy loss on target tokens.
    logits = nn.log_softmax(logits, axis=-1)   # (b, seq_len, vocab_size)
    logging.info("langmodel: logits = %r", logits)
    soft_targets = common_utils.onehot(target_tokens, task_config.vocab_size)
    logging.info("langmodel: soft_targets = %r", soft_targets)

    losses = -jnp.sum(soft_targets * logits, axis=-1)  # (b, seq_len)
    logging.info("langmodel: losses = %r", losses)

    # Don't predict null tokens which are past the end-of-sequence.
    # Also don't predict the 0 at the end of the sequence.
    # TODO(delesley): Predict the final end-of-sequence marker.
    loss_mask = jnp.logical_and(
        loss_mask,
        input_tokens > 0)
    loss_mask = jnp.logical_and(
        loss_mask,
        target_tokens > 0)
    logging.info("langmodel: loss_mask = %r", loss_mask)

    losses = jnp.where(loss_mask, losses, 0.0)  # (batch_size, seq_len)
    loss = jnp.sum(losses)  # total loss on device

    token_count = jnp.sum(loss_mask)  # tokens on device
    token_count_nz = token_count + 1.0e-6
    loss_per_token = loss / token_count_nz
    bits_per_token = loss_per_token * 1.442695  # log(e)/log(2)
    accuracy = metric_utils.compute_accuracy_sum(logits, target_tokens,
                                                 loss_mask)
    accuracy = accuracy / token_count_nz  # Percent correct.
    epoch = jnp.mean(epochs)

    if self.mode.startswith("generate"):
      # Generate example text.
      logging.info("lang_model: text inference.")
      gen_tokens, losses, gen_logits = self.generate(inputs, task_config.sequence_length)

      accuracy = jnp.equal(gen_tokens, target_tokens)
      accuracy = jnp.logical_and(accuracy, loss_mask)
      sentence_accuracy = jnp.sum(accuracy, axis = 1) == jnp.sum(loss_mask, axis = 1)

      # detect too long ones 
      is_valid = jnp.ones((input_tokens.shape[0]), dtype=jnp.bool_)
      if "task_idx" in inputs:
        is_valid = is_valid * (inputs["task_idx"] != -1)

      # Return generated text, along with vizualizations and histograms.

      token_count = jnp.sum(loss_mask * is_valid[:, None])  # tokens on device
      token_count_nz = token_count + 1.0e-6
      loss_per_token = jnp.sum(losses * is_valid[:, None]) / token_count_nz

      sum_acc = jnp.sum(accuracy, axis = 1)
      metrics = {
        "gen_tokens": gen_tokens, **d_metrics, 
        "sentence_accuracy": (jnp.sum(sentence_accuracy * is_valid)/ (jnp.sum(is_valid) + 1e-5), jnp.sum(is_valid) + 1e-5),
        "num_examples": jnp.sum(is_valid),
        "losses":  (jnp.sum(losses * is_valid)/ (jnp.sum(is_valid) + 1e-5), jnp.sum(is_valid) + 1e-5),
        "nats_per_token": (loss_per_token, token_count),
        "accuracy": (jnp.sum(sum_acc * is_valid) / token_count_nz, token_count),
        }

      if "task_idx" in inputs and self.mode.startswith("generate_syn_bigbench"):
        indexed_tasks = tasks.ALL_INDEXED_TASKS[self.mode[13:]]
        all_bigbench_task_acc = jnp.zeros(len(indexed_tasks))
        all_bigbench_task_weight = jnp.zeros(len(indexed_tasks))
        all_bigbench_losses = jnp.zeros(len(indexed_tasks))
        all_bigbench_nats_per_token = jnp.zeros(len(indexed_tasks))
        all_bigbench_tacc = jnp.zeros(len(indexed_tasks))
        all_bigbench_tokens = jnp.zeros(len(indexed_tasks))
        for idx, task_idx in enumerate(inputs["task_idx"]):
          # discard batch padding 
          is_valid = task_idx != -1
          
          all_bigbench_task_acc = all_bigbench_task_acc.at[task_idx].add(sentence_accuracy[idx] * is_valid) 
          all_bigbench_task_weight = all_bigbench_task_weight.at[task_idx].add(is_valid)
          all_bigbench_losses = all_bigbench_losses.at[task_idx].add(losses[idx] * is_valid)
          all_bigbench_nats_per_token = all_bigbench_nats_per_token.at[task_idx].add(losses[idx] * is_valid)
          all_bigbench_tacc = all_bigbench_tacc.at[task_idx].add(sum_acc[idx] * is_valid)
          all_bigbench_tokens = all_bigbench_tokens.at[task_idx].add(jnp.sum(loss_mask[idx] * is_valid))
        accumulate_result = 0
        accumulate_losses = 0
        for i in range(len(indexed_tasks)):
          acc = all_bigbench_task_acc[i] / jnp.maximum(all_bigbench_task_weight[i], 1e-5)
          loss = all_bigbench_losses[i] / jnp.maximum(all_bigbench_task_weight[i], 1e-5)
          nats_per_token = all_bigbench_nats_per_token[i] / jnp.maximum(all_bigbench_tokens[i], 1e-5)
          tacc = all_bigbench_tacc[i] / jnp.maximum(all_bigbench_tokens[i], 1e-5)
          metrics[indexed_tasks[i] + "_sentence_accuracy"] = (acc, jnp.maximum(all_bigbench_task_weight[i], 1e-5))
          metrics[indexed_tasks[i] + "_num_examples"] = all_bigbench_task_weight[i]
          metrics[indexed_tasks[i] + "_losses"] = (loss, jnp.maximum(all_bigbench_task_weight[i], 1e-5))
          metrics[indexed_tasks[i] + "_nats_per_token"] = (nats_per_token, jnp.maximum(all_bigbench_tokens[i], 1e-5))
          metrics[indexed_tasks[i] + "_accuracy"] = (tacc, jnp.maximum(all_bigbench_tokens[i], 1e-5))
        #   accumulate_result += acc
        #   accumulate_losses += loss
        # metrics["average_sentence_accuracy"] = (accumulate_result/ len(indexed_tasks))
        # metrics["average_losses"] = (accumulate_losses/ len(indexed_tasks))
          
      if self.output_logits:
        metrics["logits"] = logits
        metrics["gen_logits"] = gen_logits
      return (loss, metrics)

    # Just return metrics related to the loss.
    metrics = {
        "loss": loss,   # will be summed over devices
        "nats_per_token": (loss_per_token, token_count),
        "bits_per_token": (bits_per_token, token_count),
        "accuracy": (accuracy, token_count),
        "num_tokens_per_device": token_count,
        "num_tokens_per_batch": token_count,  # will be summed over devices
        "epoch": epoch,
    }

    # Compute bits per character if we have the number of characters.
    if "num_chars" in inputs:
      num_chars = jnp.sum(inputs["num_chars"])
      bits_per_char = loss / (num_chars + 1e-6) * 1.442695
      metrics["num_chars_per_device"] = num_chars
      metrics["num_chars_per_batch"] = num_chars  # will be summed over devices
      metrics["bits_per_char"] = (bits_per_char, num_chars)

    # Provided to make sure that the data pipeline and the the model agree
    # on the number of tokens with a loss.
    if "nonzero_tokens" in inputs:
      nonzero_tokens = jnp.sum(inputs["nonzero_tokens"])
      metrics["nonzero_tokens"] = nonzero_tokens

    if self.output_logits:
      metrics["logits"] = logits
    if self.output_token_losses:
      metrics["token_losses"] = losses

    return (loss, metrics)

  def generate(self, inputs: ...,sequence_length : int) -> Array:
    """Generate an output sequence.

    Args:
      inputs: the same as argument to _call_.
      sequence_length: the length of sequence to generate.

    Returns:
      An array of generated tokens of shape (batch_size, sequence_length).
    """
    # TODO(delesley): Add support for passing the prefix as an argument.
    # TODO(delesley): Add support for temperature, gumbel softmax, beam search.
    
    batch_size = self.task_config.batch_size
    input_tokens = inputs["targets"]                  # [b,seq_len]
    start_of_sequence = inputs["start_of_sequence"]   # [b]   
    if "loss_mask" in inputs:
        loss_mask = inputs["loss_mask"]
    target_tokens = jnp.pad(input_tokens[:, 1:], [(0, 0), (0, 1)])
    
    # Initialize decoder.
    if self.decoder.supports_generate():
        dstate = self.decoder.init_decoder_state(sequence_length,
                                                 start_of_sequence)
    else:
        dstate = self.decoder.init_decoder_state_vanilla(sequence_length,
                                                 start_of_sequence)

    # TODO(delesley): Handle start-of-sequence in a better way.
    # There is no special token for start of sequence, so we grab the first
    # one from the ground-truth input data.
    
    first_token = input_tokens[:, 0:1]
    no_start_of_seq = jnp.array([False] * batch_size, dtype=jnp.bool_)
    sample_method = self.sample_method
    sample_prng = self.make_rng("sample")

    if "sample_prng" in inputs:
        sample_prng = inputs["sample_prng"]
        
    nucleus_cutoff = 0.9
    temperature = 1
    if "nucleus_cutoff" in inputs:
        nucleus_cutoff = inputs["nucleus_cutoff"]
    if "temperature" in inputs:
        temperature = inputs["temperature"]

    # Greedy autoregressive decoder function.        
    def loop_fn(scan_state: Any, i: Array) -> Tuple[Any, Array]:
      prng = jax.random.fold_in(sample_prng, i)
      (dstate, input_token, losses) = scan_state
      
      # if "loss_mask" in inputs:        
      #   vwhere = jax.vmap(jnp.where)
      #   input_token = vwhere(input_loss_mask[:, i], input_token, input_tokens[:,i])                
      # del i  
      (logits, dstate, _) = self.decoder(input_tokens=input_token,
                                         target_tokens=None,
                                         start_of_sequence=no_start_of_seq,
                                         decoder_state=dstate, should_randomize = False)
        
      logits = logits/ temperature
      soft_targets = common_utils.onehot(jax.lax.dynamic_slice_in_dim(target_tokens, i, 1, axis=1), self.task_config.vocab_size)
      loss = -jnp.sum(soft_targets * nn.log_softmax(logits, axis=-1) , axis=-1)   

      if sample_method == "sample":
        logging.info("Using categorical sampling.")
        output_token = jax.random.categorical(prng, logits, axis=-1)
      elif sample_method == "nucleus":
        logging.info("Using nucleus sampling.")
        # Softmax the logits
        probs = jax.nn.softmax(logits, axis=-1)
        # Sort the probs in descending order.
        sorted_probs = jnp.sort(probs, axis=-1)[..., ::-1]
        # Compute the cumulative probabilities.
        cumsum_logits = jnp.cumsum(sorted_probs, axis=-1)
        # Identify the indices where the cumulative probability is after the nucleus_cutoff
        cutoff_indices = cumsum_logits > nucleus_cutoff
        cutoff_indices = jnp.pad(cutoff_indices[..., :-1], [(0, 0),(0, 0), (1, 0)])
        # Calculate the original indices of the sorted logits
        sorted_indices = jnp.argsort(probs, axis=-1)[..., ::-1]
        # Calculate the reverse mapping from the argsort -> original order
        reverse_argsort = jnp.argsort(sorted_indices, axis=-1)
        # Multiply the logits by -inf if they are not in the top-p
        logits -= jnp.take(cutoff_indices, reverse_argsort) * jnp.inf
        # Sample from the probabilities
        output_token = jax.random.categorical(prng, logits, axis=-1)
      elif sample_method == "greedy":
        logging.info("Using greedy sampling.")
        output_token = jnp.argmax(logits, axis=-1)
      else:
        raise ValueError(f"Invalid sampling method: {sample_method}")
      logging.info("generate_loop_fn: output_token = %r", output_token)
    
      if "loss_mask" in inputs:        
        vwhere = jax.vmap(jnp.where)
        output_token = vwhere(jax.lax.dynamic_slice_in_dim(loss_mask, i, 1, axis=1), output_token, jax.lax.dynamic_slice_in_dim(target_tokens, i, 1, axis=1))
        loss = vwhere(jax.lax.dynamic_slice_in_dim(loss_mask, i, 1, axis=1), loss, jnp.zeros((input_tokens.shape[0],1)))
      del i  
    
      return ((dstate, output_token, losses + loss[:,0]), (output_token, logits))

    # Scan over the sequence length.
    iterations = jnp.arange(sequence_length)
    task_configial_scan_state = (dstate, first_token, jnp.zeros(input_tokens.shape[0]))
    initial_scan_state = task_configial_scan_state
    ((_, _, losses), (output_tokens, output_logits) ) = jax.lax.scan(loop_fn, initial_scan_state, iterations)
    logging.info("generate: output_tokens = %r", output_tokens)

    
    
    # Output_tokens has shape (sequence_length, batch_size, 1)
    assert output_tokens.shape == (sequence_length, batch_size, 1)
    output_tokens = jnp.reshape(
        output_tokens, (sequence_length, self.task_config.batch_size))
    output_tokens = output_tokens.transpose([1, 0])

    assert output_logits.shape == (sequence_length, batch_size, 1, self.task_config.vocab_size)
    output_logits = jnp.reshape(
        output_logits, (sequence_length, self.task_config.batch_size, self.task_config.vocab_size))
    output_logits = output_logits.transpose([1, 0, 2])

    return output_tokens, losses, output_logits
    # , output_logits
