from tqdm import tqdm
from get_various_story_attr import get_various_story_attr

separator = "\n\n==========\n\n"

def T1(data, sel_idx_list):
    queries_t1_list = []
    for i in tqdm(range(data.num_rows)):
        if sel_idx_list[i] == -1:
            continue
        story, _, _, s, e, _ = get_various_story_attr(data, i, sel_idx_list[i])
        story = story[:s] + '[X]' + story[e:]
        queries = f"Read and memorize the following story.\n\nStory:\n\n{story}{separator}Have you memorized the story?"
        queries_t1_list.append(queries)
    return queries_t1_list

def T2(data, sel_idx_list):
    queries_t2_list = []
    for i in tqdm(range(data.num_rows)):
        if sel_idx_list[i] == -1:
            continue
        _, question, answer, _, _, support_sent = get_various_story_attr(data, i, sel_idx_list[i])
        queries = f"Given a snippet of text, a question, and an answer, I want you to flip the answer so that new answer is opposite to the original one.\n\nSnippet of text = {support_sent}\nQuestion: {question}\nAnswer: {answer}\n\nOutput new answer only and nothing else. Do not write explanations. Following rules above, new Answer = "
        queries_t2_list.append(queries)
    return queries_t2_list

def T3(data, sel_idx_list, new_sent_list):
    queries_t3_list = []
    cnt = 0
    for i in tqdm(range(data.num_rows)):
        if sel_idx_list[i] == -1:
            continue
        _, question, _, _, _, support_sent = get_various_story_attr(data, i, sel_idx_list[i])
        queries = f"I want you to fill the [X] part in one sentence based on new answer to complete the story. After completion, if I ask \"{question}\", you should answer \"{new_sent_list[cnt]}\" based on the text you generated in the story. The sentences nearby [X] should be grammatically correct. Output the replacement of [X] only, do not output context before and after [X]. For example, original [X] = \"{support_sent}\".\n\nFollowing rules above, generate new snippet of text [X] (based on new answer) = "
        cnt += 1
        queries_t3_list.append(queries)
    return queries_t3_list

def t_T1(data, sel_idx_list, is_trim=False, do_ablation=False, start_idx=None, data_split=None): # different from T1 since we provide the whole story
    test_queries_t1_list = []
    if start_idx: # the code is added because of the evaluation of training data
        assert start_idx == 5000  # the code is added because of the evaluation of training data
        assert data_split == "train" # the code is added because of the evaluation of training data
    cnt_start_idx = 0  # the code is added because of the evaluation of training data
    for i in tqdm(range(data.num_rows)):
        if sel_idx_list[i] == -1:
            continue
        if start_idx and data_split == "train" and cnt_start_idx < start_idx:  # the code is added because of the evaluation of training data
            cnt_start_idx += 1  # the code is added because of the evaluation of training data
            continue  # the code is added because of the evaluation of training data
        story, _, _, s, e, _ = get_various_story_attr(data, i, sel_idx_list[i])
        if not is_trim:
            if do_ablation:
                queries = story
            else:
                assert not do_ablation
                queries = f"Read and memorize the following story.\n\nStory:\n\n{story}{separator}Have you memorized the story?"
        else:
            assert is_trim == True
            if do_ablation:
                queries = story[:e]
            else:
                queries = f"Read and memorize the following story.\n\nStory:\n\n{story[:e]}{separator}Have you memorized the story?"
        test_queries_t1_list.append(queries)
    return test_queries_t1_list

def t_T1_modified(data, sel_idx_list, new_sent_list, is_trim): # replace original support sentence with new one
    queries_t1_list = []
    cnt = 0
    for i in tqdm(range(data.num_rows)):
        if sel_idx_list[i] == -1:
            continue
        story, _, _, s, e, _ = get_various_story_attr(data, i, sel_idx_list[i])
        if is_trim:
            story = story[:s] + new_sent_list[cnt]
        else:
            assert not is_trim
            story = story[:s] + new_sent_list[cnt] + story[e:]
        queries = f"Read and memorize the following story.\n\nStory:\n\n{story}{separator}Have you memorized the story?"
        cnt += 1
        queries_t1_list.append(queries)
    return queries_t1_list

