# We follow a similar naming convention to the paper "Language is All a Graph Needs"
# It includes 3 terms in the current version, e.g., cmd

# The 1st term - task type
# c: node classification
# l: link prediction

# The 2nd term - input mode:
# A_B_C_D.
# A: How many predicted labels from the teacher model (GNN) is provided. E.g., 2l. 
# B: Which info of the target node is used. E.g., tc means title + content.
# C: Which info of the neighbors' is used. E.g., tl means title and predicted labels.
# D: Whether the augmented TAPE text is used. E.g., tape means using TAPE text.

# The 3rd term - output mode:
# A_B
# A: groundtruth, or distilled ranked labels, or distilled ranked labels + probabilities
# B: rationale (optional)

import random

def get_template(task, input_mode, output_mode):
    prompt_temp, input_temp, output_temp = None, None, None
    
    if task == 'c':

        output_mode = output_mode.split('_')
        if output_mode[1] == 'x':
            if output_mode[0] == 'g':
                output_temp = c_g_output
                prompt_temp = c_g_prompt
            elif output_mode[0] == 'd':
                output_temp = c_d_output
                prompt_temp = c_d_prompt
            elif output_mode[0] == 'p':
                output_temp = c_p_output
                prompt_temp = c_p_prompt
        elif output_mode[1] == 'r':
            if output_mode[0] == 'g':
                output_temp = c_g_r_output
                prompt_temp = c_g_r_prompt
            elif output_mode[0] == 'd':
                output_temp = c_d_r_output
                prompt_temp = c_d_r_prompt
            elif output_mode[0] == 'p':
                output_temp = c_p_r_output
                prompt_temp = c_p_r_prompt
        elif output_mode[1] == '2l':
            output_temp = c_g_2l_output
            prompt_temp = c_g_prompt
        
        input_mode = input_mode.split('_')
        input_labels = input_mode[0]
        target_mode = input_mode[1]
        neighbor_mode = input_mode[2]
        tape_mode = input_mode[3]

        # Target node only
        if tape_mode == 'x' and neighbor_mode == 'x':
            if input_labels == '0l':
                if target_mode == 't':
                    input_temp = c_0l_t_x_x_input
                elif target_mode == 'c':
                    input_temp = c_0l_c_x_x_input
                elif target_mode == 'tc':
                    input_temp = c_0l_tc_x_x_input
                elif target_mode == 'ct':
                    input_temp = c_0l_ct_x_x_input
            elif input_labels == '2l':
                if target_mode == 't':
                    input_temp = c_2l_t_x_x_input
                elif target_mode == 'c':
                    input_temp = c_2l_c_x_x_input
                elif target_mode == 'tc':
                    input_temp = c_2l_tc_x_x_input
                elif target_mode == 'ct':
                    input_temp = c_2l_ct_x_x_input
        
        # Target node + TAPE text
        if tape_mode == 'tape' and neighbor_mode == 'x':
            if input_labels == '0l':
                if target_mode == 't':
                    input_temp = c_0l_t_x_tape_input
                elif target_mode == 'c':
                    input_temp = c_0l_c_x_tape_input
                elif target_mode == 'tc':
                    input_temp = c_0l_tc_x_tape_input
                elif target_mode == 'ct':
                    input_temp = c_0l_ct_x_tape_input
                elif target_mode == 'x':
                    input_temp = c_0l_x_x_tape_input
            elif input_labels == '1l':
                if target_mode == 't':
                    input_temp = c_1l_t_x_tape_input
                elif target_mode == 'c':
                    input_temp = c_1l_c_x_tape_input
                elif target_mode == 'tc':
                    input_temp = c_1l_tc_x_tape_input
                elif target_mode == 'ct':
                    input_temp = c_1l_ct_x_tape_input
                elif target_mode == 'x':
                    input_temp = c_1l_x_x_tape_input
            elif input_labels == '2l':
                if target_mode == 't':
                    input_temp = c_2l_t_x_tape_input
                elif target_mode == 'c':
                    input_temp = c_2l_c_x_tape_input
                elif target_mode == 'tc':
                    input_temp = c_2l_tc_x_tape_input
                elif target_mode == 'ct':
                    input_temp = c_2l_ct_x_tape_input
                elif target_mode == 'x':
                    input_temp = c_2l_x_x_tape_input
        
        # Target node + neighbors' titles
        if tape_mode == 'x' and neighbor_mode == 't':
            if input_labels == '0l':
                if target_mode == 't':
                    input_temp = c_0l_t_t_x_input
                elif target_mode == 'c':
                    input_temp = c_0l_c_t_x_input
                elif target_mode == 'tc':
                    input_temp = c_0l_tc_t_x_input
                elif target_mode == 'ct':
                    input_temp = c_0l_ct_t_x_input
            elif input_labels == '2l':
                if target_mode == 't':
                    input_temp = c_2l_t_t_x_input
                elif target_mode == 'c':
                    input_temp = c_2l_c_t_x_input
                elif target_mode == 'tc':
                    input_temp = c_2l_tc_t_x_input
                elif target_mode == 'ct':
                    input_temp = c_2l_ct_t_x_input
    
        # Target node + TAPE text + neighbors' titles
        if tape_mode == 'tape' and neighbor_mode == 't':
            if input_labels == '0l':
                if target_mode == 't':
                    input_temp = c_0l_t_t_tape_input
                elif target_mode == 'c':
                    input_temp = c_0l_c_t_tape_input
                elif target_mode == 'tc':
                    input_temp = c_0l_tc_t_tape_input
                elif target_mode == 'ct':
                    input_temp = c_0l_ct_t_tape_input
            elif input_labels == '1l':
                if target_mode == 't':
                    input_temp = c_1l_t_t_tape_input
                elif target_mode == 'c':
                    input_temp = c_1l_c_t_tape_input
                elif target_mode == 'tc':
                    input_temp = c_1l_tc_t_tape_input
                elif target_mode == 'ct':
                    input_temp = c_1l_ct_t_tape_input
            elif input_labels == '2l':
                if target_mode == 't':
                    input_temp = c_2l_t_t_tape_input
                elif target_mode == 'c':
                    input_temp = c_2l_c_t_tape_input
                elif target_mode == 'tc':
                    input_temp = c_2l_tc_t_tape_input
                elif target_mode == 'ct':
                    input_temp = c_2l_ct_t_tape_input
        
        # Target node + neighbors' titles & labels
        if tape_mode == 'x' and neighbor_mode == 'tl':
            if input_labels == '0l':
                if target_mode == 't':
                    input_temp = c_0l_t_tl_x_input
                elif target_mode == 'c':
                    input_temp = c_0l_c_tl_x_input
                elif target_mode == 'tc':
                    input_temp = c_0l_tc_tl_x_input
                elif target_mode == 'ct':
                    input_temp = c_0l_ct_tl_x_input
            elif input_labels == '2l':
                if target_mode == 't':
                    input_temp = c_2l_t_tl_x_input
                elif target_mode == 'c':
                    input_temp = c_2l_c_tl_x_input
                elif target_mode == 'tc':
                    input_temp = c_2l_tc_tl_x_input
                elif target_mode == 'ct':
                    input_temp = c_2l_ct_tl_x_input
    
        # Target node + TAPE text + neighbors' titles & labels
        if tape_mode == 'tape' and neighbor_mode == 'tl':
            if input_labels == '0l':
                if target_mode == 't':
                    input_temp = c_0l_t_tl_tape_input
                elif target_mode == 'c':
                    input_temp = c_0l_c_tl_tape_input
                elif target_mode == 'tc':
                    input_temp = c_0l_tc_tl_tape_input
                elif target_mode == 'ct':
                    input_temp = c_0l_ct_tl_tape_input
            elif input_labels == '2l':
                if target_mode == 't':
                    input_temp = c_2l_t_tl_tape_input
                elif target_mode == 'c':
                    input_temp = c_2l_c_tl_tape_input
                elif target_mode == 'tc':
                    input_temp = c_2l_tc_tl_tape_input
                elif target_mode == 'ct':
                    input_temp = c_2l_ct_tl_tape_input
        
        # Target node + TAPE text + rag's neighbors' titles
        if tape_mode == 'tape' and neighbor_mode == 'rag':
            if input_labels == '0l':
                if target_mode == 't':
                    input_temp = c_0l_t_rag_tape_input
                elif target_mode == 'c':
                    input_temp = c_0l_c_rag_tape_input
                elif target_mode == 'tc':
                    input_temp = c_0l_tc_rag_tape_input
                elif target_mode == 'ct':
                    input_temp = c_0l_ct_rag_tape_input
            elif input_labels == '1l':
                if target_mode == 't':
                    input_temp = c_1l_t_rag_tape_input
                elif target_mode == 'c':
                    input_temp = c_1l_c_rag_tape_input
                elif target_mode == 'tc':
                    input_temp = c_1l_tc_rag_tape_input
                elif target_mode == 'ct':
                    input_temp = c_1l_ct_rag_tape_input
            elif input_labels == '2l':
                if target_mode == 't':
                    input_temp = c_2l_t_rag_tape_input
                elif target_mode == 'c':
                    input_temp = c_2l_c_rag_tape_input
                elif target_mode == 'tc':
                    input_temp = c_2l_tc_rag_tape_input
                elif target_mode == 'ct':
                    input_temp = c_2l_ct_rag_tape_input
        
        # Target node + TAPE text + rag's neighbors' titles + neighbors' titles
        if tape_mode == 'tape' and neighbor_mode == 'ragt':
            if input_labels == '0l':
                if target_mode == 't':
                    input_temp = c_0l_t_ragt_tape_input
                elif target_mode == 'c':
                    input_temp = c_0l_c_ragt_tape_input
                elif target_mode == 'tc':
                    input_temp = c_0l_tc_ragt_tape_input
                elif target_mode == 'ct':
                    input_temp = c_0l_ct_ragt_tape_input
            elif input_labels == '2l':
                if target_mode == 't':
                    input_temp = c_2l_t_ragt_tape_input
                elif target_mode == 'c':
                    input_temp = c_2l_c_ragt_tape_input
                elif target_mode == 'tc':
                    input_temp = c_2l_tc_ragt_tape_input
                elif target_mode == 'ct':
                    input_temp = c_2l_ct_ragt_tape_input
    
    elif task == 'l':
        pass
    
    if prompt_temp == None or input_temp == None or output_temp == None:
        raise ValueError("Invalid task or input/output mode.")

    return prompt_temp, input_temp, output_temp

