import random
from nltk.grammar import Nonterminal

grammar_dict = {
    "ab_repeated": """
        S -> 'a' 'b' | 'a' 'b' S
    """,

    "abc_repeated": """
        S -> 'a' 'b' 'c' | 'a' 'b' 'c' S
    """,

    "anbn": """
        S -> 'a' 'b' | 'a' S 'b'
    """,

    "balanced_parenthesis_nested": """
        S -> '(' ')' | '(' S ')'
    """,

    "balanced_parenthesis": """
        S -> '('')'
        S -> '['']'
        S -> '{''}'
        S -> '(' S ')'
        S -> '[' S ']'
        S -> '{' S '}'
        S -> S S
    """,

    "pcfg_ab_repeated": """
        S -> 'a' 'b' [0.1] | 'a' 'b' S [0.9]
    """,

    
    "pcfg_balanced_parenthesis": """
        S -> S1 [1]
        S1 -> A1 [0.3]
        S1 -> A2 [0.3]
        S1 -> A3 [0.3]
        S1 -> S1 S1 [0.1]
        A1 -> '('')' [0.3]
        A1 -> '('S1')' [0.7]
        A2 -> '['']' [0.3]
        A2 -> '['S1']' [0.7]
        A3 -> '{''}' [0.3]
        A3 -> '{'S1'}' [0.7]
        """,

    "pcfg_reverse_string": """
        S -> A1 [1]
        A1 -> 'a' A1 'a' [0.475]
        A1 -> 'b' A1 'b' [0.475]
        A1 ->  [0.05]
    """,

    "pcfg_one_character_missing": """
        S -> B [0.1666]
        S -> C [0.1666]
        S -> A [0.1666]
        S -> C [0.1666]
        S -> A [0.1666]
        S -> B [0.1666]
        A -> 'b' A [0.495]
        A -> 'b' [0.005]
        A -> 'c' A [0.495]
        A -> 'c' [0.005]
        B -> 'a' B [0.495]
        B -> 'a' [0.005]
        B -> 'c' B [0.495]
        B -> 'c' [0.005]
        C -> 'a' C [0.495]
        C -> 'a' [0.005]
        C -> 'b' C [0.495]
        C -> 'b' [0.005]
    """,

    "example_1": """
        S -> NP VP
        S -> Aux NP VP
        S -> VP
        NP -> Det NOM
        NOM -> Noun
        NOM -> Noun NOM
        VP -> Verb
        VP -> Verb NP
        Det -> 'that ' | 'this ' | 'a ' | 'the '
        Noun -> 'book ' | 'flight ' | 'meal ' | 'man '
        Verb -> 'book ' | 'include ' | 'read '
        Aux -> 'does '
    """, 


    "pcfg_ab": """
        S -> 'a' 'b' [0.01]
        S -> 'a' 'b' S [0.99]
    """,

    "pcfg_ab_aabb": """
            S -> A [0.5]
            S -> B [0.5]
            A -> 'a' 'b' [0.01]
            A -> 'a' 'b' A [0.99]
            B -> 'a' 'a' 'b' 'b' [0.01]
            B -> 'a' 'a' 'b' 'b' B [0.99]
    """,

    "pcfg_ab_aabb_mixed": """
            S -> 'a' 'b' [0.01]
            S -> 'a' 'a' 'b' 'b' [0.01]
            S -> 'a' 'b' S [0.49]
            S -> 'a' 'a' 'b' 'b' S [0.49]
    """,

    "cfg3b": """
        A22 -> A21 A20 
        A22 -> A20 A19 
        A19 -> A16 A17 A18 
        A19 -> A17 A18 A16 
        A20 -> A17 A16 A18 
        A20 -> A16 A17 
        A21 -> A18 A16 
        A21 -> A16 A18 A17 
        A16 -> A15 A13 
        A16 -> A13 A15 A14 
        A17 -> A14 A13 A15 
        A17 -> A15 A13 A14 
        A18 -> A15 A14 A13 
        A18 -> A14 A13 
        A13 -> A11 A12 
        A13 -> A12 A11 
        A14 -> A11 A10 A12 
        A14 -> A10 A11 A12 
        A15 -> A12 A11 A10 
        A15 -> A11 A12 A10 
        A10 -> A7 A9 A8 
        A10 -> A9 A8 A7 
        A11 -> A8 A7 A9 
        A11 -> A7 A8 A9 
        A12 -> A8 A9 A7 
        A12 -> A9 A7 A8 
        A7 -> '3' '1' 
        A7 -> '1' '2' '3' 
        A8 -> '3' '2' 
        A8 -> '3' '1' '2' 
        A9 -> '3' '2' '1' 
        A9 -> '2' '1'
    """,

    "pcfg_cfg3b": """
        A22 -> A21 A20 [0.50]
        A22 -> A20 A19 [0.50]
        A19 -> A16 A17 A18 [0.50]
        A19 -> A17 A18 A16 [0.50]
        A20 -> A17 A16 A18 [0.50]
        A20 -> A16 A17 [0.50]
        A21 -> A18 A16 [0.50]
        A21 -> A16 A18 A17 [0.50]
        A16 -> A15 A13 [0.50]
        A16 -> A13 A15 A14 [0.50]
        A17 -> A14 A13 A15 [0.50]
        A17 -> A15 A13 A14 [0.50]
        A18 -> A15 A14 A13 [0.50]
        A18 -> A14 A13 [0.50]
        A13 -> A11 A12 [0.50]
        A13 -> A12 A11 [0.50]
        A14 -> A11 A10 A12 [0.50]
        A14 -> A10 A11 A12 [0.50]
        A15 -> A12 A11 A10 [0.50]
        A15 -> A11 A12 A10 [0.50]
        A10 -> A7 A9 A8 [0.50]
        A10 -> A9 A8 A7 [0.50]
        A11 -> A8 A7 A9 [0.50]
        A11 -> A7 A8 A9 [0.50]
        A12 -> A8 A9 A7 [0.50]
        A12 -> A9 A7 A8 [0.50]
        A7 -> '3' '1' [0.50]
        A7 -> '1' '2' '3' [0.50]
        A8 -> '6' '5' [0.50]
        A8 -> '6' '4' '5' [0.50]
        A9 -> '9' '8' '7' [0.50]
        A9 -> '8' '7' [0.50]
    """,


    "pcfg_cfg3b_disjoint_terminals_skewed_prob": """
        S -> A16 [1]
        A16 -> A15 A13 [0.90]
        A16 -> A13 A15 A14 [0.10]
        A13 -> A11 A12 [0.90]
        A13 -> A12 A11 [0.10]
        A14 -> A11 A10 A12 [0.90]
        A14 -> A10 A11 A12 [0.10]
        A15 -> A12 A11 A10 [0.90]
        A15 -> A11 A12 A10 [0.10]
        A10 -> A7 A9 A8 [0.90]
        A10 -> A9 A8 A7 [0.10]
        A11 -> A8 A7 A9 [0.90]
        A11 -> A7 A8 A9 [0.10]
        A12 -> A8 A9 A7 [0.90]
        A12 -> A9 A7 A8 [0.10]
        A7 -> '3' '1' [0.90]
        A7 -> '1' '2' '3' [0.10]
        A8 -> '6' '5' [0.90]
        A8 -> '6' '4' '5' [0.10]
        A9 -> '9' '8' '7' [0.90]
        A9 -> '8' '7' [0.10]
    """,


    "pcfg_cfg3b_disjoint_terminals": """
        S -> A16 [1]
        A16 -> A15 A13 [0.50]
        A16 -> A13 A15 A14 [0.50]
        A13 -> A11 A12 [0.50]
        A13 -> A12 A11 [0.50]
        A14 -> A11 A10 A12 [0.50]
        A14 -> A10 A11 A12 [0.50]
        A15 -> A12 A11 A10 [0.50]
        A15 -> A11 A12 A10 [0.50]
        A10 -> A7 A9 A8 [0.50]
        A10 -> A9 A8 A7 [0.50]
        A11 -> A8 A7 A9 [0.50]
        A11 -> A7 A8 A9 [0.50]
        A12 -> A8 A9 A7 [0.50]
        A12 -> A9 A7 A8 [0.50]
        A7 -> '3' '1' [0.50]
        A7 -> '1' '2' '3' [0.50]
        A8 -> '6' '5' [0.50]
        A8 -> '6' '4' '5' [0.50]
        A9 -> '9' '8' '7' [0.50]
        A9 -> '8' '7' [0.50]
    """,

    "pcfg_cfg3b_eq_len_skewed_prob": """
        S -> A16 [1]
        A16 -> A15 A14 A13 [0.95]
        A16 -> A13 A15 A14 [0.05]
        A13 -> A11 A12 [0.95]
        A13 -> A12 A11 [0.05]
        A14 -> A11 A10 A12 [0.95]
        A14 -> A10 A11 A12 [0.05]
        A15 -> A12 A11 A10 [0.95]
        A15 -> A11 A12 A10 [0.05]
        A10 -> A7 A9 A8 [0.95]
        A10 -> A9 A8 A7 [0.05]
        A11 -> A8 A7 A9 [0.95]
        A11 -> A7 A8 A9 [0.05]
        A12 -> A8 A9 A7 [0.95]
        A12 -> A9 A7 A8 [0.05]
        A7 -> '3' '1' '2' [0.95]
        A7 -> '1' '2' '3' [0.05]
        A8 -> '6' '5' '4' [0.95]
        A8 -> '6' '4' '5' [0.05]
        A9 -> '9' '8' '7' [0.95]
        A9 -> '8' '7' '9' [0.05]
    """,


    "pcfg_cfg3b_eq_len_uniform_prob": """
        S -> A16 [1]
        A16 -> A15 A14 A13 [0.50]
        A16 -> A13 A15 A14 [0.50]
        A13 -> A11 A12 [0.50]
        A13 -> A12 A11 [0.50]
        A14 -> A11 A10 A12 [0.50]
        A14 -> A10 A11 A12 [0.50]
        A15 -> A12 A11 A10 [0.50]
        A15 -> A11 A12 A10 [0.50]
        A10 -> A7 A9 A8 [0.50]
        A10 -> A9 A8 A7 [0.50]
        A11 -> A8 A7 A9 [0.50]
        A11 -> A7 A8 A9 [0.50]
        A12 -> A8 A9 A7 [0.50]
        A12 -> A9 A7 A8 [0.50]
        A7 -> '3' '1' '2' [0.50]
        A7 -> '1' '2' '3' [0.50]
        A8 -> '6' '5' '4' [0.50]
        A8 -> '6' '4' '5' [0.50]
        A9 -> '9' '8' '7' [0.50]
        A9 -> '8' '7' '9' [0.50]
    """,

    "pcfg_cfg3b_disjoint_terminals_one_rule_missing": """
        S -> A16 [1]
        A16 -> A15 A13 [0.50]
        A16 -> A13 A15 A14 [0.50]
        A13 -> A11 A12 [0.50]
        A13 -> A12 A11 [0.50]
        A14 -> A11 A10 A12 [0.50]
        A14 -> A10 A11 A12 [0.50]
        A15 -> A12 A11 A10 [0.50]
        A15 -> A11 A12 A10 [0.50]
        A10 -> A7 A9 A8 [0.50]
        A10 -> A9 A8 A7 [0.50]
        A11 -> A8 A7 A9 [1]
        A12 -> A8 A9 A7 [0.50]
        A12 -> A9 A7 A8 [0.50]
        A7 -> '3' '1' [0.50]
        A7 -> '1' '2' '3' [0.50]
        A8 -> '6' '5' [0.50]
        A8 -> '6' '4' '5' [0.50]
        A9 -> '9' '8' '7' [0.50]
        A9 -> '8' '7' [0.50]
    """
}

