import dataclasses
import random
from typing import Any, Callable, Mapping

from absl import logging
import haiku as hk
import jax
import jax.numpy as jnp
import numpy as np
import tqdm


_Batch = Mapping[str, jnp.ndarray]


@dataclasses.dataclass
class EvaluationParams:
  """The parameters used for range evaluation of networks."""
  model: hk.Transformed
  params: hk.Params

  accuracy_fn: Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]
  sample_batch: Callable[[jnp.ndarray, int, int], _Batch]

  min_length: int
  min_test_length: int
  max_test_length: int
  total_batch_size: int
  sub_batch_size: int  # We use this to avoid memory overflow.

  is_autoregressive: bool = False
  hard_attention: bool = True
  verbose: int = 0
  probabilistic_language_model: int = False,


def range_evaluation(
    eval_params: EvaluationParams,
    use_tqdm: bool = False,
) -> list[Mapping[str, Any]]:
  """Evaluates the model on longer, never seen strings and log the results.

  Args:
    eval_params: The evaluation parameters, see above.
    use_tqdm: Whether to use a progress bar with tqdm.

  Returns:
    The list of dicts containing the accuracies.
  """
  model = eval_params.model
  params = eval_params.params

  random.seed(1)
  np.random.seed(1)
  rng_seq = hk.PRNGSequence(1)

  if eval_params.is_autoregressive:
    apply_fn = jax.jit(model.apply, static_argnames=('sample',))
  elif eval_params.hard_attention:
    apply_fn = jax.jit(model.apply, static_argnames=('hard_attention',))
  else:
    apply_fn = jax.jit(model.apply)

  results = []
  lengths = range(eval_params.min_test_length, eval_params.max_test_length + 1)
  if use_tqdm:
    lengths = tqdm.tqdm(lengths)
  true_counts = {}
  false_counts = {}
  for length in lengths:
    # We need to clear the cache of jitted functions, to avoid overflow as we
    # are jitting len(lengths) ones, which can be a lot.
    apply_fn.clear_cache()
    sub_accuracies = []
    for _ in range(eval_params.total_batch_size // eval_params.sub_batch_size):
      batch = eval_params.sample_batch(
          next(rng_seq), eval_params.sub_batch_size, length)

      if eval_params.is_autoregressive:
        outputs = apply_fn(
            params,
            next(rng_seq),
            batch['input'],
            jnp.empty_like(batch['output']),
            sample=True)
      elif eval_params.hard_attention:
        outputs = apply_fn(params, next(rng_seq), batch['input'], hard_attention=True)
      else:
        outputs = apply_fn(params, next(rng_seq), batch['input'])

      if eval_params.probabilistic_language_model:
        sub_accuracies.append(eval_params.accuracy_fn(outputs, batch['input'].argmax(axis=-1)))
      else:
        sub_accuracies.append(eval_params.accuracy_fn(outputs, batch['output']))
      
      if eval_params.verbose > 0:
        if length % 5 == 0:
          skip = False
          if length not in true_counts:
            true_counts[length] = 0
            false_counts[length] = 0
          if true_counts[length] >= 5 and false_counts[length] >= 5:
            skip = True
          if not skip:
            for x, y_pred, y_true in zip(batch['input'].argmax(axis=-1), outputs.argmax(axis=-1), batch['output'].argmax(axis=-1)):
              if not jnp.all(y_pred == y_true) and false_counts[length] < 5:
                logging.info("x: {}, y_pred: {}, y_true: {}".format(x, y_pred, y_true))
                false_counts[length] += 1
              elif jnp.all(y_pred == y_true) and true_counts[length] < 5:
                logging.info("x: {}, y_pred: {}, y_true: {}".format(x, y_pred, y_true))
                true_counts[length] += 1

    log_data = {
        'length': length,
        'accuracy': np.mean(sub_accuracies),
    }
    logging.info(log_data)
    results.append(log_data)
  return results