""" %%%%%%%%%%%%%%%%%%%%%%%%%%%%
Node classification templates
%%%%%%%%%%%%%%%%%%%%%%%%%%%% """
# label_set, title_list, content_list, label_list, label_and_prob_list, neighbors_list, rationale_list, gpt_list, raw_label_and_prob_list = lists

""" ****************** Classification prompt template ****************** """
def c_g_prompt(lists, i):
    return f"""Classify the research paper according to the provided information.\n"""

def c_d_prompt(lists, i):
    return f"""Classify the research paper according to the provided information.\nProvide the top-3 categories (in descending order).\n"""

def c_p_prompt(lists, i):
    return f"""Classify the research paper according to the provided information.\nProvide the top-3 categories (in descending order) with corresponding probabilities.\n"""

def c_g_r_prompt(lists, i):
    return f"""Classify the research paper according to the provided information.\nProvide a rationale for your classification.\n"""

def c_d_r_prompt(lists, i):
    return f"""Classify the research paper according to the provided information.\nProvide the top-3 categories (in descending order) and a rationale for your classification.\n"""

def c_p_r_prompt(lists, i):
    return f"""Classify the research paper according to the provided information.\nProvide the top-3 categories (in descending order) with corresponding probabilities and a rationale for your classification.\n"""

