import heapq

def complete_prompt_ordered_commongen(concepts_str):
    prompt = f"""
# Instruction

Given several concepts (i.e., nouns or verbs), write a short and simple sentence that contains *all* the required words in the given order.
The sentence should describe a common scene in daily life, and the concepts should be used in a natural way.

# Examples

## Example 1
- Concepts: "dog, frisbee, catch, throw"
- Sentence: The dog eagerly chased the frisbee trying to catch it after its owner threw it.

## Example 2
- Concepts: "apple, place, tree, pick"
- Sentence: I found an apple in a place near a tree and I picked it up.

# Your Task

- Concepts: {concepts_str}
- Sentence:"""
    
    return prompt

def compute_token_variants_len(variants, tokenizer):
    token_variants_len = []
    for variant in range(len(variants)):
        token_variants_len.append([])
        for var in range(len(variants[variant])):
            token_variants_len[variant].append(len(tokenizer.tokenize(variants[variant][var])))
    return token_variants_len

def dijkstra_from_acceptance(matrix, acceptance_states):
    n = len(matrix)
    dist = [float('inf')] * n
    pq = []
    
    inverted_adj = [[] for _ in range(n)]
    for u in range(n):
        for cost, v in matrix[u]:
            inverted_adj[v-1].append((u+1, cost))
    
    for state in acceptance_states:
        dist[state-1] = 0
        heapq.heappush(pq, (0, state-1))
    
    while pq:
        current_dist, u = heapq.heappop(pq)
        if current_dist > dist[u]:
            continue
        for v, cost in inverted_adj[u]:
            v_index = v - 1
            new_dist = current_dist + cost
            if new_dist < dist[v_index]:
                dist[v_index] = new_dist
                heapq.heappush(pq, (new_dist, v_index))
    
    return {i+1: dist[i] for i in range(n)}

