import functools

import haiku as hk

from neural_networks_chomsky_hierarchy.experiments import curriculum as curriculum_lib
from neural_networks_chomsky_hierarchy.models import ndstack_rnn
from neural_networks_chomsky_hierarchy.models import rnn
from neural_networks_chomsky_hierarchy.models import stack_rnn
from neural_networks_chomsky_hierarchy.models import tape_rnn
from neural_networks_chomsky_hierarchy.models import transformer
from neural_networks_chomsky_hierarchy.tasks.cs import binary_addition
from neural_networks_chomsky_hierarchy.tasks.cs import binary_multiplication
from neural_networks_chomsky_hierarchy.tasks.cs import bucket_sort
from neural_networks_chomsky_hierarchy.tasks.cs import compute_sqrt
from neural_networks_chomsky_hierarchy.tasks.cs import duplicate_string
from neural_networks_chomsky_hierarchy.tasks.cs import missing_duplicate_string
from neural_networks_chomsky_hierarchy.tasks.cs import odds_first
from neural_networks_chomsky_hierarchy.tasks.dcf import modular_arithmetic_brackets
from neural_networks_chomsky_hierarchy.tasks.dcf import reverse_string
from neural_networks_chomsky_hierarchy.tasks.dcf import solve_equation
from neural_networks_chomsky_hierarchy.tasks.dcf import stack_manipulation
from neural_networks_chomsky_hierarchy.tasks.regular import cycle_navigation
from neural_networks_chomsky_hierarchy.tasks.regular import even_pairs
from neural_networks_chomsky_hierarchy.tasks.regular import even_length
from neural_networks_chomsky_hierarchy.tasks.regular import modular_arithmetic
from neural_networks_chomsky_hierarchy.tasks.regular import parity_check
from neural_networks_chomsky_hierarchy.tasks.regular import star
from neural_networks_chomsky_hierarchy.tasks.regular import star_free
from neural_networks_chomsky_hierarchy.tasks.regular import local_threshold
from neural_networks_chomsky_hierarchy.tasks.regular import locally_testable
from neural_networks_chomsky_hierarchy.tasks.regular import strictly_local
from neural_networks_chomsky_hierarchy.tasks.regular import definite
from neural_networks_chomsky_hierarchy.tasks.regular import first
from neural_networks_chomsky_hierarchy.tasks.regular import first_two
from neural_networks_chomsky_hierarchy.tasks.regular import last
from neural_networks_chomsky_hierarchy.tasks.regular import piecewise_testable
from neural_networks_chomsky_hierarchy.tasks.regular import left_deterministic
from neural_networks_chomsky_hierarchy.tasks.regular import right_deterministic
from neural_networks_chomsky_hierarchy.tasks.regular import bounded_dyck
from neural_networks_chomsky_hierarchy.tasks.counter import anbn
from neural_networks_chomsky_hierarchy.tasks.counter import even
from neural_networks_chomsky_hierarchy.tasks.lm import left_deterministic_blm
from neural_networks_chomsky_hierarchy.tasks.lm import right_deterministic_blm
from neural_networks_chomsky_hierarchy.tasks.lm import left_deterministic_plm
from neural_networks_chomsky_hierarchy.tasks.lm import right_deterministic_plm

MODEL_BUILDERS = {
    'rnn':
        functools.partial(rnn.make_rnn, rnn_core=hk.VanillaRNN),
    'lstm':
        functools.partial(rnn.make_rnn, rnn_core=hk.LSTM),
    'stack_rnn':
        functools.partial(
            rnn.make_rnn,
            rnn_core=stack_rnn.StackRNNCore,
            inner_core=hk.VanillaRNN),
    'ndstack_rnn':
        functools.partial(
            rnn.make_rnn,
            rnn_core=ndstack_rnn.NDStackRNNCore,
            inner_core=hk.VanillaRNN),
    'stack_lstm':
        functools.partial(
            rnn.make_rnn, rnn_core=stack_rnn.StackRNNCore, inner_core=hk.LSTM),
    'transformer_encoder':
        transformer.make_transformer_encoder,
    'transformer':
        transformer.make_transformer,
    'tape_rnn':
        functools.partial(
            rnn.make_rnn,
            rnn_core=tape_rnn.TapeInputLengthJumpCore,
            inner_core=hk.VanillaRNN),
}

CURRICULUM_BUILDERS = {
    'fixed': curriculum_lib.FixedCurriculum,
    'regular_increase': curriculum_lib.RegularIncreaseCurriculum,
    'reverse_exponential': curriculum_lib.ReverseExponentialCurriculum,
    'uniform': curriculum_lib.UniformCurriculum,
}

TASK_BUILDERS = {
    'modular_arithmetic':
        modular_arithmetic.ModularArithmetic,
    'parity_check':
        parity_check.ParityCheck,
    'even_pairs':
        even_pairs.EvenPairs,
    'star':
        star.Star,
    'star_free':
        star_free.StarFree,
    'local_threshold':
        local_threshold.LocalThreshold,
    'locally_testable':
        locally_testable.LocallyTestable,
    'strictly_local':
        strictly_local.StrictlyLocal,
    'definite':
        definite.Definite,
    'first':
        first.First,
    'first_two':
        first_two.FirstTwo,
    'last':
        last.Last,
    'piecewise_testable':
        piecewise_testable.PiecewiseTestable,
    'left_deterministic':
        left_deterministic.LeftDeterministic,
    'right_deterministic':
        right_deterministic.RightDeterministic,
    'bounded_dyck':
        bounded_dyck.BoundedDyck,
    'cycle_navigation':
        cycle_navigation.CycleNavigation,
    'modular_arithmetic_brackets':
        functools.partial(
            modular_arithmetic_brackets.ModularArithmeticBrackets, mult=True),
    'reverse_string':
        reverse_string.ReverseString,
    'missing_duplicate_string':
        missing_duplicate_string.MissingDuplicateString,
    'duplicate_string':
        duplicate_string.DuplicateString,
    'binary_addition':
        binary_addition.BinaryAddition,
    'binary_multiplication':
        binary_multiplication.BinaryMultiplication,
    'compute_sqrt':
        compute_sqrt.ComputeSqrt,
    'odds_first':
        odds_first.OddsFirst,
    'solve_equation':
        solve_equation.SolveEquation,
    'stack_manipulation':
        stack_manipulation.StackManipulation,
    'bucket_sort':
        bucket_sort.BucketSort,
    'anbn':
        anbn.AnBn,
    'even':
        even.Even,
    'even_length':
        even_length.EvenLength,
    'left_deterministic_blm':
        left_deterministic_blm.LeftDeterministicBLM,
    'right_deterministic_blm':
        right_deterministic_blm.RightDeterministicBLM,
    'left_deterministic_plm':
        left_deterministic_plm.LeftDeterministicPLM,
    'right_deterministic_plm':
        right_deterministic_plm.RightDeterministicPLM,
}

TASK_LEVELS = {
    'modular_arithmetic': 'regular',
    'parity_check': 'regular',
    'even_pairs': 'regular',
    'cycle_navigation': 'regular',
    'modular_arithmetic_brackets': 'dcf',
    'reverse_string': 'dcf',
    'stack_manipulation': 'dcf',
    'solve_equation': 'dcf',
    'missing_duplicate_string': 'cs',
    'compute_sqrt': 'cs',
    'duplicate_string': 'cs',
    'binary_addition': 'cs',
    'binary_multiplication': 'cs',
    'odds_first': 'cs',
    'bucket_sort': 'cs',
}