""" ****************** Output template ****************** """
# Output: ground truth label
def c_g_output(lists, i):
    label = lists[3][i]
    return f"""{label}"""

# Output: repeat the top-2 label from GNNs and report the groundtruth label
def c_g_2l_output(lists, i):
    label = lists[3][i]
    raw_label_and_prob_list = lists[8][i]
    top_candidates = [f"({x.split('|')[0]})" for x in raw_label_and_prob_list][:2]
    return f"""Among the candidates: {', '.join(top_candidates[::-1])}, the correct label is: {label}"""

# Output: distilled ranking labels
def c_d_output(lists, i):
    label_and_prob_list = lists[4][i]
    label_list = [x.split('|')[0] for x in label_and_prob_list]
    return f"""{', '.join(label_list)}"""

# Output: distilled ranking labels (with probs)
def c_p_output(lists, i):
    label_and_prob_list = lists[4][i]
    label_and_prob_list_tmp = [x.split('|') for x in label_and_prob_list]
    output_list = [f"{x[0]} ({x[1]}%)" for x in label_and_prob_list_tmp]
    return f"""{', '.join(output_list)}"""

# Output: ground truth label + Rationale
def c_g_r_output(lists, i):
    label = lists[3][i]
    rationale = lists[6][i]
    return f"""{label}\nRationale: {rationale}"""

# Output: distilled ranking labels + Rationale
def c_d_r_output(lists, i):
    label_and_prob_list = lists[4][i]
    label_list = [x.split('|')[0] for x in label_and_prob_list]
    rationale = lists[6][i]
    return f"""{', '.join(label_list)}\nRationale: {rationale}"""

# Output: distilled ranking labels (with probs) + Rationale
def c_p_r_output(lists, i):
    label_and_prob_list = lists[4][i]
    label_and_prob_list_tmp = [x.split('|') for x in label_and_prob_list]
    output_list = [f"{x[0]} ({x[1]}%)" for x in label_and_prob_list_tmp]
    rationale = lists[6][i]
    return f"""{', '.join(output_list)}\nRationale: {rationale}"""


""" ****************** Single input templates ****************** """
# Input: target title
def c_0l_t_x_x_input(lists, i):
    title = lists[1][i]
    return f"""Title: {title}."""

# Input: target content
def c_0l_c_x_x_input(lists, i):
    content = lists[2][i]
    return f"""Abstract: {content}."""

# Input: target title + target content
def c_0l_tc_x_x_input(lists, i):
    title = lists[1][i]
    content = lists[2][i]
    return f"""Title: {title}\nAbstract: {content}"""

# Input: target content + target title (reversed order)
def c_0l_ct_x_x_input(lists, i):
    title = lists[1][i]
    content = lists[2][i]
    return f"""Abstract: {content}\nTitle: {title}"""

