from typing import Dict, Set, Optional, Callable, Any, Tuple, List
from dataclasses import dataclass
from enum import Enum
import re
import random

from frame.environments.ground_truth_types import GroundTruthEntity, EntityType, is_prime
from frame.tools.z3_template import Z3Template


NT_FUNCTIONS: Dict[str, GroundTruthEntity] = {
    # Arithmetic concepts
    "add": GroundTruthEntity(
        canonical_name="add",
        entity_type=EntityType.CONCEPT,
        description="Addition of natural numbers",
        computational_implementation=lambda a, b: a + b,
        discovered_names={"iterate_(successor)"},
        z3_translation=lambda a, b: Z3Template(f"""
            params 2;
            bounded params 0;
            ReturnExpr x_0 + x_1;
            ReturnPred None;
        """,
        a, b
        )
    ),
    "add_three_numbers": GroundTruthEntity(
        canonical_name="add_three_numbers",
        entity_type=EntityType.CONCEPT,
        description="Addition of natural numbers",
        computational_implementation=lambda a, b, c: a + b + c,
        discovered_names={"compose_(add_with_add_output_to_input_map={0: 0})", 
                          "compose_(add_with_add_output_to_input_map={0: 1})", 
                          "compose_(add_with_add_output_to_input_map={1: 0})", 
                          "compose_(add_with_add_output_to_input_map={1: 1})"},
        z3_translation=lambda a, b, c: Z3Template(f"""
            params 3;
            bounded params 0;
            ReturnExpr x_0 + x_1 + x_2;
            ReturnPred None;
        """,
        a, b, c
        )
    ),
    "double": GroundTruthEntity(
        canonical_name="double",
        entity_type=EntityType.CONCEPT,
        description="Double of a natural number",
        computational_implementation=lambda a: a + a,
        discovered_names={"matched_(add_indices_[0, 1])", "specialized_(multiply_at_0_to_two)", "specialized_(multiply_at_1_to_two)"}
    ),
    "multiply": GroundTruthEntity(
        canonical_name="multiply",
        entity_type=EntityType.CONCEPT,
        description="Multiplication of natural numbers",
        computational_implementation=lambda a, b: a * b,
        discovered_names={"iterate_(add_with_zero)"},
        z3_translation=lambda a, b: Z3Template(f"""
            params 2;
            bounded params 0;
            ReturnExpr x_0 * x_1;
            ReturnPred None;
        """,
        a, b
        )
    ),
    "square": GroundTruthEntity(
        canonical_name="square",
        entity_type=EntityType.CONCEPT,
        description="Square of a natural number",
        computational_implementation=lambda a: a * a,
        discovered_names={"matched_(multiply_indices_[0, 1])", "specialized_(power_at_1_to_two)"}
    ),
    "power": GroundTruthEntity(
        canonical_name="power",
        entity_type=EntityType.CONCEPT,
        description="Exponentiation of natural numbers",
        computational_implementation=lambda a, n: a ** n,
        discovered_names={"iterate_(multiply_with_one)"},
        z3_translation=lambda a, n: Z3Template(f"""
            params 2;
            bounded params 0;
            ReturnExpr x_0 ^ x_1; 
            ReturnPred None;
        """,
        a, n
        )
    ),
    "cube": GroundTruthEntity(
        canonical_name="cube",
        entity_type=EntityType.CONCEPT,
        description="Cube of a natural number",
        computational_implementation=lambda a: a * a * a,
        discovered_names={"specialized_(power_at_1_to_three)"}
    ),
    "fourth_power": GroundTruthEntity(
        canonical_name="fourth_power",
        entity_type=EntityType.CONCEPT,
        description="Fourth power of a natural number",
        computational_implementation=lambda a: a * a * a * a,
        discovered_names={"specialized_(power_at_1_to_four)"}
    ),
    "triple": GroundTruthEntity(
        canonical_name="triple",
        entity_type=EntityType.CONCEPT,
        description="Triple of a natural number",
        computational_implementation=lambda a: a * 3,
        discovered_names={"specialized_(multiply_at_0_to_three)", "specialized_(multiply_at_1_to_three)"}
    ),
    "ternary_power": GroundTruthEntity(
        canonical_name="ternary_power",
        entity_type=EntityType.CONCEPT,
        description="Ternary power of a natural number",
        computational_implementation=lambda a: 3 ** a,
        discovered_names={"specialized_(power_at_0_to_three)"}
    ),
    "tau": GroundTruthEntity( # TODO(_; 3/23): In general this should probably not have a computational implementation because we cannot just use the examples known in the universe to compute it.
        canonical_name="tau",
        entity_type=EntityType.CONCEPT,
        description="Number of divisors function",
        computational_implementation=lambda n: len([d for d in range(1, n+1) if n % d == 0]),
        discovered_names={"size_of_(divides_indices_[0])"},
        new_examples={(n, len([d for d in range(1, n+1) if n % d == 0]) ) for n in range (1,21)},
        new_nonexamples={(n, random.randint(1, 10) + len([d for d in range(1, n+1) if n % d == 0]) ) for n in range (1,21)}
    ),
    "tau_of_tau": GroundTruthEntity( # Note(_; 3/23): HR discovered concept
        canonical_name="tau_of_tau",
        entity_type=EntityType.CONCEPT,
        description="Tau of tau",
        computational_implementation=lambda n: len([d for d in range(1, len([k for k in range(1, n+1) if n % k == 0]) + 1) if len([k for k in range(1, n+1) if n % k== 0]) % d == 0]),
        discovered_names={"compose_(tau_with_tau_output_to_input_map={0: 0})"}
    ),
    "product_add": GroundTruthEntity(
        canonical_name="product_add",
        entity_type=EntityType.CONCEPT,
        description='Take product then add',
        computational_implementation=lambda a, b, c: a*b + c,
        discovered_names={"compose_(multiply_with_add_output_to_input_map={0: 0})", 
                          "compose_(multiply_with_add_output_to_input_map={0: 1})"}
    ),
    "sum_two_products": GroundTruthEntity(
        canonical_name="sum_two_products",
        entity_type=EntityType.CONCEPT,
        description='Sum the products of two pairs of numbers',
        computational_implementation=lambda a,b,c,d: a*b + c*d,
        discovered_names={"compose_(multiply_with_product_add_output_to_input_map={0: 2})"}
    ),
    "power_add": GroundTruthEntity(
      canonical_name="power_add",
      entity_type=EntityType.CONCEPT,
      description="Take power and add",
      computational_implementation=lambda a,b,c: a**b + c,
      discovered_names={"compose_(power_with_add_output_to_input_map={0: 0})",
                        "compose_(power_with_add_output_to_input_map={0: 1})"}
    ),
    "sum_two_powers": GroundTruthEntity(
      canonical_name="sum_two_powers",
      entity_type=EntityType.CONCEPT,
      description="Sum of two numbers to respective powers",
      computational_implementation=lambda a,b,c,d: a**b + c**d,
      discovered_names={"compose_(power_with_power_add_output_to_input_map={0: 2})"}
    ),
}