def t_T1_insertion(data, sel_idx_list, new_sent_list, start_token='', end_token=''):
    queries_t1_list = []
    cnt = 0
    for i in tqdm(range(data.num_rows)):
        if sel_idx_list[i] == -1:
            continue
        story, _, _, s, e, _ = get_various_story_attr(data, i, sel_idx_list[i])
        insert_sentence = f" {start_token}{new_sent_list[cnt]}{end_token} " # these will be properly replaced
        story = story[:e] + insert_sentence + story[e:]
        queries = f"Read and memorize the following story.\n\nStory:\n\n{story}{separator}Have you memorized the story?"
        cnt += 1
        queries_t1_list.append(queries)
    return queries_t1_list

def t_Tp(data, sel_idx_list, is_label=False, start_idx=None, data_split=None, model='gpt'): # newly added
    merge_test_messages_tp_list = []
    if start_idx: # the code is added because of the evaluation of training data
        assert start_idx == 5000  # the code is added because of the evaluation of training data
        assert data_split == "train" # the code is added because of the evaluation of training data
    cnt_start_idx = 0  # the code is added because of the evaluation of training data
    for i in tqdm(range(data.num_rows)):
        if sel_idx_list[i] == -1:
            continue
        if start_idx and data_split == "train" and cnt_start_idx < start_idx:  # the code is added because of the evaluation of training data
            cnt_start_idx += 1  # the code is added because of the evaluation of training data
            continue  # the code is added because of the evaluation of training data
        tmp = []
        questions = data['questions'][i][:sel_idx_list[i]]
        answers = data['answers'][i]['input_text'][:sel_idx_list[i]]
        for j, (question, answer) in enumerate(zip(questions, answers)):
            if is_label:
                if 'gpt' in model:
                    tmp.append({"role": "user", "content": f"Q{j+1}: {question}"})
                    tmp.append({"role": "assistant", "content": f"A{j+1}: {answer}"})
                else:
                    assert 'gemini' in model
                    tmp.append({"role": "user", "parts": f"Q{j+1}: {question}"})
                    tmp.append({"role": "model", "parts": f"A{j+1}: {answer}"})
            else:
                if 'gpt' in model:
                    tmp.append({"role": "user", "content": question})
                    tmp.append({"role": "assistant", "content": answer})
                else:
                    assert 'gemini' in model
                    tmp.append({"role": "user", "parts": question})
                    tmp.append({"role": "model", "parts": answer})
        merge_test_messages_tp_list.append(tmp)
    return merge_test_messages_tp_list

def t_Tc_ensemble(data, sel_idx_list, templates_list_daily_dialog, new_sent_list, start_idx=None, data_split=None):
    test_queries_correct_list_ensemble = []
    cnt = 0
    if start_idx: # the code is added because of the evaluation of training data
        assert start_idx == 5000  # the code is added because of the evaluation of training data
        assert data_split == "train" # the code is added because of the evaluation of training data
    cnt_start_idx = 0  # the code is added because of the evaluation of training data
    for i in tqdm(range(data.num_rows)):
        if sel_idx_list[i] == -1:
            continue
        if start_idx and data_split == "train" and cnt_start_idx < start_idx:  # the code is added because of the evaluation of training data
            cnt_start_idx += 1  # the code is added because of the evaluation of training data
            continue  # the code is added because of the evaluation of training data
        _, _, _, _, _, support_sent = get_various_story_attr(data, i, sel_idx_list[i])
        tmp_list = []
        for cur_template in templates_list_daily_dialog:
            if "[old_completion]" in cur_template:
                cur_template = cur_template.replace("[old_completion]", support_sent)
            if "[new_completion]" in cur_template:
                cur_template = cur_template.replace("[new_completion]", new_sent_list[cnt])
            tmp_list.append(cur_template)
        cnt += 1
        test_queries_correct_list_ensemble.append(tmp_list)
    return test_queries_correct_list_ensemble

def t_Ti(data, sel_idx_list, start_idx=None, data_split=None):
    test_queries_ti_list = []
    if start_idx: # the code is added because of the evaluation of training data
        assert start_idx == 5000  # the code is added because of the evaluation of training data
        assert data_split == "train" # the code is added because of the evaluation of training data
    cnt_start_idx = 0  # the code is added because of the evaluation of training data
    for i in tqdm(range(data.num_rows)):
        if sel_idx_list[i] == -1:
            continue
        if start_idx and data_split == "train" and cnt_start_idx < start_idx:  # the code is added because of the evaluation of training data
            cnt_start_idx += 1  # the code is added because of the evaluation of training data
            continue  # the code is added because of the evaluation of training data
        _, question, _, _, _, _ = get_various_story_attr(data, i, sel_idx_list[i])
        queries = question
        test_queries_ti_list.append(queries)
    return test_queries_ti_list