# Input: GPT response from TAPE
def c_0l_x_x_tape_input(lists, i):
    gpt_response = lists[7][i]
    return f"""{gpt_response}"""

# Input: GPT response from TAPE + target title + target content
def c_0l_tc_x_tape_input(lists, i):
    gpt_response = lists[7][i]
    title = lists[1][i]
    content = lists[2][i]
    return f"""{gpt_response}\nTitle: {title}\nAbstract: {content}"""

# Input: GPT response from TAPE + target title + target content + top-1 label from GNNs
def c_1l_tc_x_tape_input(lists, i):
    gpt_response = lists[7][i]
    title = lists[1][i]
    content = lists[2][i]
    raw_label_and_prob_list = lists[8][i]
    top_candidates = [f"({x.split('|')[0]})" for x in raw_label_and_prob_list][:1]
    return f"""{gpt_response}\nTitle: {title}\nAbstract: {content}\nChoose from: {', '.join(top_candidates[::-1])}"""

# Input: GPT response from TAPE + target title + target content + top-2 label from GNNs
def c_2l_tc_x_tape_input(lists, i):
    gpt_response = lists[7][i]
    title = lists[1][i]
    content = lists[2][i]
    raw_label_and_prob_list = lists[8][i]
    top_candidates = [f"({x.split('|')[0]})" for x in raw_label_and_prob_list][:2]
    return f"""{gpt_response}\nTitle: {title}\nAbstract: {content}\nChoose from: {', '.join(top_candidates[::-1])}"""

# Input: GPT response from TAPE + target content + target title (reversed order)
def c_0l_ct_x_tape_input(lists, i):
    gpt_response = lists[7][i]
    title = lists[1][i]
    content = lists[2][i]
    return f"""Abstract: {content}\n{gpt_response}\nTitle: {title}"""

# Input: GPT response from TAPE + target content + target title (reversed order) + top-2 label from GNNs
def c_2l_ct_x_tape_input(lists, i):
    gpt_response = lists[7][i]
    title = lists[1][i]
    content = lists[2][i]
    raw_label_and_prob_list = lists[8][i]
    top_candidates = [f"({x.split('|')[0]})" for x in raw_label_and_prob_list][:2]
    return f"""Abstract: {content}\n{gpt_response}\nTitle: {title}\nChoose from: {', '.join(top_candidates[::-1])}"""

# Input: GPT response from TAPE + target content + target title (reversed order) + top-3 label from GNNs
def c_3l_ct_x_tape_input(lists, i):
    gpt_response = lists[7][i]
    title = lists[1][i]
    content = lists[2][i]
    raw_label_and_prob_list = lists[8][i]
    top_candidates = [f"({x.split('|')[0]})" for x in raw_label_and_prob_list][:3]
    return f"""Abstract: {content}\n{gpt_response}\nTitle: {title}\nChoose from: {', '.join(top_candidates[::-1])}"""


""" ****************** Multiple input templates ****************** """

# Input: target title + neighbors' titles
def c_0l_t_t_x_input(lists, i):
    title = lists[1][i]
    neighbor_titles = [lists[1][x] for x in lists[5][i]]
    neighbor_titles_tmp = [f"({x})" for x in neighbor_titles]
    return f"""Title: {title}\nReferences: {', '.join(neighbor_titles_tmp)}"""


# Input: target title + neighbors' titles + top-2 label from GNNs
def c_2l_t_t_x_input(lists, i):
    title = lists[1][i]
    neighbor_titles = [lists[1][x] for x in lists[5][i]]
    neighbor_titles_tmp = [f"({x})" for x in neighbor_titles]
    top_candidates = [f"({x.split('|')[0]})" for x in lists[8][i]][:2]
    # random.shuffle(top_candidates)
    return f"""Title: {title}\nReferences: {', '.join(neighbor_titles_tmp)}\nChoose from: {', '.join(top_candidates[::-1])}"""

# Input: target title + neighbors' titles & labels
def c_0l_t_tl_x_input(lists, i):
    title = lists[1][i]
    neighbor_title_and_label = [(lists[1][x], lists[4][x][0].split('|')[0]) for x in lists[5][i]]
    neighbor_title_and_label = [f"({x[0]}, {x[1]})" for x in neighbor_title_and_label]
    top_candidates = [f"({x.split('|')[0]})" for x in lists[8][i]][:2]
    return f"""Title: {title}\nReferences: {', '.join(neighbor_title_and_label)}"""

# Input: target title + neighbors' titles & labels + top-2 label from GNNs
def c_2l_t_tl_x_input(lists, i):
    title = lists[1][i]
    neighbor_title_and_label = [(lists[1][x], lists[4][x][0].split('|')[0]) for x in lists[5][i]]
    neighbor_title_and_label = [f"({x[0]}, {x[1]})" for x in neighbor_title_and_label]
    top_candidates = [f"({x.split('|')[0]})" for x in lists[8][i]][:2]
    return f"""Title: {title}\nReferences: {', '.join(neighbor_title_and_label)}\nChoose from: {', '.join(top_candidates[::-1])}"""