NT_PREDICATES:  Dict[str, GroundTruthEntity] = {
    "divides": GroundTruthEntity(
        canonical_name="divides",
        entity_type=EntityType.CONCEPT,
        description="Divisibility relation",
        computational_implementation=lambda a, b: a != 0 and b % a == 0,
        discovered_names={"exists_(multiply_indices_[0])", "exists_(multiply_indices_[1])"} # Note(_; 3/22): Duplication occurs because multiplication commutes.
    ), 
    "is_square": GroundTruthEntity(
        canonical_name="is_square",
        entity_type=EntityType.CONCEPT,
        description="Squareness predicate",
        computational_implementation=lambda a: a**0.5 == int(a**0.5),
        discovered_names={"exists_(square_indices_[0])"}
    ),
    "is_nonzero_square": GroundTruthEntity(
        canonical_name="is_nonzero_square",
        entity_type=EntityType.CONCEPT,
        description="Nonzero squareness predicate",
        computational_implementation=lambda a: a**0.5 == int(a**0.5) and a!= 0,
        discovered_names={"compose_(is_square_with_gt_zero_shared_vars={0: 0})",
                          "compose_(gt_zero_with_is_square_shared_vars={0: 0})"}
    ),
    "not_square": GroundTruthEntity(
        canonical_name="not_square",
        entity_type=EntityType.CONCEPT,
        description="not Squareness predicate",
        computational_implementation=lambda a: a**0.5 != int(a**0.5),
        discovered_names={"not_(is_square)"}
    ),
    "is_even": GroundTruthEntity( # TODO(_; 3/22): Failure in specialize for some reason, fix this. (when using two obtained via specialize on successor)
        canonical_name="is_even", # NOTE(_; 4/25): Specialize works,
        entity_type=EntityType.CONCEPT,
        description="Evenness predicate",
        computational_implementation=lambda n: n % 2 == 0,
        discovered_names={"exists_(double_indices_[0])", "specialized_(divides_at_0_to_two)", "not_(is_odd)"}
    ),
    "is_odd": GroundTruthEntity(
        canonical_name="is_odd",
        entity_type=EntityType.CONCEPT,
        description="Oddness predicate",
        computational_implementation=lambda n: n % 2 == 1,
        discovered_names={"not_(is_even)", "exists_(double_add_one_indices_[0])"}
    ),
    "leq": GroundTruthEntity(
        canonical_name="leq",
        entity_type=EntityType.CONCEPT,
        description="Less than or equal to predicate",
        computational_implementation=lambda a, b: a <= b,
        discovered_names={"exists_(add_indices_[0])", "exists_(add_indices_[1])"} # Note(_; 3/22): Duplication occurs because addition commutes.
    ),
    "gt": GroundTruthEntity(
        canonical_name="gt",
        entity_type=EntityType.CONCEPT,
        description="Greater than predicate",
        computational_implementation=lambda a, b: a > b,
        discovered_names={"not_(leq)"} 
    ),
    "ne": GroundTruthEntity(
        canonical_name="ne",
        entity_type=EntityType.CONCEPT,
        description="Not equal predicate",
        computational_implementation=lambda a, b: a != b,
        discovered_names={"not_(eq)"} 
    ),
    "lt": GroundTruthEntity(
        canonical_name="lt",
        entity_type=EntityType.CONCEPT,
        description="Less than predicate",
        computational_implementation=lambda a, b: a < b,
        discovered_names={"compose_(ne_with_leq_shared_vars={0: 0, 1: 1})",
                          "compose_(leq_with_ne_shared_vars={0: 0, 1: 1})",
                          "compose_(leq_with_ne_shared_vars={0: 1, 1: 0})"} 
    ),
    "geq": GroundTruthEntity(
        canonical_name="geq",
        entity_type=EntityType.CONCEPT,
        description="Greater than or equal to predicate",
        computational_implementation=lambda a, b: a < b,
        discovered_names={"not_(lt)"} 
    ),
    "not_divides": GroundTruthEntity( # Note(_; 3/25): This is a ground truth concept, but its name does not differ. I think it is important we indicate that the model finds this.
        canonical_name="not_divides", 
        entity_type=EntityType.CONCEPT,
        description="Not divides predicate",
        computational_implementation=lambda a, b: not b % a == 0 ,
        discovered_names={"not_(divides)"}
    ),
    "is_nonunit_divisor_of": GroundTruthEntity(
        canonical_name="is_nonunit_divisor_of",
        entity_type=EntityType.CONCEPT,
        description="Is nonunit divisor of predicate",
        computational_implementation=lambda a, b: b % a == 0 and a > 1,
        discovered_names={"compose_(divides_with_geq_two_shared_vars={0:0})", "compose_(divides_with_gt_one_shared_vars={0: 0})"}
    ),
    "is_proper_divisor_of": GroundTruthEntity(
        canonical_name="is_proper_divisor_of",
        entity_type=EntityType.CONCEPT,
        description="Is proper divisor of predicate",
        computational_implementation=lambda a, b: b % a == 0 and a > 1 and a != b,
        discovered_names={"compose_(divides_with_gt_shared_vars={0:1,1:0})"}
    ),
    "is_prime": GroundTruthEntity(
        canonical_name="is_prime",
        entity_type=EntityType.CONCEPT,
        description='Only divisors are itself and one',
        computational_implementation=lambda n: is_prime(n),
        discovered_names={"specialized_(tau_output_eq_two)"}
    ),
    "division_algo_condition": GroundTruthEntity(
        canonical_name="division_algo_condition",
        entity_type=EntityType.CONCEPT,
        description='aq +r = b and r < a',
        computational_implementation=lambda a, q,r, b: a*q + r == b and r <a,
        discovered_names={"compose_(eq_product_add_with_lt_shared_vars={0: 1, 2: 0})"}
    ),
    "is_mod": GroundTruthEntity(
        canonical_name="is_mod",
        entity_type=EntityType.CONCEPT,
        description='True if b = c mod a. Require b <= c',
        computational_implementation=lambda a,b,c: b <=c and c%a==b ,
        discovered_names={"exists_(product_add_indices_[0])",
                          "exists_(product_add_indices_[1])"}
    )
}