grammar_details_dict = {
    "pcfg_cfg3b_disjoint_terminals": {
        "level_to_nonterminals": {
            4: [Nonterminal("A16")],
            3: [Nonterminal("A15"), Nonterminal("A14"), Nonterminal("A13")],
            2: [Nonterminal("A12"), Nonterminal("A11"), Nonterminal("A10")],
            1: [Nonterminal("A9"), Nonterminal("A8"), Nonterminal("A7")],
        },
        "nonterminal_to_level": {
            Nonterminal("S"): -100,
            Nonterminal("A16"): 4,
            Nonterminal("A15"): 3,
            Nonterminal("A14"): 3,
            Nonterminal("A13"): 3,
            Nonterminal("A12"): 2,
            Nonterminal("A11"): 2,
            Nonterminal("A10"): 2,
            Nonterminal("A9"): 1,
            Nonterminal("A8"): 1,
            Nonterminal("A7"): 1,
        }
    },

    "pcfg_cfg3b_disjoint_terminals_skewed_prob": {
        "level_to_nonterminals": {
            4: [Nonterminal("A16")],
            3: [Nonterminal("A15"), Nonterminal("A14"), Nonterminal("A13")],
            2: [Nonterminal("A12"), Nonterminal("A11"), Nonterminal("A10")],
            1: [Nonterminal("A9"), Nonterminal("A8"), Nonterminal("A7")],
        },
        "nonterminal_to_level": {
            Nonterminal("S"): -100,
            Nonterminal("A16"): 4,
            Nonterminal("A15"): 3,
            Nonterminal("A14"): 3,
            Nonterminal("A13"): 3,
            Nonterminal("A12"): 2,
            Nonterminal("A11"): 2,
            Nonterminal("A10"): 2,
            Nonterminal("A9"): 1,
            Nonterminal("A8"): 1,
            Nonterminal("A7"): 1,
        }
    }
}