# Input: target title & content + neighbors' titles
def c_0l_tc_t_x_input(lists, i):
    title = lists[1][i]
    content = lists[2][i]
    neighbor_titles = [lists[1][x] for x in lists[5][i]]
    neighbor_titles_tmp = [f"({x})" for x in neighbor_titles]
    return f"""Title: {title}\nAbstract: {content}\nReferences: {', '.join(neighbor_titles_tmp)}."""

# Input: target title & content + neighbors' titles & labels
def c_0l_tc_tl_x_input(lists, i):
    title = lists[1][i]
    content = lists[2][i]
    neighbor_title_and_label = [(lists[1][x], lists[4][x][0].split('|')[0]) for x in lists[5][i]]
    neighbor_title_and_label = [f"({x[0]}, {x[1]})" for x in neighbor_title_and_label]
    return f"""Title: {title}\nAbstract: {content}\nReferences: {', '.join(neighbor_title_and_label)}."""

# Input: target title & content + neighbors' titles + top-2 label from GNNs
def c_2l_tc_t_x_input(lists, i):
    title = lists[1][i]
    content = lists[2][i]
    neighbor_titles = [lists[1][x] for x in lists[5][i]]
    neighbor_titles_tmp = [f"({x})" for x in neighbor_titles]
    top_candidates = [f"({x.split('|')[0]})" for x in lists[8][i]][:2]
    return f"""Title: {title}\nAbstract: {content}\nReferences: {', '.join(neighbor_titles_tmp)}\nChoose from: {', '.join(top_candidates[::-1])}"""

# Input: target content & title + neighbors' titles & labels
def c_0l_ct_tl_x_input(lists, i):
    title = lists[1][i]
    content = lists[2][i]
    neighbor_title_and_label = [(lists[1][x], lists[4][x][0].split('|')[0]) for x in lists[5][i]]
    neighbor_title_and_label = [f"({x[0]}, {x[1]})" for x in neighbor_title_and_label]
    return f"""Abstract: {content}\nReferences: {', '.join(neighbor_title_and_label)}\nTitle: {title}."""

# Input: target title & content + neighbors' titles + TAPE text
def c_0l_t_t_tape_input(lists, i):
    title = lists[1][i]
    neighbor_titles = [lists[1][x] for x in lists[5][i]]
    neighbor_titles_tmp = [f"({x})" for x in neighbor_titles]
    gpt_response = lists[7][i]
    return f"""Title: {title}\n{gpt_response}\nReferences: {', '.join(neighbor_titles_tmp)}."""

# Input: target title + neighbors' titles  + TAPE text + top-1 label from GNNs
def c_1l_t_t_tape_input(lists, i):
    title = lists[1][i]
    neighbor_titles = [lists[1][x] for x in lists[5][i]]
    neighbor_titles = [f"({x})" for x in neighbor_titles]
    gpt_response = lists[7][i]
    top_candidates = [f"({x.split('|')[0]})" for x in lists[8][i]][:1]
    return f"""Title: {title}\n{gpt_response}\nReferences: {', '.join(neighbor_titles)}\nChoose from: {', '.join(top_candidates[::-1])}"""

# Input: target title + neighbors' titles  + TAPE text + top-2 label from GNNs
def c_2l_t_t_tape_input(lists, i):
    title = lists[1][i]
    neighbor_titles = [lists[1][x] for x in lists[5][i]]
    neighbor_titles = [f"({x})" for x in neighbor_titles]
    gpt_response = lists[7][i]
    top_candidates = [f"({x.split('|')[0]})" for x in lists[8][i]][:2]
    return f"""Title: {title}\n{gpt_response}\nReferences: {', '.join(neighbor_titles)}\nChoose from: {', '.join(top_candidates[::-1])}"""

# Input: target title + neighbors' titles & labels + TAPE text
def c_0l_t_tl_tape_input(lists, i):
    title = lists[1][i]
    neighbor_title_and_label = [(lists[1][x], lists[4][x][0].split('|')[0]) for x in lists[5][i]]
    neighbor_title_and_label = [f"({x[0]}, {x[1]})" for x in neighbor_title_and_label]
    gpt_response = lists[7][i]
    return f"""Title: {title}\nReferences: {', '.join(neighbor_title_and_label)}\n{gpt_response}"""

# Input: target title + neighbors' titles & labels + TAPE text + top-2 label from GNNs
def c_2l_t_tl_tape_input(lists, i):
    title = lists[1][i]
    neighbor_title_and_label = [(lists[1][x], lists[4][x][0].split('|')[0]) for x in lists[5][i]]
    neighbor_title_and_label = [f"({x[0]}, {x[1]})" for x in neighbor_title_and_label]
    gpt_response = lists[7][i]
    top_candidates = [f"({x.split('|')[0]})" for x in lists[8][i]][:2]
    return f"""Title: {title}\nReferences: {', '.join(neighbor_title_and_label)}\n{gpt_response}\nChoose from: {', '.join(top_candidates[::-1])}"""

