"""Constants for the generalization project."""

import functools

import haiku as hk

from models import ndstack_rnn
from models import rnn
from models import stack_rnn
from models import tape_rnn
from models import transformer
from tasks.cs import binary_addition
from tasks.cs import binary_multiplication
from tasks.cs import bucket_sort
from tasks.cs import compute_sqrt
from tasks.cs import duplicate_string
from tasks.cs import missing_duplicate_string
from tasks.cs import odds_first
from tasks.dcf import modular_arithmetic_brackets
from tasks.dcf import reverse_string
from tasks.dcf import solve_equation
from tasks.dcf import stack_manipulation
from tasks.regular import cycle_navigation
from tasks.regular import even_pairs
from tasks.regular import modular_arithmetic
from tasks.regular import parity_check
from training import curriculum as curriculum_lib

MODEL_BUILDERS = {
    'rnn':
        functools.partial(rnn.make_rnn, rnn_core=hk.VanillaRNN),
    'lstm':
        functools.partial(rnn.make_rnn, rnn_core=hk.LSTM),
    'stack_rnn':
        functools.partial(
            rnn.make_rnn,
            rnn_core=stack_rnn.StackRNNCore,
            inner_core=hk.VanillaRNN),
    'ndstack_rnn':
        functools.partial(
            rnn.make_rnn,
            rnn_core=ndstack_rnn.NDStackRNNCore,
            inner_core=hk.VanillaRNN),
    'stack_lstm':
        functools.partial(
            rnn.make_rnn, rnn_core=stack_rnn.StackRNNCore, inner_core=hk.LSTM),
    'transformer_encoder':
        transformer.make_transformer_encoder,
    'transformer':
        transformer.make_transformer,
    'tape_rnn':
        functools.partial(
            rnn.make_rnn,
            rnn_core=tape_rnn.TapeInputLengthJumpCore,
            inner_core=hk.VanillaRNN),
}

CURRICULUM_BUILDERS = {
    'fixed': curriculum_lib.FixedCurriculum,
    'regular_increase': curriculum_lib.RegularIncreaseCurriculum,
    'reverse_exponential': curriculum_lib.ReverseExponentialCurriculum,
    'uniform': curriculum_lib.UniformCurriculum,
}

TASK_BUILDERS = {
    'modular_arithmetic':
        modular_arithmetic.ModularArithmetic,
    'parity_check':
        parity_check.ParityCheck,
    'even_pairs':
        even_pairs.EvenPairs,
    'cycle_navigation':
        cycle_navigation.CycleNavigation,
    'modular_arithmetic_brackets':
        functools.partial(
            modular_arithmetic_brackets.ModularArithmeticBrackets, mult=True),
    'reverse_string':
        reverse_string.ReverseString,
    'missing_duplicate_string':
        missing_duplicate_string.MissingDuplicateString,
    'duplicate_string':
        duplicate_string.DuplicateString,
    'binary_addition':
        binary_addition.BinaryAddition,
    'binary_multiplication':
        binary_multiplication.BinaryMultiplication,
    'compute_sqrt':
        compute_sqrt.ComputeSqrt,
    'odds_first':
        odds_first.OddsFirst,
    'solve_equation':
        solve_equation.SolveEquation,
    'stack_manipulation':
        stack_manipulation.StackManipulation,
    'bucket_sort':
        bucket_sort.BucketSort,
}

TASK_LEVELS = {
    'modular_arithmetic': 'regular',
    'parity_check': 'regular',
    'even_pairs': 'regular',
    'cycle_navigation': 'regular',
    'modular_arithmetic_brackets': 'dcf',
    'reverse_string': 'dcf',
    'stack_manipulation': 'dcf',
    'solve_equation': 'dcf',
    'missing_duplicate_string': 'cs',
    'compute_sqrt': 'cs',
    'duplicate_string': 'cs',
    'binary_addition': 'cs',
    'binary_multiplication': 'cs',
    'odds_first': 'cs',
    'bucket_sort': 'cs',
}

POS_ENC_TABLE = {
    'NONE': transformer.PositionalEncodings.NONE,
    'SIN_COS': transformer.PositionalEncodings.SIN_COS,
    'RELATIVE': transformer.PositionalEncodings.RELATIVE,
    'ALIBI': transformer.PositionalEncodings.ALIBI,
    'ROTARY': transformer.PositionalEncodings.ROTARY,
}

