def complete_prompt(concepts_str):
    prompt = f"""
# Instruction

Given several concepts (i.e., nouns or verbs), write a short and simple sentence that contains *all* the required words.
The sentence should describe a common scene in daily life, and the concepts should be used in a natural way.

# Examples

## Example 1
- Concepts: "dog, frisbee, catch, throw"
- Sentence: The dog catches the frisbee when the boy throws it into the air.

## Example 2
- Concepts: "apple, place, tree, pick"
- Sentence: A girl picks some apples from a tree and places them into her basket.

# Your Task

- Concepts: {concepts_str}
- Sentence:"""

    return prompt

def compute_token_variants_len(variants, tokenizer):
    token_variants_len = []
    for variant in range(len(variants)):
        token_variants_len.append([])
        for var in range(len(variants[variant])):
            token_variants_len[variant].append(len(tokenizer.tokenize(variants[variant][var])))
    return token_variants_len

import heapq

def dijkstra_from_acceptance(matrix, acceptance_states):
    n = len(matrix)
    dist = [float('inf')] * n
    pq = []
    
    inverted_adj = [[] for _ in range(n)]
    for u in range(n):
        for cost, v in matrix[u]:
            inverted_adj[v-1].append((u+1, cost))
    
    for state in acceptance_states:
        dist[state-1] = 0
        heapq.heappush(pq, (0, state-1))
    
    while pq:
        current_dist, u = heapq.heappop(pq)
        if current_dist > dist[u]:
            continue
        for v, cost in inverted_adj[u]:
            v_index = v - 1
            new_dist = current_dist + cost
            if new_dist < dist[v_index]:
                dist[v_index] = new_dist
                heapq.heappush(pq, (new_dist, v_index))
    
    return {i+1: dist[i] for i in range(n)}