def t_Tr(data, sel_idx_list, instruct_update=False, start_idx=None, data_split=None):
    test_queries_tr_list = []
    if start_idx: # the code is added because of the evaluation of training data
        assert start_idx == 5000  # the code is added because of the evaluation of training data
        assert data_split == "train" # the code is added because of the evaluation of training data
    cnt_start_idx = 0  # the code is added because of the evaluation of training data
    for i in tqdm(range(data.num_rows)):
        if sel_idx_list[i] == -1:
            continue
        if start_idx and data_split == "train" and cnt_start_idx < start_idx:  # the code is added because of the evaluation of training data
            cnt_start_idx += 1  # the code is added because of the evaluation of training data
            continue  # the code is added because of the evaluation of training data
        if instruct_update:
            queries = "What's the new story with the correction? Output new story and nothing else." # "Reiterate the new story with the correction." # "Summarize the new story with the correction."
        else:
            queries = "What's the story? Output the story and nothing else." # "Reiterate the story." # "Summarize the story."
        test_queries_tr_list.append(queries)
    return test_queries_tr_list

def t_Tr_oracle(data, sel_idx_list, new_sent_list, start_idx=None, data_split=None):
    test_queries_tr_oracle_list = []
    cnt = 0
    if start_idx: # the code is added because of the evaluation of training data
        assert start_idx == 5000  # the code is added because of the evaluation of training data
        assert data_split == "train" # the code is added because of the evaluation of training data
    cnt_start_idx = 0  # the code is added because of the evaluation of training data
    for i in tqdm(range(data.num_rows)):
        if sel_idx_list[i] == -1:
            continue
        if start_idx and data_split == "train" and cnt_start_idx < start_idx:  # the code is added because of the evaluation of training data
            cnt_start_idx += 1  # the code is added because of the evaluation of training data
            continue  # the code is added because of the evaluation of training data
        story, _, _, s, e, _ = get_various_story_attr(data, i, sel_idx_list[i])
        queries = story[:s] + new_sent_list[cnt] + story[e:]
        cnt += 1
        test_queries_tr_oracle_list.append(queries)
    return test_queries_tr_oracle_list

def t_Tv(data, sel_idx_list, start_idx=None, data_split=None):
    test_queries_tv_list = []
    if start_idx: # the code is added because of the evaluation of training data
        assert start_idx == 5000  # the code is added because of the evaluation of training data
        assert data_split == "train" # the code is added because of the evaluation of training data
    cnt_start_idx = 0  # the code is added because of the evaluation of training data
    for i in tqdm(range(data.num_rows)):
        if sel_idx_list[i] == -1:
            continue
        if start_idx and data_split == "train" and cnt_start_idx < start_idx:  # the code is added because of the evaluation of training data
            cnt_start_idx += 1  # the code is added because of the evaluation of training data
            continue  # the code is added because of the evaluation of training data
        queries = "Really? Let's think about the update." # "Really?", ""
        test_queries_tv_list.append(queries)
    return test_queries_tv_list

def t_Tv_2(data, sel_idx_list, start_idx=None, data_split=None):
    test_queries_tv_2_list = []
    if start_idx: # the code is added because of the evaluation of training data
        assert start_idx == 5000  # the code is added because of the evaluation of training data
        assert data_split == "train" # the code is added because of the evaluation of training data
    cnt_start_idx = 0  # the code is added because of the evaluation of training data
    for i in tqdm(range(data.num_rows)):
        if sel_idx_list[i] == -1:
            continue
        if start_idx and data_split == "train" and cnt_start_idx < start_idx:  # the code is added because of the evaluation of training data
            cnt_start_idx += 1  # the code is added because of the evaluation of training data
            continue  # the code is added because of the evaluation of training data
        queries = "Therefore, based on your previous response, your answer to the last question is more likely to be 'Yes', 'No'? You must output 'Yes' or 'No' first."
        test_queries_tv_2_list.append(queries)
    return test_queries_tv_2_list