def ordered_commongen_dfa(variants, tokenizer):
    transition_matrix = []

    variants_len = compute_token_variants_len(variants, tokenizer)
        
    if len(variants) == 3:

        accepting_states = [7]
        arcs = [1, 2, 3, 4, 5, 6, 7]

        for row in range(7):

            row_items = []

            if row == 0:
                row_items.append([(x, 3) for x in variants_len[0]] + [(x, 2) for x in variants_len[1]] + [(x, 2) for x in variants_len[2]] + [(1, 2)] + [(1, 2)] + [(1, 1)])
            if row == 1:
                row_items.append([(x, 2) for x in variants_len[0]] + [(x, 2) for x in variants_len[1]] + [(x, 2) for x in variants_len[2]] + [(1, 2)] + [(1, 2)] + [(1, 2)])
            if row == 2:
                row_items.append([(x, 3) for x in variants_len[0]] + [(x, 4) for x in variants_len[1]] + [(x, 2) for x in variants_len[2]] + [(1, 2)] + [(1, 2)] + [(1, 3)])
            if row == 3:
                row_items.append([(x, 4) for x in variants_len[0]] + [(x, 4) for x in variants_len[1]] + [(x, 5) for x in variants_len[2]] + [(1, 2)] + [(1, 2)]+ [(1, 4)])
            if row == 4:
                row_items.append([(x, 5) for x in variants_len[0]] + [(x, 5) for x in variants_len[1]] + [(x, 5) for x in variants_len[2]] + [(1, 6)] + [(1, 2)] + [(1, 5)])
            if row == 5:
                row_items.append([(x, 2) for x in variants_len[0]] + [(x, 2) for x in variants_len[1]] + [(x, 2) for x in variants_len[2]] + [(1, 2)] + [(1, 7)] + [(1, 2)])
            if row == 6:
                row_items.append([(x, 7) for x in (variants_len[0] + variants_len[1])] + [(x, 7) for x in variants_len[2]] + [(1, 7)] + [(1, 7)] + [(1, 7)])

            transition_matrix.append(row_items[0])

    elif len(variants) == 4:
        accepting_states = [8]
        arcs = [1, 2, 3, 4, 5, 6, 7, 8]

        for row in range(8):
            row_items = []
            if row == 0:
                row_items = [(x, 3) for x in variants_len[0]] + [(x, 2) for x in variants_len[1]] + [(x, 2) for x in variants_len[2]] + [(x, 2) for x in variants_len[3]] + [(1, 2)] + [(1, 2)] + [(1, 1)]  
            elif row == 1:
                row_items = [(x, 2) for x in variants_len[0]] + [(x, 2) for x in variants_len[1]] + [(x, 2) for x in variants_len[2]] + [(x, 2) for x in variants_len[3]] + [(1, 2)] + [(1, 2)] + [(1, 2)] 
            elif row == 2:
                row_items = [(x, 3) for x in variants_len[0]] + [(x, 4) for x in variants_len[1]] + [(x, 2) for x in variants_len[2]] + [(x, 2) for x in variants_len[3]] + [(1, 2)] + [(1, 2)] + [(1, 3)] 
            elif row == 3:
                row_items = [(x, 4) for x in variants_len[0]] + [(x, 4) for x in variants_len[1]] + [(x, 5) for x in variants_len[2]] + [(x, 2) for x in variants_len[3]] + [(1, 2)] + [(1, 2)] + [(1, 4)] 
            elif row == 4:
                row_items = [(x, 5) for x in variants_len[0]] + [(x, 5) for x in variants_len[1]] + [(x, 5) for x in variants_len[2]] + [(x, 6) for x in variants_len[3]] + [(1, 2)] + [(1, 2)] + [(1, 5)] 
            elif row == 5:
                row_items = [(x, 6) for x in variants_len[0]] + [(x, 6) for x in variants_len[1]] + [(x, 6) for x in variants_len[2]] + [(x, 6) for x in variants_len[3]] + [(1, 7)] + [(1, 2)] + [(1, 6)] 
            elif row == 6:
                row_items = [(x, 2) for x in variants_len[0]] + [(x, 2) for x in variants_len[1]] + [(x, 2) for x in variants_len[2]] + [(x, 2) for x in variants_len[3]] + [(1, 2)] + [(1, 8)] + [(1, 2)] 
            elif row == 7:
                row_items = [(x, 8) for x in variants_len[0]] + [(x, 8) for x in variants_len[1]] + [(x, 8) for x in variants_len[2]] + [(x, 8) for x in variants_len[3]] + [(1, 8)] + [(1, 8)] + [(1, 8)] 

            transition_matrix.append(row_items)

    elif len(variants) == 5:
        accepting_states = [9]
        arcs = [1, 2, 3, 4, 5, 6, 7, 8, 9]

        for row in range(9):
            row_items = []
            if row == 0:
                row_items = [(x, 3) for x in variants_len[0]] + [(x, 2) for x in variants_len[1]] + [(x, 2) for x in variants_len[2]] + [(x, 2) for x in variants_len[3]] + [(x, 2) for x in variants_len[4]] + [(1, 2)] + [(1, 2)] + [(1, 1)]
            elif row == 1:
                row_items = [(x, 2) for x in variants_len[0]] + [(x, 2) for x in variants_len[1]] + [(x, 2) for x in variants_len[2]] + [(x, 2) for x in variants_len[3]] + [(x, 2) for x in variants_len[4]] + [(1, 2)] + [(1, 2)] + [(1, 2)]
            elif row == 2:
                row_items = [(x, 3) for x in variants_len[0]] + [(x, 4) for x in variants_len[1]] + [(x, 2) for x in variants_len[2]] + [(x, 2) for x in variants_len[3]] + [(x, 2) for x in variants_len[4]] + [(1, 2)] + [(1, 2)] + [(1, 3)]
            elif row == 3:
                row_items = [(x, 4) for x in variants_len[0]] + [(x, 4) for x in variants_len[1]] + [(x, 5) for x in variants_len[2]] + [(x, 2) for x in variants_len[3]] + [(x, 2) for x in variants_len[4]] + [(1, 2)] + [(1, 2)] + [(1, 4)]
            elif row == 4:
                row_items = [(x, 5) for x in variants_len[0]] + [(x, 5) for x in variants_len[1]] + [(x, 5) for x in variants_len[2]] + [(x, 6) for x in variants_len[3]] + [(x, 2) for x in variants_len[4]] + [(1, 2)] + [(1, 2)] + [(1, 5)]
            elif row == 5:
                row_items = [(x, 6) for x in variants_len[0]] + [(x, 6) for x in variants_len[1]] + [(x, 6) for x in variants_len[2]] + [(x, 6) for x in variants_len[3]] + [(x, 7) for x in variants_len[4]] + [(1, 2)] + [(1, 2)] + [(1, 6)]
            elif row == 6:
                row_items = [(x, 7) for x in variants_len[0]] + [(x, 7) for x in variants_len[1]] + [(x, 7) for x in variants_len[2]] + [(x, 7) for x in variants_len[3]] + [(x, 7) for x in variants_len[4]] + [(1, 8)] + [(1, 2)] + [(1, 7)]
            elif row == 7:
                row_items = [(x, 2) for x in variants_len[0]] + [(x, 2) for x in variants_len[1]] + [(x, 2) for x in variants_len[2]] + [(x, 2) for x in variants_len[3]] + [(x, 2) for x in variants_len[4]] + [(1, 2)] + [(1, 9)] + [(1, 2)]
            elif row == 8:
                row_items = [(x, 9) for x in variants_len[0]] + [(x, 9) for x in variants_len[1]] + [(x, 9) for x in variants_len[2]] + [(x, 9) for x in variants_len[3]] + [(x, 9) for x in variants_len[4]] + [(1, 9)] + [(1, 9)] + [(1, 9)]

            transition_matrix.append(row_items)

    dist = dijkstra_from_acceptance(transition_matrix, accepting_states)

    dfa_layer = {
        "accepting_states": accepting_states,
        "arcs": arcs,
        "dist": dist,
        "transition_matrix":transition_matrix
    }

    return dfa_layer