def hierarchical_grammar_builder(depth, breadth, terminals=["1", "2", "3", "4"], seed=10):
    assert isinstance(depth, int)
    assert isinstance(breadth, int)
    assert depth > 0
    assert breadth > 0
    assert len(terminals) > 2

    print(f"Depth: {depth}, Breadth: {breadth}, Terminals: {terminals}")
    

    random.seed(seed)
    num_terminals = len(terminals)
    terminal_t_list = terminals[:num_terminals//2]
    terminal_c_list = terminals[num_terminals//2:]
    
    # separator between two hierarchical grammar rules
    terminal_t_rules = []   
    terminal_c_rules = []
    for i, terminal in enumerate(terminal_t_list):
        terminal_t_rules.append(f"T_1_{i+1} -> '{terminal}' [1]")
    for i, terminal in enumerate(terminal_c_list):
        terminal_c_rules.append(f"C_1_{i+1} -> '{terminal}' [1]")
    

    u_dict = {
        "b" : [],
        "e" : []
    }
    for i in range(breadth*2):
        length = random.randint(1, breadth*2-1)
        s = random.choices(terminals, k=length)
        u_dict["b"].append(s + random.choices(terminal_t_list, k=1))
        u_dict["e"].append(s + random.choices(terminal_c_list, k=1))
    
    # print(u_dict)
    # print(terminal_t_rules)
    # print(terminal_c_rules)

    def recursive_grammar_builder(depth, type, breadth):
        # print("Called with depth", depth, "type", type, "breadth", breadth)
        assert type in ["b", "e"]
        # assert breadth >= 1 and breadth <= 3
        assert breadth >= 1
        non_terminal = f"B_{depth}" if type == "b" else f"E_{depth}"
        if depth == 0:
            return []
        elif depth == 1:
            rules = []
            for i, s in enumerate(random.choices(u_dict[type], k=breadth)):
                s = " ".join(f"'{terminal}'" for terminal in s)
                # rules.append(f"{non_terminal}_{i+1} -> {s} [1]")
                rules.append(f"{non_terminal} -> {s} [{round(1/breadth, 4)}]")
            return rules
        else:
            
            # non_terminals_child = [f"B_{depth-1}_{i+1}" if type == "b" else f"E_{depth-1}_{i+1}" for i in range(breadth)]
            non_terminals_child = [f"B_{depth-1}" if type == "b" else f"E_{depth-1}" for i in range(breadth)]
            rules = []
            for i in range(breadth):
                # rules.append(f"{non_terminal}_{i+1} -> {' '.join(random.choices(non_terminals_child, k=random.randint(1, breadth)))} [{round(1/breadth, 4)}]")
                # rules.append(f"{non_terminal}_{i+1} -> {' '.join(random.choices(non_terminals_child, k=random.randint(1, breadth)))} [1]")
                rules.append(f"{non_terminal} -> {' '.join(random.choices(non_terminals_child, k=random.randint(1, breadth)))} [{round(1/breadth, 4)}]")
            rules.extend(recursive_grammar_builder(depth-1, type, breadth))
            return rules



    start_rule = [
        f"S -> S_{depth+1} [1]",
        # f"S_{depth+1} -> B_{depth} C_1_1 E_{depth} T_1_1 [0.5]",
        # f"S_{depth+1} -> B_{depth} C_1_2 E_{depth} T_1_2 [0.5]",
    ]

    for i in range(num_terminals//2):
        start_rule.append(f"S_{depth+1} -> B_{depth} C_1_{i+1} E_{depth} T_1_{i+1} [{round(1/(num_terminals//2), 4)}]")

    # start_rule = [
    #     f"S -> S_{depth+1} [1]",
    #     f"S_{depth+1} -> S_{depth+1}_{1} T1 [0.5]",
    #     f"S_{depth+1} -> S_{depth+1}_{2} T2 [0.5]",
    # ]

    # s_rule = [
    #     f"S_{depth+1}_1 -> B_{depth} C1 E_{depth} [1]",
    #     f"S_{depth+1}_2 -> B_{depth} C2 E_{depth} [1]",
    # ]
    
    # b_rule = [f"B_{depth} -> B_{depth}_{i+1} [{round(1/breadth, 4)}]" for i in range(breadth)]
    # e_rule = [f"E_{depth} -> E_{depth}_{i+1} [{round(1/breadth, 4)}]" for i in range(breadth)]

    # print("Start Rule:")
    # print(start_rule)

    # print("G Rule:")
    # print(g_rule)


    # print()
    # print(list(recursive_grammar_builder(depth, "b")))
    # print()
    # print(recursive_grammar_builder(depth, "e"))

    rule_builder = []
    rule_builder.extend(start_rule)
    # rule_builder.extend(s_rule)
    # rule_builder.extend(b_rule)
    # rule_builder.extend(e_rule)
    rule_builder.extend(recursive_grammar_builder(depth, "b", breadth))
    rule_builder.extend(recursive_grammar_builder(depth, "e", breadth))
    rule_builder.extend(terminal_t_rules)
    rule_builder.extend(terminal_c_rules)

    # print(rule_builder)

    return "\n".join(rule_builder)




# print("\n".join(hierarchical_grammar_builder(6, 3, 0)))










# def hierarchical_grammar_builder(depth, breadth, seed=10):
#     assert isinstance(depth, int)
#     assert isinstance(breadth, int)
#     assert depth > 0
#     assert breadth > 0

#     terminal_t_list = ['1', '2']
#     terminal_c_list = ['3', '4']
#     terminal_list = terminal_t_list + terminal_c_list
#     u_dict = {
#         "b" : [],
#         "e" : []
#     }

#     random.seed(seed)
#     terminal_t_rules = []   
#     terminal_c_rules = []
#     for i, terminal in enumerate(terminal_t_list):
#         terminal_t_rules.append(f"T{i+1} -> '{terminal}' [1]")
#     for i, terminal in enumerate(terminal_c_list):
#         terminal_c_rules.append(f"C{i+1} -> '{terminal}' [1]")
    
#     for i in range(breadth*2):
#         len = random.randint(1, breadth*2-1)
#         s = random.choices(terminal_list, k=len)
#         u_dict["b"].append(s + random.choices(terminal_t_list, k=1))
#         u_dict["e"].append(s + random.choices(terminal_c_list, k=1))
    
#     # print(u_dict)
#     # print(terminal_t_rules)
#     # print(terminal_c_rules)

#     def recursive_grammar_builder(depth, type, breadth=3):
#         # print("Called with depth", depth, "type", type, "breadth", breadth)
#         assert type in ["b", "e"]
#         assert breadth >= 1 and breadth <= 3
#         non_terminal = f"B_{depth}" if type == "b" else f"E_{depth}"
#         if depth == 0:
#             return []
#         elif depth == 1:
#             rules = []
#             for i, s in enumerate(random.choices(u_dict[type], k=breadth)):
#                 s = " ".join(f"'{terminal}'" for terminal in s)
#                 rules.append(f"{non_terminal}_{i+1} -> {s} [1]")
#             return rules
#         else:
            
#             non_terminals_child = [f"B_{depth-1}_{i+1}" if type == "b" else f"E_{depth-1}_{i+1}" for i in range(breadth)]
#             rules = []
#             for i in range(breadth):
#                 # rules.append(f"{non_terminal}_{i+1} -> {' '.join(random.choices(non_terminals_child, k=random.randint(1, breadth)))} [{round(1/breadth, 4)}]")
#                 rules.append(f"{non_terminal}_{i+1} -> {' '.join(random.choices(non_terminals_child, k=random.randint(1, breadth)))} [1]")
#             rules.extend(recursive_grammar_builder(depth-1, type, breadth))
#             return rules
        
#     start_rule = [
#         f"S -> S_{depth+1} [1]",
#         f"S_{depth+1} -> S_{depth+1}_{1} T1 [0.5]",
#         f"S_{depth+1} -> S_{depth+1}_{2} T2 [0.5]",
#     ]

#     s_rule = [
#         f"S_{depth+1}_1 -> B_{depth} C1 E_{depth} [1]",
#         f"S_{depth+1}_2 -> B_{depth} C2 E_{depth} [1]",
#     ]
    
#     b_rule = [f"B_{depth} -> B_{depth}_{i+1} [{round(1/breadth, 4)}]" for i in range(breadth)]
#     e_rule = [f"E_{depth} -> E_{depth}_{i+1} [{round(1/breadth, 4)}]" for i in range(breadth)]

#     # print("Start Rule:")
#     # print(start_rule)

#     # print("G Rule:")
#     # print(g_rule)


#     # print()
#     # print(list(recursive_grammar_builder(depth, "b")))
#     # print()
#     # print(recursive_grammar_builder(depth, "e"))

#     rule_builder = []
#     rule_builder.extend(start_rule)
#     rule_builder.extend(s_rule)
#     rule_builder.extend(b_rule)
#     rule_builder.extend(e_rule)
#     rule_builder.extend(recursive_grammar_builder(depth, "b"))
#     rule_builder.extend(recursive_grammar_builder(depth, "e"))
#     rule_builder.extend(terminal_t_rules)
#     rule_builder.extend(terminal_c_rules)

#     # print(rule_builder)

#     return "\n".join(rule_builder)




# # print("\n".join(hierarchical_grammar_builder(6, 3, 0)))