# Input: target title & content + neighbors' titles + TAPE text
def c_0l_tc_t_tape_input(lists, i):
    title = lists[1][i]
    content = lists[2][i]
    neighbor_titles = [lists[1][x] for x in lists[5][i]]
    neighbor_titles_tmp = [f"({x})" for x in neighbor_titles]
    gpt_response = lists[7][i]
    return f"""Title: {title}\nAbstract: {content}\n{gpt_response}\nReferences: {', '.join(neighbor_titles_tmp)}."""

# Input: target title & content + neighbors' titles + TAPE text + top-2 label from GNNs
def c_2l_tc_t_tape_input(lists, i):
    title = lists[1][i]
    content = lists[2][i]
    neighbor_titles = [lists[1][x] for x in lists[5][i]]
    neighbor_titles_tmp = [f"({x})" for x in neighbor_titles]
    gpt_response = lists[7][i]
    top_candidates = [f"({x.split('|')[0]})" for x in lists[8][i]][:2]
    return f"""Title: {title}\nAbstract: {content}\n{gpt_response}\nReferences: {', '.join(neighbor_titles_tmp)}\nChoose from: {', '.join(top_candidates[::-1])}"""

# Input: target content & title + neighbors' titles + TAPE text
def c_0l_ct_t_tape_input(lists, i):
    title = lists[1][i]
    content = lists[2][i]
    neighbor_titles = [lists[1][x] for x in lists[5][i]]
    neighbor_titles_tmp = [f"({x})" for x in neighbor_titles]
    gpt_response = lists[7][i]
    return f"""Abstract: {content}\nReferences: {', '.join(neighbor_titles_tmp)}\n{gpt_response}\nTitle: {title}."""

# Input: target content & title + neighbors' titles + TAPE text + top-2 label from GNNs
def c_2l_ct_t_tape_input(lists, i):
    title = lists[1][i]
    content = lists[2][i]
    neighbor_titles = [lists[1][x] for x in lists[5][i]]
    neighbor_titles_tmp = [f"({x})" for x in neighbor_titles]
    gpt_response = lists[7][i]
    top_candidates = [f"({x.split('|')[0]})" for x in lists[8][i]][:2]
    return f"""Abstract: {content}\nReferences: {', '.join(neighbor_titles_tmp)}\n{gpt_response}\nTitle: {title}\nChoose from: {', '.join(top_candidates[::-1])}"""

# Input: target title & content + neighbors' titles & labels + TAPE text
def c_0l_tc_tl_tape_input(lists, i):
    title = lists[1][i]
    content = lists[2][i]
    neighbor_title_and_label = [(lists[1][x], lists[4][x][0].split('|')[0]) for x in lists[5][i]]
    neighbor_title_and_label = [f"({x[0]}, {x[1]})" for x in neighbor_title_and_label]
    gpt_response = lists[7][i]
    return f"""Title: {title}\nAbstract: {content}\n{gpt_response}\nReferences: {', '.join(neighbor_title_and_label)}."""

# Input: target title & content + neighbors' titles & labels + TAPE text + top-2 label from GNNs
def c_2l_tc_tl_tape_input(lists, i):
    title = lists[1][i]
    content = lists[2][i]
    neighbor_title_and_label = [(lists[1][x], lists[4][x][0].split('|')[0]) for x in lists[5][i]]
    neighbor_title_and_label = [f"({x[0]}, {x[1]})" for x in neighbor_title_and_label]
    gpt_response = lists[7][i]
    top_candidates = [f"({x.split('|')[0]})" for x in lists[8][i]][:2]
    return f"""Title: {title}\nAbstract: {content}\nReferences: {', '.join(neighbor_title_and_label)}\n{gpt_response}\nChoose from: {', '.join(top_candidates[::-1])}"""

# Input: target title & content + neighbors' titles & labels + TAPE text
def c_0l_ct_tl_tape_input(lists, i):
    title = lists[1][i]
    content = lists[2][i]
    neighbor_title_and_label = [(lists[1][x], lists[4][x][0].split('|')[0]) for x in lists[5][i]]
    neighbor_title_and_label = [f"({x[0]}, {x[1]})" for x in neighbor_title_and_label]
    gpt_response = lists[7][i]
    return f"""Abstract: {content}\nReferences: {', '.join(neighbor_title_and_label)}\n{gpt_response}\nTitle: {title}."""

