import os
from absl import app
from absl import flags
import haiku as hk
import jax.numpy as jnp
import numpy as np
import random
from scipy.stats import binned_statistic

import sys
sys.path.append('../')

from neural_networks_chomsky_hierarchy.experiments import constants
from neural_networks_chomsky_hierarchy.experiments import curriculum as curriculum_lib
from neural_networks_chomsky_hierarchy.experiments import training
from neural_networks_chomsky_hierarchy.experiments import utils
from neural_networks_chomsky_hierarchy.models import positional_encodings

_BATCH_SIZE = flags.DEFINE_integer(
    'batch_size',
    default=128,
    help='Training batch size.',
    lower_bound=1,
)
_MIN_LENGTH = flags.DEFINE_integer(
    'min_length',
    default=1,
    help='Minimum training sequence length.',
    lower_bound=1,
)
_SEQUENCE_LENGTH = flags.DEFINE_integer(
    'sequence_length',
    default=40,
    help='Maximum training sequence length.',
    lower_bound=1,
)
_MAX_RANGE_TEST_LENGTH = flags.DEFINE_integer(
    'max_range_test_length',
    default=500,
    help='Maximum testing sequence length.',
    lower_bound=1,
)
_MIN_RANGE_TEST_LENGTH = flags.DEFINE_integer(
    'min_range_test_length',
    default=1,
    help='Minimum testing sequence length.',
    lower_bound=1,
)
_RANGE_TEST_TOTAL_BATCH_SIZE = flags.DEFINE_integer(
    'range_test_total_batch_size',
    default=512,
    help='Test total batch size.',
    lower_bound=1,
)
_RANGE_TEST_SUB_BATCH_SIZE = flags.DEFINE_integer(
    'range_test_sub_batch_size',
    default=32,
    help='Test batch size.',
    lower_bound=1,
)
_TRAINING_STEPS = flags.DEFINE_integer(
    'training_steps',
    default=1_000_000,
    help='Training steps.',
    lower_bound=1,
)
_LEARNING_RATE = flags.DEFINE_list(
    'learning_rate',
    default=[1e-4, 3e-4, 5e-4],
    help="Learning rate.",
)
_SEED = flags.DEFINE_list(
    'seed',
    default=[0, 1, 2, 3, 4],
    help='Random seed.',
)
_TASK = flags.DEFINE_string(
    'task',
    default='even_pairs',
    help='Length generalization task (see `constants.py` for other tasks).',
)
_VOCAB_SIZE = flags.DEFINE_integer(
   'vocab_size',
   default=3,
   help='Vocab size'
)
_MODE = flags.DEFINE_string(
    'mode',
    default='full',
    help='Mode for regular palindrome.',
)
_ARCHITECTURE = flags.DEFINE_string(
    'architecture',
    default='tape_rnn',
    help='Model architecture (see `constants.py` for other architectures).',
)
_NUM_LAYERS = flags.DEFINE_integer(
   'num_layers',
   default=5,
   help='Number of layers in transformer'
)
_IS_AUTOREGRESSIVE = flags.DEFINE_boolean(
    'is_autoregressive',
    default=False,
    help='Whether to use autoregressive sampling or not.',
)
_HARD_ATTENTION = flags.DEFINE_boolean(
    'hard_attention',
    default=False,
    help='Whether to use hard attention or not.',
)
_STRICT_MASKING = flags.DEFINE_boolean(
    'strict_masking',
    default=False,
    help='Whether to use strict masking or not.',
)
_DOUBLE_MASKING = flags.DEFINE_boolean(
    'double_masking',
    default=False,
    help='Whether to use double masking.',
)
_CAUSAL_MASKING = flags.DEFINE_boolean(
    'causal_masking',
    default=False,
    help='Whether to use causal masking or not.',
)
_COMPUTATION_STEPS_MULT = flags.DEFINE_integer(
    'computation_steps_mult',
    default=0,
    help=(
        'The amount of computation tokens to append to the input tape (defined'
        ' as a multiple of the input length)'
    ),
    lower_bound=0,
)
_LAYER_NORM = flags.DEFINE_bool(
   'layer_norm',
   default=True,
   help="Whether to use layer normalization or not.",
)
_SUB_LENGTH = flags.DEFINE_integer(
   'sub_length',
   default=2,
   help="Length of subsequence in piecewise testable languages or substring in locally testable languages."
)
_ADD_EOS = flags.DEFINE_boolean(
    'add_eos',
    default=True,
    help='Whether to add eos token.'
)
_EMBEDDING_DIM = flags.DEFINE_integer(
   'embedding_dim',
   default=64,
   help="Embedding dimension of transformer."
)
_NUM_HEADS = flags.DEFINE_integer(
   'num_heads',
   default=8,
   help='Number of heads.'
)
_VERBOSE = flags.DEFINE_integer(
    'verbose',
    default=0,
    help='Set to 1 for debugging.'
)
_SAVE_HIDDEN_STATES = flags.DEFINE_bool(
    'save_hidden_states',
    default=False,
    help="Save hidden states."
)
_BOOLEAN_LANGUAGE_MODEL = flags.DEFINE_bool(
    'boolean_language_model',
    default=False,
    help="Boolean language modeling."
)
_PROBABILISTIC_LANGUAGE_MODEL = flags.DEFINE_bool(
    'probabilistic_language_model',
    default=False,
    help="Probabilistic language modeling."
)
_NUM_STATES = flags.DEFINE_integer(
   'num_states',
   default=2,
   help='Number of states for blm.'
)
# The architecture parameters depend on the architecture, so we cannot define
# them as via flags. See `constants.py` for the required values.