NT_NUMBERS: Dict[str, GroundTruthEntity] = {
    "eq_zero": GroundTruthEntity(
        canonical_name='eq_zero',
        entity_type=EntityType.CONCEPT,
        description='Is equal to 0 / less than  1',
        computational_implementation=lambda a: a == 0,
        discovered_names={"specialized_(eq_at_0_to_zero)", "specialized_(eq_at_1_to_zero)", 
                          "specialized_(lt_at_1_to_one)", "specialized_(leq_at_1_to_zero)",
                          "specialized_(gt_at_0_to_one)", "specialized_(geq_at_0_to_zero)"}
    ),
    "geq_zero": GroundTruthEntity(
        canonical_name="geq_zero",
        entity_type=EntityType.CONCEPT,
        description="Greater than or equal to zero predicate",
        computational_implementation=lambda a: a >= 0,
        discovered_names={"specialized_(leq_at_0_to_zero)", "specialized_(geq_at_1_to_zero)",
                          }
    ),
    "gt_zero": GroundTruthEntity(
        canonical_name="gt_zero",
        entity_type=EntityType.CONCEPT,
        description="Greater than zero predicate / positive / not equal zero",
        computational_implementation=lambda a: a > 1,
        discovered_names={"specialized_(gt_at_1_to_zero)", "specialized_(geq_at_1_to_one)",
                          "specialized_(leq_at_0_to_one)", "specialized_(lt_at_0_to_zero)",
                          "not_(eq_zero)", 
                          "specialized_(ne_at_0_to_zero)", "specialized_(ne_at_1_to_zero)"}
    ),
    "one": GroundTruthEntity(
        canonical_name="one",
        entity_type=EntityType.CONCEPT,
        description="The natural number 1",
        computational_implementation=lambda: 1,
        discovered_names={"specialized_(successor_at_0_to_zero)"}
    ),
    #lt_one = eq_zero
    "eq_one": GroundTruthEntity(
        canonical_name='eq_one',
        entity_type=EntityType.CONCEPT,
        description='Is equal to 1',
        computational_implementation=lambda a: a == 1,
        discovered_names={"specialized_(eq_at_0_to_one)", "specialized_(eq_at_1_to_one)"}
    ),
    "ne_one": GroundTruthEntity(
        canonical_name='ne_one',
        entity_type=EntityType.CONCEPT,
        description='not equal to 1',
        computational_implementation=lambda a: a != 1,
        discovered_names={"specialized_(ne_at_0_to_one)", "specialized_(ne_at_1_to_one)",
                          "not_(eq_one)"}
    ),
    "gt_one": GroundTruthEntity(
        canonical_name="gt_one",
        entity_type=EntityType.CONCEPT,
        description="Greater than one predicate",
        computational_implementation=lambda a: a > 1,
        discovered_names={"specialized_(gt_at_1_to_one)", "specialized_(geq_at_1_to_two)",
                          "specialized_(leq_at_0_to_two)", "specialized_(lt_at_0_to_one)"}
    ),
    "two": GroundTruthEntity(
        canonical_name="two",
        entity_type=EntityType.CONCEPT,
        description="The natural number 2",
        computational_implementation=lambda: 2,
        discovered_names={"specialized_(successor_at_0_to_one)"} 
    ),
    "lt_two": GroundTruthEntity(
        canonical_name='leq_one',
        entity_type=EntityType.CONCEPT,
        description='< 2, <= 1',
        computational_implementation=lambda a: a <= 1,
        discovered_names={"specialized_(leq_at_1_to_one)", "specialized_(lt_at_1_to_two)",
                          "specialized_(geq_at_0_to_three)", "specialized_(gt_at_0_to_two)"}
    ),
    "eq_two": GroundTruthEntity(
        canonical_name='eq_two',
        entity_type=EntityType.CONCEPT,
        description='Is equal to 2',
        computational_implementation=lambda a: a == 2,
        discovered_names={"specialized_(eq_at_0_to_two)", "specialized_(eq_at_1_to_two)"}
    ),
    "ne_two": GroundTruthEntity(
        canonical_name='ne_two',
        entity_type=EntityType.CONCEPT,
        description='not equal to 2',
        computational_implementation=lambda a: a != 2,
        discovered_names={"specialized_(ne_at_0_to_two)", "specialized_(ne_at_1_to_two)",
                          "not_(eq_two)"}
    ),
    "gt_two": GroundTruthEntity(
        canonical_name="gt_two",
        entity_type=EntityType.CONCEPT,
        description="Greater than two predicate",
        computational_implementation=lambda a: a > 2,
        discovered_names={"specialized_(gt_at_1_to_two)", "specialized_(geq_at_1_to_three)",
                          "specialized_(leq_at_0_to_three)", "specialized_(lt_at_0_to_two)"}
    ),
    "three": GroundTruthEntity(
        canonical_name="three",
        entity_type=EntityType.CONCEPT,
        description="The natural number 3",
        computational_implementation=lambda: 3,
        discovered_names={"specialized_(successor_at_0_to_two)"} 
    ),
    "four": GroundTruthEntity(
        canonical_name="four",
        entity_type=EntityType.CONCEPT,
        description="The natural number 4",
        computational_implementation=lambda: 4,
        discovered_names={"specialized_(successor_at_0_to_three)"} 
    ), 
    "five": GroundTruthEntity(
        canonical_name="five",
        entity_type=EntityType.CONCEPT,
        description="The natural number 5",
        computational_implementation=lambda: 5,
        discovered_names={"specialized_(successor_at_0_to_four)"}
    ), 
    "six": GroundTruthEntity(
        canonical_name="six",
        entity_type=EntityType.CONCEPT,
        description="The natural number 6",
        computational_implementation=lambda: 6,
        discovered_names={"specialized_(successor_at_0_to_five)"}
    ), 
    "gt_five": GroundTruthEntity(
        canonical_name="gt_five",
        entity_type=EntityType.CONCEPT,
        description="Greater than five predicate",
        computational_implementation=lambda a: a > 5,
        discovered_names={"specialized_(gt_at_1_to_five)", "specialized_(geq_at_1_to_six)",
                          "specialized_(leq_at_0_to_six)", "specialized_(lt_at_0_to_five)"}
    ), 
}