# Input: target content & title + neighbors' titles & labels + TAPE text + top-2 label from GNNs
def c_2l_ct_tl_tape_input(lists, i):
    title = lists[1][i]
    content = lists[2][i]
    neighbor_title_and_label = [(lists[1][x], lists[4][x][0].split('|')[0]) for x in lists[5][i]]
    neighbor_title_and_label = [f"({x[0]}, {x[1]})" for x in neighbor_title_and_label]
    gpt_response = lists[7][i]
    top_candidates = [f"({x.split('|')[0]})" for x in lists[8][i]][:2]
    return f"""Abstract: {content}\nReferences: {', '.join(neighbor_title_and_label)}\n{gpt_response}\nTitle: {title}\nChoose from: {', '.join(top_candidates[::-1])}"""

""" ****************** RAG input templates ****************** """
# Input: target title + RAG + TAPE text
def c_0l_t_rag_tape_input(lists, i, context):
    title = lists[1][i]
    gpt_response = lists[7][i]
    return f"""Title: {title}\n{gpt_response}\nReferences: {context}"""

# Input: target title + RAG + TAPE text + top-2 label from GNNs
def c_2l_t_rag_tape_input(lists, i, context):
    title = lists[1][i]
    gpt_response = lists[7][i]
    top_candidates = [f"({x.split('|')[0]})" for x in lists[8][i]][:2]
    return f"""Title: {title}\n{gpt_response}\nReferences: {context}\nChoose from: {', '.join(top_candidates[::-1])}"""

# Input: target title & content + RAG + TAPE text
def c_0l_tc_rag_tape_input(lists, i, context):
    title = lists[1][i]
    content = lists[2][i]
    gpt_response = lists[7][i]
    return f"""Title: {title}\nAbstract: {content}\n{gpt_response}\nReferences: {context}"""

# Input: target title & content + RAG + TAPE text + top-2 label from GNNs
def c_2l_tc_rag_tape_input(lists, i, context):
    title = lists[1][i]
    content = lists[2][i]
    gpt_response = lists[7][i]
    top_candidates = [f"({x.split('|')[0]})" for x in lists[8][i]][:2]
    return f"""Title: {title}\nAbstract: {content}\n{gpt_response}\nReferences: {context}\nChoose from: {', '.join(top_candidates[::-1])}"""

# Input: target content & title + RAG + TAPE text
def c_0l_ct_rag_tape_input(lists, i, context):
    title = lists[1][i]
    content = lists[2][i]
    gpt_response = lists[7][i]
    return f"""Abstract: {content}\nReferences: {context}\n{gpt_response}\nTitle: {title}"""

# Input: target content & title + RAG + TAPE text + top-2 label from GNNs
def c_2l_ct_rag_tape_input(lists, i, context):
    title = lists[1][i]
    content = lists[2][i]
    gpt_response = lists[7][i]
    top_candidates = [f"({x.split('|')[0]})" for x in lists[8][i]][:2]
    return f"""Abstract: {content}\nReferences: {context}\n{gpt_response}\nTitle: {title}\nChoose from: {', '.join(top_candidates[::-1])}"""

# Input: target title + RAG + neighbors' titles + TAPE text
def c_0l_t_ragt_tape_input(lists, i, context):
    title = lists[1][i]
    gpt_response = lists[7][i]
    neighbor_titles = [lists[1][x] for x in lists[5][i]]
    neighbor_titles_tmp = [f"({x})" for x in neighbor_titles]
    neighbor_titles_tmp = ', ' + ', '.join(neighbor_titles_tmp)
    return f"""Title: {title}\n{gpt_response}\nReferences: {context}{neighbor_titles_tmp}"""

# Input: target title + RAG + neighbors' titles + TAPE text + top-2 label from GNNs
def c_2l_t_ragt_tape_input(lists, i, context):
    title = lists[1][i]
    gpt_response = lists[7][i]
    neighbor_titles = [lists[1][x] for x in lists[5][i]]
    neighbor_titles_tmp = [f"({x})" for x in neighbor_titles]
    neighbor_titles_tmp = ', ' + ', '.join(neighbor_titles_tmp)
    top_candidates = [f"({x.split('|')[0]})" for x in lists[8][i]][:2]
    return f"""Title: {title}\n{gpt_response}\nReferences: {context}{neighbor_titles_tmp}\nChoose from: {', '.join(top_candidates[::-1])}"""

# Input: target title & content + RAG + neighbors' titles + TAPE text
def c_0l_tc_ragt_tape_input(lists, i, context):
    title = lists[1][i]
    content = lists[2][i]
    gpt_response = lists[7][i]
    neighbor_titles = [lists[1][x] for x in lists[5][i]]
    neighbor_titles_tmp = [f"({x})" for x in neighbor_titles]
    neighbor_titles_tmp = ', ' + ', '.join(neighbor_titles_tmp)
    return f"""Title: {title}\nAbstract: {content}\n{gpt_response}\nReferences: {context}{neighbor_titles_tmp}"""