def commongen_dfa(variants, tokenizer):
    transition_matrix = []

    variants_len = compute_token_variants_len(variants, tokenizer)
        
    if len(variants) == 3:

        accepting_states = [11]
        arcs = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]

        for row in range(11):

            row_items = []

            if row == 0:
                row_items.append([(x, 5) for x in variants_len[0]] + [(x, 3) for x in variants_len[1]] + [(x, 2) for x in variants_len[2]] + [(1, 9)] + [(1, 9)] + [(1, 1)])
            if row == 1:
                row_items.append([(x, 6) for x in variants_len[0]] + [(x, 4) for x in variants_len[1]] + [(x, 2) for x in variants_len[2]] + [(1, 9)] + [(1, 9)] + [(1, 2)])
            if row == 2:
                row_items.append([(x, 7) for x in variants_len[0]] + [(x, 3) for x in variants_len[1]] + [(x, 4) for x in variants_len[2]] + [(1, 9)] + [(1, 9)] + [(1, 3)])
            if row == 3:
                row_items.append([(x, 8) for x in variants_len[0]] + [(x, 4) for x in (variants_len[1] + variants_len[2])] + [(1, 9)] + [(1, 9)]+ [(1, 4)])
            if row == 4:
                row_items.append([(x, 5) for x in variants_len[0]] + [(x, 7) for x in variants_len[1]] + [(x, 6) for x in variants_len[2]] + [(1, 9)] + [(1, 9)] + [(1, 5)])
            if row == 5:
                row_items.append([(x, 6) for x in variants_len[0]] + [(x, 8) for x in variants_len[1]] + [(x, 6) for x in variants_len[2]] + [(1, 9)] + [(1, 9)] + [(1, 6)])
            if row == 6:
                row_items.append([(x, 7) for x in (variants_len[0] + variants_len[1])] + [(x, 8) for x in variants_len[2]] + [(1, 9)] + [(1, 9)] + [(1, 7)])
            if row == 7:
                row_items.append([(x, 8) for x in (variants_len[0] + variants_len[1] + variants_len[2])] + [(1, 10)] + [(1, 9)] + [(1, 8)]) # + [(1, 8)] + [(1, 10)])
            if row == 8:
                row_items.append([(x, 9) for x in (variants_len[0] + variants_len[1] + variants_len[2])] + [(1, 9)] + [(1, 9)] + [(1, 9)])
            if row == 9:
                row_items.append([(x, 9) for x in (variants_len[0] + variants_len[1] + variants_len[2])] + [(1, 9)] + [(1, 11)] + [(1, 9)])
            if row == 10:
                row_items.append([(x, 11) for x in (variants_len[0] + variants_len[1] + variants_len[2])] + [(1, 11)] + [(1, 11)] + [(1, 11)])

            transition_matrix.append(row_items[0])

    elif len(variants) == 4:
        accepting_states = [19]
        arcs = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]

        for row in range(19):
            row_items = []
            if row == 0:
                row_items = [(x, 9) for x in variants_len[0]] + [(x, 5) for x in variants_len[1]] + [(x, 3) for x in variants_len[2]] + [(x, 2) for x in variants_len[3]] + [(1, 17)] + [(1, 17)] + [(1, 1)]  
            elif row == 1:
                row_items = [(x, 10) for x in variants_len[0]] + [(x, 6) for x in variants_len[1]] + [(x, 4) for x in variants_len[2]] + [(x, 2) for x in variants_len[3]] + [(1, 17)] + [(1, 17)] + [(1, 2)] 
            elif row == 2:
                row_items = [(x, 11) for x in variants_len[0]] + [(x, 7) for x in variants_len[1]] + [(x, 3) for x in variants_len[2]] + [(x, 4) for x in variants_len[3]] + [(1, 17)] + [(1, 17)] + [(1, 3)] 
            elif row == 3:
                row_items = [(x, 12) for x in variants_len[0]] + [(x, 8) for x in variants_len[1]] + [(x, 4) for x in variants_len[2]] + [(x, 4) for x in variants_len[3]] + [(1, 17)] + [(1, 17)] + [(1, 4)] 
            elif row == 4:
                row_items = [(x, 13) for x in variants_len[0]] + [(x, 5) for x in variants_len[1]] + [(x, 7) for x in variants_len[2]] + [(x, 6) for x in variants_len[3]] + [(1, 17)] + [(1, 17)] + [(1, 5)] 
            elif row == 5:
                row_items = [(x, 14) for x in variants_len[0]] + [(x, 6) for x in variants_len[1]] + [(x, 8) for x in variants_len[2]] + [(x, 6) for x in variants_len[3]] + [(1, 17)] + [(1, 17)] + [(1, 6)] 
            elif row == 6:
                row_items = [(x, 15) for x in variants_len[0]] + [(x, 7) for x in variants_len[1]] + [(x, 7) for x in variants_len[2]] + [(x, 8) for x in variants_len[3]] + [(1, 17)] + [(1, 17)] + [(1, 7)] 
            elif row == 7:
                row_items = [(x, 16) for x in variants_len[0]] + [(x, 8) for x in variants_len[1]] + [(x, 8) for x in variants_len[2]] + [(x, 8) for x in variants_len[3]] + [(1, 17)] + [(1, 17)] + [(1, 8)] 
            elif row == 8:
                row_items = [(x, 9) for x in variants_len[0]] + [(x, 13) for x in variants_len[1]] + [(x, 11) for x in variants_len[2]] + [(x, 10) for x in variants_len[3]] + [(1, 17)] + [(1, 17)] + [(1, 9)] 
            elif row == 9:
                row_items = [(x, 10) for x in variants_len[0]] + [(x, 14) for x in variants_len[1]] + [(x, 12) for x in variants_len[2]] + [(x, 10) for x in variants_len[3]] + [(1, 17)] + [(1, 17)] + [(1, 10)] 
            elif row == 10:
                row_items = [(x, 11) for x in variants_len[0]] + [(x, 15) for x in variants_len[1]] + [(x, 11) for x in variants_len[2]] + [(x, 12) for x in variants_len[3]] + [(1, 17)] + [(1, 17)] + [(1, 11)] 
            elif row == 11:
                row_items = [(x, 12) for x in variants_len[0]] + [(x, 16) for x in variants_len[1]] + [(x, 12) for x in variants_len[2]] + [(x, 12) for x in variants_len[3]] + [(1, 17)] + [(1, 17)] + [(1, 12)] 
            elif row == 12:
                row_items = [(x, 13) for x in variants_len[0]] + [(x, 13) for x in variants_len[1]] + [(x, 15) for x in variants_len[2]] + [(x, 14) for x in variants_len[3]] + [(1, 17)] + [(1, 17)] + [(1, 13)] 
            elif row == 13:
                row_items = [(x, 14) for x in variants_len[0]] + [(x, 14) for x in variants_len[1]] + [(x, 16) for x in variants_len[2]] + [(x, 14) for x in variants_len[3]] + [(1, 17)] + [(1, 17)] + [(1, 14)] 
            elif row == 14:
                row_items = [(x, 15) for x in variants_len[0]] + [(x, 15) for x in variants_len[1]] + [(x, 15) for x in variants_len[2]] + [(x, 16) for x in variants_len[3]] + [(1, 17)] + [(1, 17)] + [(1, 15)] 
            elif row == 15:
                row_items = [(x, 16) for x in (variants_len[0] + variants_len[1] + variants_len[2] + variants_len[3])] + [(1, 18)] + [(1, 17)] + [(1, 16)]
            elif row == 16:
                row_items = [(x, 17) for x in (variants_len[0] + variants_len[1] + variants_len[2] + variants_len[3])] + [(1, 17)] + [(1, 17)] + [(1, 17)] 
            elif row == 17:
                row_items = [(x, 17) for x in (variants_len[0] + variants_len[1] + variants_len[2] + variants_len[3])] + [(1, 17)] + [(1, 19)] + [(1, 17)]
            elif row == 18:
                row_items = [(x, 19) for x in (variants_len[0] + variants_len[1] + variants_len[2] + variants_len[3])] + [(1, 19)] + [(1, 19)] + [(1, 19)]

            transition_matrix.append(row_items)

    elif len(variants) == 5:
        accepting_states = [35]
        arcs = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35]

        for row in range(35):
            row_items = []
            if row == 0:
                row_items = [(x, 17) for x in variants_len[0]] + [(x, 9) for x in variants_len[1]] + [(x, 5) for x in variants_len[2]] + [(x, 3) for x in variants_len[3]] + [(x, 2) for x in variants_len[4]] + [(1, 33)] + [(1, 33)] + [(1, 1)]
            elif row == 1:
                row_items = [(x, 18) for x in variants_len[0]] + [(x, 10) for x in variants_len[1]] + [(x, 6) for x in variants_len[2]] + [(x, 4) for x in variants_len[3]] + [(x, 2) for x in variants_len[4]] + [(1, 33)] + [(1, 33)] + [(1, 2)]
            elif row == 2:
                row_items = [(x, 19) for x in variants_len[0]] + [(x, 11) for x in variants_len[1]] + [(x, 7) for x in variants_len[2]] + [(x, 3) for x in variants_len[3]] + [(x, 4) for x in variants_len[4]] + [(1, 33)] + [(1, 33)] + [(1, 3)]
            elif row == 3:
                row_items = [(x, 20) for x in variants_len[0]] + [(x, 12) for x in variants_len[1]] + [(x, 8) for x in variants_len[2]] + [(x, 4) for x in variants_len[3]] + [(x, 4) for x in variants_len[4]] + [(1, 33)] + [(1, 33)] + [(1, 4)]
            elif row == 4:
                row_items = [(x, 21) for x in variants_len[0]] + [(x, 13) for x in variants_len[1]] + [(x, 5) for x in variants_len[2]] + [(x, 7) for x in variants_len[3]] + [(x, 6) for x in variants_len[4]] + [(1, 33)] + [(1, 33)] + [(1, 5)]
            elif row == 5:
                row_items = [(x, 22) for x in variants_len[0]] + [(x, 14) for x in variants_len[1]] + [(x, 6) for x in variants_len[2]] + [(x, 8) for x in variants_len[3]] + [(x, 6) for x in variants_len[4]] + [(1, 33)] + [(1, 33)] + [(1, 6)]
            elif row == 6:
                row_items = [(x, 23) for x in variants_len[0]] + [(x, 15) for x in variants_len[1]] + [(x, 7) for x in variants_len[2]] + [(x, 7) for x in variants_len[3]] + [(x, 8) for x in variants_len[4]] + [(1, 33)] + [(1, 33)] + [(1, 7)]
            elif row == 7:
                row_items = [(x, 24) for x in variants_len[0]] + [(x, 16) for x in variants_len[1]] + [(x, 8) for x in variants_len[2]] + [(x, 8) for x in variants_len[3]] + [(x, 8) for x in variants_len[4]] + [(1, 33)] + [(1, 33)] + [(1, 8)]
            elif row == 8:
                row_items = [(x, 25) for x in variants_len[0]] + [(x, 9) for x in variants_len[1]] + [(x, 13) for x in variants_len[2]] + [(x, 11) for x in variants_len[3]] + [(x, 10) for x in variants_len[4]] + [(1, 33)] + [(1, 33)] + [(1, 9)]
            elif row == 9:
                row_items = [(x, 26) for x in variants_len[0]] + [(x, 10) for x in variants_len[1]] + [(x, 14) for x in variants_len[2]] + [(x, 12) for x in variants_len[3]] + [(x, 10) for x in variants_len[4]] + [(1, 33)] + [(1, 33)] + [(1, 10)]
            elif row == 10:
                row_items = [(x, 27) for x in variants_len[0]] + [(x, 11) for x in variants_len[1]] + [(x, 15) for x in variants_len[2]] + [(x, 11) for x in variants_len[3]] + [(x, 12) for x in variants_len[4]] + [(1, 33)] + [(1, 33)] + [(1, 11)]
            elif row == 11:
                row_items = [(x, 28) for x in variants_len[0]] + [(x, 12) for x in variants_len[1]] + [(x, 16) for x in variants_len[2]] + [(x, 12) for x in variants_len[3]] + [(x, 12) for x in variants_len[4]] + [(1, 33)] + [(1, 33)] + [(1, 12)]
            elif row == 12:
                row_items = [(x, 29) for x in variants_len[0]] + [(x, 13) for x in variants_len[1]] + [(x, 13) for x in variants_len[2]] + [(x, 15) for x in variants_len[3]] + [(x, 14) for x in variants_len[4]] + [(1, 33)] + [(1, 33)] + [(1, 13)]
            elif row == 13:
                row_items = [(x, 30) for x in variants_len[0]] + [(x, 14) for x in variants_len[1]] + [(x, 14) for x in variants_len[2]] + [(x, 16) for x in variants_len[3]] + [(x, 14) for x in variants_len[4]] + [(1, 33)] + [(1, 33)] + [(1, 14)]
            elif row == 14:
                row_items = [(x, 31) for x in variants_len[0]] + [(x, 15) for x in variants_len[1]] + [(x, 15) for x in variants_len[2]] + [(x, 15) for x in variants_len[3]] + [(x, 16) for x in variants_len[4]] + [(1, 33)] + [(1, 33)] + [(1, 15)]
            elif row == 15:
                row_items = [(x, 32) for x in variants_len[0]] + [(x, 16) for x in variants_len[1]] + [(x, 16) for x in variants_len[2]] + [(x, 16) for x in variants_len[3]] + [(x, 16) for x in variants_len[4]] + [(1, 33)] + [(1, 33)] + [(1, 16)]
            elif row == 16:
                row_items = [(x, 17) for x in variants_len[0]] + [(x, 25) for x in variants_len[1]] + [(x, 21) for x in variants_len[2]] + [(x, 19) for x in variants_len[3]] + [(x, 18) for x in variants_len[4]] + [(1, 33)] + [(1, 33)] + [(1, 17)]
            elif row == 17:
                row_items = [(x, 18) for x in variants_len[0]] + [(x, 26) for x in variants_len[1]] + [(x, 22) for x in variants_len[2]] + [(x, 20) for x in variants_len[3]] + [(x, 18) for x in variants_len[4]] + [(1, 33)] + [(1, 33)] + [(1, 18)]
            elif row == 18:
                row_items = [(x, 19) for x in variants_len[0]] + [(x, 27) for x in variants_len[1]] + [(x, 23) for x in variants_len[2]] + [(x, 19) for x in variants_len[3]] + [(x, 20) for x in variants_len[4]] + [(1, 33)] + [(1, 33)] + [(1, 19)]
            elif row == 19:
                row_items = [(x, 20) for x in variants_len[0]] + [(x, 28) for x in variants_len[1]] + [(x, 24) for x in variants_len[2]] + [(x, 20) for x in variants_len[3]] + [(x, 20) for x in variants_len[4]] + [(1, 33)] + [(1, 33)] + [(1, 20)]
            elif row == 20:
                row_items = [(x, 21) for x in variants_len[0]] + [(x, 29) for x in variants_len[1]] + [(x, 21) for x in variants_len[2]] + [(x, 23) for x in variants_len[3]] + [(x, 22) for x in variants_len[4]] + [(1, 33)] + [(1, 33)] + [(1, 21)]
            elif row == 21:
                row_items = [(x, 22) for x in variants_len[0]] + [(x, 30) for x in variants_len[1]] + [(x, 22) for x in variants_len[2]] + [(x, 24) for x in variants_len[3]] + [(x, 22) for x in variants_len[4]] + [(1, 33)] + [(1, 33)] + [(1, 22)]
            elif row == 22:
                row_items = [(x, 23) for x in variants_len[0]] + [(x, 31) for x in variants_len[1]] + [(x, 23) for x in variants_len[2]] + [(x, 23) for x in variants_len[3]] + [(x, 24) for x in variants_len[4]] + [(1, 33)] + [(1, 33)] + [(1, 23)]
            elif row == 23:
                row_items = [(x, 24) for x in variants_len[0]] + [(x, 32) for x in variants_len[1]] + [(x, 24) for x in variants_len[2]] + [(x, 24) for x in variants_len[3]] + [(x, 24) for x in variants_len[4]] + [(1, 33)] + [(1, 33)] + [(1, 24)]
            elif row == 24:
                row_items = [(x, 25) for x in variants_len[0]] + [(x, 25) for x in variants_len[1]] + [(x, 29) for x in variants_len[2]] + [(x, 27) for x in variants_len[3]] + [(x, 26) for x in variants_len[4]] + [(1, 33)] + [(1, 33)] + [(1, 25)]
            elif row == 25:
                row_items = [(x, 26) for x in variants_len[0]] + [(x, 26) for x in variants_len[1]] + [(x, 30) for x in variants_len[2]] + [(x, 28) for x in variants_len[3]] + [(x, 26) for x in variants_len[4]] + [(1, 33)] + [(1, 33)] + [(1, 26)]
            elif row == 26:
                row_items = [(x, 27) for x in variants_len[0]] + [(x, 27) for x in variants_len[1]] + [(x, 31) for x in variants_len[2]] + [(x, 27) for x in variants_len[3]] + [(x, 28) for x in variants_len[4]] + [(1, 33)] + [(1, 33)] + [(1, 27)]
            elif row == 27:
                row_items = [(x, 28) for x in variants_len[0]] + [(x, 28) for x in variants_len[1]] + [(x, 32) for x in variants_len[2]] + [(x, 28) for x in variants_len[3]] + [(x, 28) for x in variants_len[4]] + [(1, 33)] + [(1, 33)] + [(1, 28)]
            elif row == 28:
                row_items = [(x, 29) for x in variants_len[0]] + [(x, 29) for x in variants_len[1]] + [(x, 29) for x in variants_len[2]] + [(x, 31) for x in variants_len[3]] + [(x, 30) for x in variants_len[4]] + [(1, 33)] + [(1, 33)] + [(1, 29)]
            elif row == 29:
                row_items = [(x, 30) for x in variants_len[0]] + [(x, 30) for x in variants_len[1]] + [(x, 30) for x in variants_len[2]] + [(x, 32) for x in variants_len[3]] + [(x, 30) for x in variants_len[4]] + [(1, 33)] + [(1, 33)] + [(1, 30)]
            elif row == 30:
                row_items = [(x, 31) for x in variants_len[0]] + [(x, 31) for x in variants_len[1]] + [(x, 31) for x in variants_len[2]] + [(x, 31) for x in variants_len[3]] + [(x, 32) for x in variants_len[4]] + [(1, 33)] + [(1, 33)] + [(1, 31)]
            elif row == 31:
                row_items = [(x, 32) for x in (variants_len[0] + variants_len[1] + variants_len[2] + variants_len[3] + variants_len[4])] + [(1, 34)] + [(1, 33)] + [(1, 32)]
            elif row == 32:
                row_items = [(x, 33) for x in (variants_len[0] + variants_len[1] + variants_len[2] + variants_len[3] + variants_len[4])] + [(1, 33)] + [(1, 33)] + [(1, 33)]
            elif row == 33:
                row_items = [(x, 33) for x in (variants_len[0] + variants_len[1] + variants_len[2] + variants_len[3] + variants_len[4])] + [(1, 33)] + [(1, 35)] + [(1, 33)]
            elif row == 34:
                row_items = [(x, 35) for x in (variants_len[0] + variants_len[1] + variants_len[2] + variants_len[3] + variants_len[4])] + [(1, 35)] + [(1, 35)] + [(1, 35)]

            transition_matrix.append(row_items)

    dist = dijkstra_from_acceptance(transition_matrix, accepting_states)

    #print(dist)
    dfa_layer = {
        "accepting_states": accepting_states,
        "arcs": arcs,
        "dist": dist,
        "transition_matrix":transition_matrix
    }

    return dfa_layer