def main(unused_argv) -> None:
  # Create the task.
    print("Task: {}".format(_TASK.value))
    print(_STRICT_MASKING.value)
    scores = []

    interval = int(_MAX_RANGE_TEST_LENGTH.value / 5)
    bins = list(range(0, _MAX_RANGE_TEST_LENGTH.value + interval, interval))
    bins.insert(1, _SEQUENCE_LENGTH.value)
    binned_scores = []
    for lr in _LEARNING_RATE.value:
        lr = float(lr)
        for seed in _SEED.value:
            pos = 'NONE'
            seed = int(seed)
            random.seed(seed)
            np.random.seed(seed)
            if _SAVE_HIDDEN_STATES.value:
                save_dir = "./outputs/{}/{}/{}_{}".format(_TASK.value, _NUM_STATES.value, lr, seed)
                os.makedirs(save_dir, exist_ok=True)
            else:
                save_dir = None
            if _ARCHITECTURE.value == 'transformer_encoder':
                _ARCHITECTURE_PARAMS = {'positional_encodings': positional_encodings.POS_ENC_TABLE[pos], 'positional_encodings_params': positional_encodings.POS_ENC_PARAMS_TABLE[pos], 'causal_masking': _CAUSAL_MASKING.value, 'strict_masking': _STRICT_MASKING.value, 'double_masking': _DOUBLE_MASKING.value, 'num_layers': _NUM_LAYERS.value, 'layer_norm': _LAYER_NORM.value, 'embedding_dim': _EMBEDDING_DIM.value, 'num_heads': _NUM_HEADS.value, 'save_hidden_states': _SAVE_HIDDEN_STATES.value,
                'save_dir': save_dir}
            elif _ARCHITECTURE.value == 'tape_rnn':
                _ARCHITECTURE_PARAMS = {
                    'hidden_size': 256,
                    'memory_cell_size': 8,
                    'memory_size': 40,
                }
            elif _ARCHITECTURE.value == 'stack_rnn':
                _ARCHITECTURE_PARAMS = {
                    'hidden_size': 256,
                    'stack_cell_size': 8,
                    'stack_size': 256
                }
            else:
                _ARCHITECTURE_PARAMS = {
                    'hidden_size': 256,
                }

            curriculum = curriculum_lib.UniformCurriculum(
                values=list(range(_MIN_LENGTH.value, _SEQUENCE_LENGTH.value + 1))
            )
            if _TASK.value == 'regular_palindrome':
                task = constants.TASK_BUILDERS[_TASK.value](_SEQUENCE_LENGTH.value, _MODE.value)
            elif _TASK.value == 'definite':
                task = constants.TASK_BUILDERS[_TASK.value](_VOCAB_SIZE.value)
            elif _TASK.value == 'locally_testable':
                task = constants.TASK_BUILDERS[_TASK.value](_SUB_LENGTH.value)
            elif _TASK.value == 'piecewise_testable':
                task = constants.TASK_BUILDERS[_TASK.value](_VOCAB_SIZE.value, _SUB_LENGTH.value)
            elif 'left_deterministic' in _TASK.value:
                task = constants.TASK_BUILDERS[_TASK.value](_NUM_STATES.value)
            else:
                task = constants.TASK_BUILDERS[_TASK.value]()

            # Create the model.
            single_output = task.output_length(10) == 1
            model = constants.MODEL_BUILDERS[_ARCHITECTURE.value](
                output_size=task.output_size,
                return_all_outputs=True,
                **_ARCHITECTURE_PARAMS,
            )
            if _IS_AUTOREGRESSIVE.value:
                if _ARCHITECTURE.value != 'transformer':
                    model = utils.make_model_with_targets_as_input(
                        model, _COMPUTATION_STEPS_MULT.value
                    )
                    model = utils.add_sampling_to_autoregressive_model(model, single_output)
            elif _BOOLEAN_LANGUAGE_MODEL.value:
                model = utils.make_boolean_language_model(model)
            elif _PROBABILISTIC_LANGUAGE_MODEL.value:
                model = utils.make_probabilistic_language_model(model)
            elif not _ADD_EOS.value:
                model = utils.make_model_with_single_output(model)
            else:
                model = utils.make_model_with_empty_targets(
                model, task, _COMPUTATION_STEPS_MULT.value, single_output
                )
            model = hk.transform(model)

            # Create the loss and accuracy based on the pointwise ones.
            def loss_fn(output, target):
                loss = jnp.mean(jnp.sum(task.pointwise_loss_fn(output, target), axis=-1))
                return loss, {}

            def accuracy_fn(output, target):
                accuracies = task.accuracy_fn(output, target)
                return accuracies.mean()

            # Create the final training parameters.
            training_params = training.ClassicTrainingParams(
                seed=seed,
                model_init_seed=seed,
                training_steps=_TRAINING_STEPS.value,
                log_frequency=100,
                length_curriculum=curriculum,
                batch_size=_BATCH_SIZE.value,
                task=task,
                model=model,
                loss_fn=loss_fn,
                learning_rate=lr,
                accuracy_fn=accuracy_fn,
                compute_full_range_test=True,
                min_length=_MIN_LENGTH.value,
                min_range_test_length=_MIN_RANGE_TEST_LENGTH.value,
                max_range_test_length=_MAX_RANGE_TEST_LENGTH.value,
                range_test_total_batch_size=_RANGE_TEST_TOTAL_BATCH_SIZE.value,
                range_test_sub_batch_size=_RANGE_TEST_SUB_BATCH_SIZE.value,
                is_autoregressive=_IS_AUTOREGRESSIVE.value,
                hard_attention=_HARD_ATTENTION.value,
                verbose=_VERBOSE.value,
                probabilistic_language_model=_PROBABILISTIC_LANGUAGE_MODEL.value,
            )

            training_worker = training.TrainingWorker(training_params, use_tqdm=True)
            _, eval_results, eval_results_hard, _ = training_worker.run()

            accuracies = [r['accuracy'] for r in eval_results]
            score = np.mean(accuracies[_SEQUENCE_LENGTH.value - _MIN_LENGTH.value + 1 :])
            scores.append(score)

            binned_score = binned_statistic(list(range(_MAX_RANGE_TEST_LENGTH.value)), accuracies, 'mean', bins=bins).statistic
            binned_scores.append(binned_score)
            print(f'lr {lr} seed {seed} Accuracy: {score} Binned Accuracy: {binned_score}')
            if eval_results_hard is not None:
                accuracies = [r['accuracy'] for r in eval_results_hard]
                score = np.mean(accuracies[_SEQUENCE_LENGTH.value - _MIN_LENGTH.value + 1 :])
                print(f'lr {lr} seed {seed} Accuracy hard: {score}')
    print("maximum score: {}".format(max(scores)))
    print("minimum score: {}".format(min(scores)))
    print("average score: {}".format(np.mean(scores)))
    print("std score: {}".format(np.std(scores)))
    print("maximum binned score: {}".format(np.max(binned_scores, axis=0)))
    print("minimum binned score: {}".format(np.min(binned_scores, axis=0)))
    print("average binned score: {}".format(np.mean(binned_scores, axis=0)))
    print("std binned score: {}".format(np.std(binned_scores, axis=0)))

if __name__ == '__main__':
  app.run(main)