# Input: target title & content + RAG + neighbors' titles + TAPE text + top-2 label from GNNs
def c_2l_tc_ragt_tape_input(lists, i, context):
    title = lists[1][i]
    content = lists[2][i]
    gpt_response = lists[7][i]
    neighbor_titles = [lists[1][x] for x in lists[5][i]]
    neighbor_titles_tmp = [f"({x})" for x in neighbor_titles]
    neighbor_titles_tmp = ', ' + ', '.join(neighbor_titles_tmp)
    top_candidates = [f"({x.split('|')[0]})" for x in lists[8][i]][:2]
    return f"""Title: {title}\nAbstract: {content}\n{gpt_response}\nReferences: {context}{neighbor_titles_tmp}\nChoose from: {', '.join(top_candidates[::-1])}"""

# Input: target content & title + RAG + neighbors' titles + TAPE text
def c_0l_ct_ragt_tape_input(lists, i, context):
    title = lists[1][i]
    content = lists[2][i]
    gpt_response = lists[7][i]
    neighbor_titles = [lists[1][x] for x in lists[5][i]]
    neighbor_titles_tmp = [f"({x})" for x in neighbor_titles]
    neighbor_titles_tmp = ', ' + ', '.join(neighbor_titles_tmp)
    return f"""Abstract: {content}\nReferences: {context}{neighbor_titles_tmp}\n{gpt_response}\nTitle: {title}"""

# Input: target content & title + RAG + neighbors' titles + TAPE text + top-2 label from GNNs
def c_2l_ct_ragt_tape_input(lists, i, context):
    title = lists[1][i]
    content = lists[2][i]
    gpt_response = lists[7][i]
    neighbor_titles = [lists[1][x] for x in lists[5][i]]
    neighbor_titles_tmp = [f"({x})" for x in neighbor_titles]
    neighbor_titles_tmp = ', ' + ', '.join(neighbor_titles_tmp)
    top_candidates = [f"({x.split('|')[0]})" for x in lists[8][i]][:2]
    return f"""Abstract: {content}\nReferences: {context}{neighbor_titles_tmp}\n{gpt_response}\nTitle: {title}\nChoose from: {', '.join(top_candidates[::-1])}"""

""" %%%%%%%%%%%%%%%%%%%%%%%%%%%%
Link prediction templates
%%%%%%%%%%%%%%%%%%%%%%%%%%%% """
# Link prediction prompt template
# TODO TODO TODO


""" %%%%%%%%%%%%%%%%%%%%%%%%%%%%
Rationale query templates (only input is needed)
%%%%%%%%%%%%%%%%%%%%%%%%%%%% """

# CoT Reasoning templates, used for querying Llama3-8B-Instruct
def get_rationale_template(task, input_mode, output_mode):
    assert task in ['c', 'l']
    assert input_mode in ['s', 'm']
    assert output_mode in ['g', 'd', 'p']

    if task == 'c':
        prompt_temp = rationale_prompt
        if input_mode == 's':
            pass
        elif input_mode == 'm':
            if output_mode == 'g':
                input_temp = rationale_c_m_g
            elif output_mode == 'd':
                input_temp = rationale_c_m_d
            elif output_mode == 'p':
                input_temp = rationale_c_m_p
    
    elif task == 'l':
        pass
    
    return prompt_temp, input_temp

def rationale_prompt():
    return f"""You are a text classification assistant. Please help explain the following classification results."""

def rationale_c_m_g(title, content, neighbor_titles, label_and_prob_list):
    neighbor_titles_tmp = [f"({x})" for x in neighbor_titles]
    return f"""({title}) is connected with {', '.join(neighbor_titles_tmp)}. Can you provide a short explanation why ({title}, {content}) is categorized as {label_and_prob_list[0].split('|')[0]}? The classification is surely correct. Try to focus on the keywords and the citation relationships."""

def rationale_c_m_d(title, content, neighbor_titles, label_and_prob_list):
    neighbor_titles_tmp = [f"{x}\n" for x in neighbor_titles]
    label_and_prob_list_tmp = [x.split('|')[0] for x in label_and_prob_list][:3]
    label_and_prob_list_tmp = [f'"{x}"' for x in label_and_prob_list_tmp]
    return f"""Title: {title}\nAbstract: {content}\nReferences: {''.join(neighbor_titles_tmp)}. The top-3 categories (in descending order) of the above paper are {', '.join(label_and_prob_list_tmp)}. Can you provide a short but accurate rationale about its categories? Please be as concise as possible."""

def rationale_c_m_p(title, content, neighbor_titles, label_and_prob_list):
    neighbor_titles_tmp = [f"({x})" for x in neighbor_titles]
    label_and_prob_list_tmp = [x.split('|') for x in label_and_prob_list]
    output_list = [f"{x[0]} ({x[1]}%)" for x in label_and_prob_list_tmp]
    return f"""({title}) is connected with {', '.join(neighbor_titles_tmp)}. Can you provide a short explanation why ({title}, {content}) is categorized as {', '.join(output_list)}? The classification is surely correct. Try to focus on the keywords and the citation relationships."""