import argparse
import random
import re

# this package is used if discard_long_input is set
import tiktoken # count number of token
# enc = tiktoken.encoding_for_model("gpt-3.5-turbo-0613")
# e.g., len(enc.encode("hello world")) == 2
import torch
from tqdm import tqdm
from datasets import load_dataset

from set_seed import set_seed
from openai_api_setup import get_response_list_from_chatgpt, get_error_messages_list
from generate_yes_no_index_list import generate_yes_no_index_list
import generate_queries_list
from get_various_story_attr import get_various_story_attr
from save_data import save_data
from get_templates_list_from_daily_dialog import get_templates_list_from_daily_dialog

####################################################
def parse_argument():
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument('--data_split', required=True, choices=['train', 'validation'], help='choose training set or validation set in CoQA dataset')
    parser.add_argument('--rewrite_i', type=int, choices=range(0, 3), help='Choose 1st, 2nd, or 3rd response of MTurk')
    parser.add_argument('--max_iter', required=True, type=int, choices=range(1, 6), help='set maximum of iteration in Td. Note that it will increase the fee')
    parser.add_argument('--is_trim', action='store_true', help='Whether to trim the story (so original support sentence is always at the end of the story)')
    parser.add_argument('--is_recall', action='store_true', help='Whether to recall the story first then perform Td')
    #parser.add_argument('--is_unrolling', action='store_true', help='Whether to append all chat history. Note that I do not automatically trim the data so you should set this flag with care, or you will always see the error message: openai.error.InvalidRequestError: This model\'s maximum context length is 16385 tokens... Please reduce the length of the messages or completion.')
    parser.add_argument('--instruct_delete', action='store_true', help='Whether to instruct ChatGPT in deletion process, if this flag is set, we do not label previous QA pairs as they are presented in the deletion conversation so no need to worry about the "reference" in the dialogue session')
    # TODO
    #parser.add_argument('--discard_long_input', action='store_true', help='Whether to discard the oldest conversation if the input exceed 4096 (or 16,384) in deletion process')
    args = parser.parse_args()
    return args

args = parse_argument()
data_split = args.data_split # "train", "validation"
rewrite_i = args.rewrite_i
max_iter = args.max_iter
is_trim = args.is_trim
is_recall = args.is_recall
#is_unrolling = args.is_unrolling
instruct_delete = args.instruct_delete
#discard_long_input = args.discard_long_input

"""
data_split = "validation"
rewrite_i = 0
max_iter = 3
is_trim = False
is_recall = True
#is_unrolling = False
instruct_delete = True
#discard_long_input = False
"""

assert 1 <= max_iter <= 5

#if instruct_delete:
#    assert discard_long_input # as it can easily exceed 16k tokens

# NOT DONE
#assert instruct_delete != True

####################################################

_d_coqa = load_dataset('coqa')
data = _d_coqa[data_split]


yn_qa_idx_list = generate_yes_no_index_list(data)
set_seed(0)
sel_idx_list = [random.choice(i) if i else -1 for i in yn_qa_idx_list]


log_file_path = "../data/log"
pt_file_path = "../data/pt"

mturk_responses = torch.load('../data/csv/MTurk_validation/final_validation_464.pt')
assert 464 == len(list(filter(lambda x: x != -1, sel_idx_list)))
assert all(len(mturk_responses["validation"][k]) == 3 for k in mturk_responses["validation"])
mturk_rewrite_0_list = [mturk_responses["validation"][k][0] for k in range(464)]
mturk_rewrite_1_list = [mturk_responses["validation"][k][1] for k in range(464)]
mturk_rewrite_2_list = [mturk_responses["validation"][k][2] for k in range(464)]

mturk_rewrite_lists = [mturk_rewrite_0_list,
                       mturk_rewrite_1_list,
                       mturk_rewrite_2_list
                      ]

templates_list_daily_dialog = get_templates_list_from_daily_dialog()
top_k_idx = [2] # [7, 2, 1] # test all in the future (REALLY EXPENSIVE)
templates_list_daily_dialog = [templates_list_daily_dialog[i] for i in top_k_idx]
assert len(top_k_idx) == len(templates_list_daily_dialog)

# test if exists first
if is_recall:
    for i in range(len(templates_list_daily_dialog)):
        _ = torch.load(f"{pt_file_path}/coqa_{data_split}_yes_no_mturk_rewrite_{rewrite_i}_p_c{top_k_idx[i]}_r_chatgpt_recall_story_trim_0.pt")
####################################################


story_list = []
prev_questions_answers_list = []
question_list = []
answer_list = []
old_sent_list = []
cnt = 0
for i in tqdm(range(data.num_rows)):
    if sel_idx_list[i] == -1:
        continue
    story, question, answer, _, _, old_support_sent = get_various_story_attr(data, i, sel_idx_list[i])
    # process previous qa
    prev_questions = data['questions'][i][:sel_idx_list[i]]
    prev_answers = data['answers'][i]['input_text'][:sel_idx_list[i]]
    prev_questions_answers = []
    for j, (q, a) in enumerate(zip(prev_questions, prev_answers)):
        if instruct_delete:
            prev_questions_answers.append(f"{q}\n\n{a}")
        else:
            prev_questions_answers.append(f"Q{j+1}: {q}\n\nA{j+1}: {a}")
    story_list.append(story)
    prev_questions_answers_list.append(prev_questions_answers)
    question_list.append(question)
    answer_list.append(answer)
    old_sent_list.append(old_support_sent)
    cnt += 1

assert cnt == len(story_list) == len(prev_questions_answers_list) == len(question_list) == len(answer_list) == len(old_sent_list)

####################################################
# Test if ChatGPT can update answer based on new story

def remove_qa_prefix(s):
    # remove "Q:" and "A:" etc.
    remove_prefix_list = [r'[qQ]\d+:', r'[aA]\d+:', r'[qQ]:', r'[aA]:', r'[qQ]uestion\d+:', r'[aA]nswer\d+:']
    # Use re.sub() to replace all occurrences of the pattern with an empty string
    for pattern in remove_prefix_list:
        s = re.sub(pattern, '', s)
    return s

# 1st/2nd stage prompting
def generate_queries_Td(j, cur_iter, cur_info, cur_hist, instruct_delete, is_recall, stage=None):
    # note that we do not use cur_hist for now, because they will only be used in instruct_delete=True
    # TODO: add instruct_delete
    # remove "Q:", "A:" etc.
    assert stage and stage in [1, 2]
    cur_info = remove_qa_prefix(cur_info)
    nth_str = "0th"
    if j == 0:
        if cur_iter == 0:
            if is_recall:
                story_str = "new story"
            else:
                story_str = "story with the correction"
        else:
            story_str = "most recent modified story"
        if instruct_delete:
            story_info = f"Story = \"\"\"\n{cur_hist}\n\"\"\"\n\n"
            story_str = "story" # overwrite
        else:
            story_info = ""
        # same
        if stage == 1:
            queries = f"{story_info}Correction = \"\"\"\n{cur_info}\n\"\"\"\n\nWhich parts in the {story_str} contradict the correction? If the {story_str} entails the correction, output 'NO MODIFICATION'. Let\'s read the story line by line. List all the contradictions one by one, if any."
        elif stage == 2:
            queries = f"Can you modify the {story_str}, one by one, so that the correction entail the {story_str}?"
        else:
            print("stage should not be here....")
            assert False
    else:
        assert j != 0
        if j == 1:
            nth_str = "1st"
        elif j == 2:
            nth_str = "2nd"
        elif j == 3:
            nth_str = "3rd"
        else:
            nth_str = f"{j}th"
        if cur_iter == 0:
            qa_str = f"Q{j} A{j} pair" # f"{nth_str} QA pair"
        else:
            qa_str = f"most recent modified Q{j} A{j} pair" # f"most recent modified {nth_str} QA pair"
        if instruct_delete:
            qa_info = f"QA pair = \"\"\"\n{cur_hist}\n\"\"\"\n\n"
            qa_str = f"QA pair" # overwrite
        else:
            qa_info = ""
        # same
        if stage == 1:
            queries = f"{qa_info}Correction = \"\"\"\n{cur_info}\n\"\"\"\n\nDoes the {qa_str} contradict the correction?\nIf the {qa_str} entails the correction, output 'NO MODIFICATION'. If the {qa_str} contradicts the correction, explain why they are contradictory in one sentence. If they are in a neutral relation, output 'NO MODIFICATION'. Let\'s think step by step."
        elif stage == 2:
            queries = f"Can you modify the {qa_str} so that it entails the correction? DO NOT modify the QA pair by copying the correction. Let\'s think step by step."
        else:
            print("stage should not be here....")
            assert False
    return queries, nth_str

"""
def generate_queries_Td(j, cur_iter, cur_info, cur_hist, instruct_delete, is_recall):
    # note that we do not use cur_hist for now, because they will only be used in instruct_delete=True
    # TODO: add instruct_delete
    # remove "Q:", "A:" etc.
    cur_info = remove_qa_prefix(cur_info)
    nth_str = "0th"
    if j == 0:
        if cur_iter == 0:
            if is_recall:
                story_str = "new story"
            else:
                story_str = "story with the correction"
            queries = f"Correction = \"\"\"\n{cur_info}\n\"\"\"\n\nDoes the {story_str} entail the correction?\nIf the {story_str} entails the correction, output 'NO MODIFICATION'. Otherwise, output the modified story so that it entails the correction." # Let\'s read the story line by line.
        else:
            queries = f"Correction = \"\"\"\n{cur_info}\n\"\"\"\n\nDoes the most recent modified story entail the correction?\nIf the most recent modified story entails the correction, output 'NO MODIFICATION'. Otherwise, output new modified story so that it entails the correction." # Let\'s read the story line by line.
    else:
        if j == 1:
            nth_str = "1st"
        elif j == 2:
            nth_str = "2nd"
        elif j == 3:
            nth_str = "3rd"
        else:
            nth_str = f"{j}th"
        if cur_iter == 0:
            queries = f"Correction = \"\"\"\n{cur_info}\n\"\"\"\n\nDoes the {nth_str} QA pair contradict the correction?\nIf the {nth_str} QA pair entails the correction, output 'NO MODIFICATION'. If the {nth_str} QA pair contradicts the correction, output the modified {nth_str} QA pair so that it does not contradict the correction. If they are in a neutral relation, output 'NO MODIFICATION'. DO NOT modify the {nth_str} QA pair by copying the correction. Let\'s think step by step."
        else:
            queries = f"Correction = \"\"\"\n{cur_info}\n\"\"\"\n\nDoes the most recent modified {nth_str} QA pair contradict the correction?\nIf the modified {nth_str} QA pair entails the correction, output 'NO MODIFICATION'. If the modified {nth_str} QA pair contradicts the correction, output new modified {nth_str} QA pair so that it does not contradict the correction. If they are in a neutral relation, output 'NO MODIFICATION'. DO NOT modify the {nth_str} QA pair by copying the correction. Let\'s think step by step."
    return queries, nth_str
"""

def extract_story_or_qa_from_final_response_text(final_merge_messages, test_final_response_text, j, nth_str, instruct_delete):
    #new_chat_session = [final_merge_messages[-1], {"role": "assistant", "content": test_final_response_text}]
    assert len(final_merge_messages) >= 3
    new_chat_session = [final_merge_messages[-3], final_merge_messages[-2], final_merge_messages[-1], {"role": "assistant", "content": test_final_response_text}]
    if j == 0:
        new_chat_session.append({"role": "user", "content": "Therefore, what is the modified story? Output the modified story and nothing else."})
    else:
        if instruct_delete:
            new_chat_session.append({"role": "user", "content": f"Therefore, what is the modified QA pair? Your response must contain two lines only. The first line is the question, and the second line is the answer. Output the modified QA pair and nothing else."})
        else:
            new_chat_session.append({"role": "user", "content": f"Therefore, what is the modified {nth_str} QA pair? Your response must contain two lines only. The first line is the question, and the second line is the answer. Output the modified {nth_str} QA pair and nothing else."})
    #print(f"new_chat_session = {new_chat_session}") # debugging
    #s = input("Pause. Press any keys to continue, Or 'q' to quit\n") # debugging
    #if s.lower() == 'q': # debugging
    #    assert False  # debugging
    _ = get_response_list_from_chatgpt([new_chat_session], model="gpt-3.5-turbo-16k-0613", max_tokens=512)
    return _[0]['choices'][0]['message']['content']

# generated from data_error_analysis.py (at the end of the file), for testing that are really hard to modify the exising knowledge and debugging as well
extremely_hard_index = [9, 12, 27, 40, 41, 44, 47, 49, 66, 71, 72, 81, 107, 116, 120, 121, 126, 129, 130, 134, 136, 147, 162, 167, 168, 190, 198, 214, 215, 222, 234, 237, 243, 248, 257, 258, 268, 273, 274, 282, 285, 292, 301, 310, 313, 317, 325, 338, 341, 353, 356, 385, 388, 391, 394, 404, 406, 409, 417, 435, 436, 461]
second_hard_index = [7, 10, 16, 18, 23, 35, 38, 42, 48, 52, 55, 59, 88, 96, 99, 103, 111, 115, 123, 128, 133, 150, 156, 160, 172, 174, 176, 177, 179, 185, 187, 191, 194, 195, 206, 213, 225, 227, 230, 238, 241, 255, 264, 277, 295, 298, 309, 312, 314, 332, 334, 354, 368, 381, 387, 401, 418, 428, 444, 448, 450, 451, 454, 458, 460, 463]
third_hard_index = [31, 165, 202, 259, 350, 386, 419, 432]
fourth_hard_index = [11, 17, 22, 50, 64, 78, 91, 104, 106, 112, 122, 158, 175, 181, 261, 286, 331, 335, 336, 357, 362, 364, 365, 370, 372, 378, 398, 402, 415, 422, 427, 433]

combine_index_list = sorted(list(set(extremely_hard_index + second_hard_index + third_hard_index + fourth_hard_index)))

def remove_duplicate_element(l):
    # remove potential duplicate element
    # For example, if q = [['test', 1], ['test', 1], ['another test', 1]] -> [['test', 1], ['another test', 1]]
    # NOTE THAT SET DOES NOT PRESERVE ORDER
    # q = list(map(list, set(map(tuple, q))))
    ret = []
    seen = set()
    for i in l:
        if tuple(i) not in seen:
            seen.add(tuple(i))
            ret.append(i)
    return ret


#for rewrite_i in range(len(mturk_rewrite_lists)):
for rewrite_i in [rewrite_i]:
    mturk_rewrite_list = mturk_rewrite_lists[rewrite_i] # new_sent_list
    # (test) T1
    test_queries_t1_list = generate_queries_list.t_T1(data, sel_idx_list, is_trim=is_trim)
    test_messages_t1_list = [[{"role": "user", "content": queries}] for queries in test_queries_t1_list]
    test_response_t1 = [{"role": "assistant", "content": "Yes, I have memorized the story."}]
    merge_test_messages_t1_list = [test_msg_t1 + test_response_t1 for test_msg_t1 in test_messages_t1_list]
    # (test) Tp (previous QA pairs, consecutive)
    if instruct_delete:
        merge_test_messages_tp_list = generate_queries_list.t_Tp(data, sel_idx_list)
    else:
        merge_test_messages_tp_list = generate_queries_list.t_Tp(data, sel_idx_list, is_label=True) # NOTE THAT is_label does not include in previous experiment
    # (test) Tc (ideally, can insert anywhere)
    test_queries_correct_list_ensemble = generate_queries_list.t_Tc_ensemble(data, sel_idx_list, templates_list_daily_dialog, mturk_rewrite_list)
    test_messages_correct_list_ensemble = [[[{"role": "user", "content": cur_query}] for cur_query in queries] for queries in test_queries_correct_list_ensemble]
    # play with ChatGPT and choose one reasonable response beforehand
    test_response_correct = [{"role": "assistant", "content": "No problem at all! I have updated my memory of the story with the correction you provided. Thank you for letting me know."}]
    merge_test_messages_tc_list_ensemble = [[j + test_response_correct for j in i] for i in test_messages_correct_list_ensemble]
    # (test) Ti
    # Note that we do not test Ti here, they are done in error_analysis_Td.py. However, we can test Qi at the end of iteration (for loop) and test if the GPT-3.5 update the knowledge.
    test_queries_ti_list = generate_queries_list.t_Ti(data, sel_idx_list)
    # (test) Td (delete, ask ChatGPT to test if the story, previous QA entails the correction in Tc)
    assert len(merge_test_messages_t1_list) == len(merge_test_messages_tp_list) == len(merge_test_messages_tc_list_ensemble) == len(test_queries_ti_list) == len(list(filter(lambda x: x != -1, sel_idx_list)))
    assert all(len(i) == len(templates_list_daily_dialog) for i in merge_test_messages_tc_list_ensemble)
    # run experiment
    # combination: (1) T1, Tp, Tc, Td, Ti (2) T1, Tp, Tc, Tr, Td, Ti
    for method in ["t1-tp-tc-tr-td-ti"]: # ["t1-tp-tc-td-ti", "t1-tp-tc-tr-td-ti"]
        if method == "t1-tp-tc-tr-td-ti":
            assert is_recall
        else:
            assert not is_recall
        for i in range(len(templates_list_daily_dialog)):
            if is_recall:
                # (test) Tr (recall, note that I used previous stored story, if you want to generate from scratch, you should modify the code, see reference in main_mturk_with_Tr.py)
                test_queries_tr_list = generate_queries_list.t_Tr(data, sel_idx_list, instruct_update=True)
                test_messages_tr_list = [[{"role": "user", "content": queries}] for queries in test_queries_tr_list]
                test_response_tr_text_list = torch.load(f"{pt_file_path}/coqa_{data_split}_yes_no_mturk_rewrite_{rewrite_i}_p_c{top_k_idx[i]}_r_chatgpt_recall_story_trim_0.pt")
                test_response_tr_list = [[{"role": "assistant", "content": _}] for _ in test_response_tr_text_list]
                merge_test_messages_tr_list = [test_msg_tr + test_response_tr for test_msg_tr, test_response_tr in zip(test_messages_tr_list, test_response_tr_list)]
                assert len(merge_test_messages_tr_list) == len(list(filter(lambda x: x != -1, sel_idx_list)))
                prev_messages_list = [t1 + tp + tc[i] + tr for t1, tp, tc, tr in zip(merge_test_messages_t1_list, merge_test_messages_tp_list, merge_test_messages_tc_list_ensemble, merge_test_messages_tr_list)]
                history_list = [[story] + prev_qa for story, prev_qa in zip(test_response_tr_text_list, prev_questions_answers_list)]
            else:
                prev_messages_list = [t1 + tp + tc[i] for t1, tp, tc in zip(merge_test_messages_t1_list, merge_test_messages_tp_list, merge_test_messages_tc_list_ensemble)]
                history_list = [[story] + prev_qa for story, prev_qa in zip(story_list, prev_questions_answers_list)]
            # IC-MRE algorithm
            new_info_queue_list = [[[tc[i][0]['content'], 0]] for tc in merge_test_messages_tc_list_ensemble] # Q.APPEND(S)
            #new_info_queue_list = [[_, 0] for _ in mturk_rewrite_list] # Q.APPEND(S)
            assert len(new_info_queue_list) == len(history_list)
            n = len(new_info_queue_list)
            for ii in tqdm(range(n)):
                unexpected_error_occurs = False
                #if ii in combine_index_list: # run the remaining data (other than index in combine_index_list)!
                #if ii not in combine_index_list: # run a part of data # debugging
                #    continue # debugging
                #print(f"before history = {history_list[ii]}") # debugging
                cur_iter = 0
                while len(new_info_queue_list[ii]) > 0: # WHILE Q:
                    if unexpected_error_occurs:
                        break
                    new_info_queue_list[ii] = remove_duplicate_element(new_info_queue_list[ii])
                    #print(f"Queue: {new_info_queue_list[ii]}") # debugging
                    cur_info, cur_iter = new_info_queue_list[ii][0] # S = Q.POPLEFT()
                    #print(f"cur_iter = {cur_iter}; cur_info = {cur_info}") # debugging
                    del new_info_queue_list[ii][0]
                    #print(f"Queue after popped: {new_info_queue_list[ii]}") # debugging
                    for j, cur_hist in enumerate(history_list[ii]): # FOR T IN HISTORY:
                        # IF RELATED (T,S) AND INCONSISTENT(T,S)
                        #print(f"j = {j}, cur_hist (abridged to first 150 chars) = {cur_hist[:150]}") # debugging
                        queries, nth_str = generate_queries_Td(j, cur_iter, cur_info, cur_hist, instruct_delete, is_recall, stage=1)
                        #print("==========") # debugging
                        #print(f"queries = {queries}") # debugging
                        #print("==========") # debugging
                        test_messages_td = [{"role": "user", "content": queries}]
                        final_merge_messages = prev_messages_list[ii] + test_messages_td
                        #print(f"final_merge_messages = {final_merge_messages}") # debugging
                        #print("==========") # debugging
                        #s = input("Pause. Press any keys to continue, Or 'q' to quit\n") # debugging
                        #if s.lower() == 'q': # debugging
                        #    assert False  # debugging
                        if j == 0:
                            test_final_response = get_response_list_from_chatgpt([final_merge_messages], model="gpt-3.5-turbo-16k-0613", max_tokens=512)
                        else:
                            test_final_response = get_response_list_from_chatgpt([final_merge_messages], model="gpt-3.5-turbo-16k-0613") # max_tokens = 128
                        if any(test_final_response[0] == err_msg for err_msg in get_error_messages_list()):
                            tqdm.write(f"\"{test_final_response[0]}\", skip data {ii}")
                            unexpected_error_occurs = True
                            break
                        test_final_response_text = test_final_response[0]['choices'][0]['message']['content']
                        assert type(test_final_response_text) is str
                        #print(f"test_final_response_text = {test_final_response_text}") # debugging
                        #print("==========") # debugging
                        is_in_2nd_stage = False # test if 2nd stage prompting in delete process is used
                        if ("NO MODIFICATION" not in test_final_response_text) and ("pair does not contradict the correction" not in test_final_response_text) and ("there is no contradiction" not in test_final_response_text.lower()) and ("there are no contradictions" not in test_final_response_text.lower()) and ("in a neutral relation to the correction" not in test_final_response_text) and ("they are in a neutral relation" not in test_final_response_text.lower()): # T = UPDATE(T,S)
                            is_in_2nd_stage = True
                            ####################
                            # 2nd stage prompting
                            final_merge_messages += [{"role": "assistant", "content": test_final_response_text}]
                            queries, nth_str = generate_queries_Td(j, cur_iter, cur_info, cur_hist, instruct_delete, is_recall, stage=2)
                            test_messages_td_2 = [{"role": "user", "content": queries}]
                            final_merge_messages += test_messages_td_2
                            #print(f"final_merge_messages = {final_merge_messages}") # debugging
                            #print("==========") # debugging
                            #s = input("Pause. Press any keys to continue, Or 'q' to quit\n") # debugging
                            #if s.lower() == 'q': # debugging
                            #    assert False  # debugging
                            if j == 0:
                                test_final_response = get_response_list_from_chatgpt([final_merge_messages], model="gpt-3.5-turbo-16k-0613", max_tokens=512)
                            else:
                                test_final_response = get_response_list_from_chatgpt([final_merge_messages], model="gpt-3.5-turbo-16k-0613")
                            if any(test_final_response[0] == err_msg for err_msg in get_error_messages_list()):
                                tqdm.write(f"\"{test_final_response[0]}\", skip data {ii}")
                                unexpected_error_occurs = True
                                break
                            test_final_response_text_2 = test_final_response[0]['choices'][0]['message']['content']
                            assert type(test_final_response_text_2) is str
                            #print(f"test_final_response_text_2 = {test_final_response_text_2}") # debugging
                            #print("==========") # debugging
                            ####################
                            newly_generated_info = extract_story_or_qa_from_final_response_text(final_merge_messages, test_final_response_text_2, j, nth_str, instruct_delete)
                            newly_generated_info = remove_qa_prefix(newly_generated_info)
                            assert type(newly_generated_info) is str
                            #print(f"newly_generated_info (remove qa prefix, if exists) = {newly_generated_info}") # debugging
                            #print("==========") # debugging
                            # Q.APPEND(T)
                            if cur_iter + 1 <= max_iter and "NO MODIFICATION".lower() not in newly_generated_info.lower(): # NOTE THAT even if proper instruction, ChatGPT sometimes still does not output 'NO MODIFICATION' if they are in a neutral relation
                                #print(f"next iter {cur_iter+1} does not exceed max_iter ({max_iter}), append in Queue") # debugging
                                new_info_queue_list[ii].append([newly_generated_info, cur_iter+1])
                                #print(f"new Queue: {new_info_queue_list[ii]}") # debugging
                            # UPDATE HISTORY BY T
                            if "NO MODIFICATION".lower() not in newly_generated_info.lower(): # NOTE THAT even if proper instruction, ChatGPT sometimes still does not output 'NO MODIFICATION' if they are in a neutral relation
                                # correct the QA pair's label
                                if j != 0 and "\n" in newly_generated_info and len(newly_generated_info.split('\n')) == 2:
                                    tmp_q, tmp_a = newly_generated_info.split('\n')
                                    if ":" not in tmp_q and ":" not in tmp_a:
                                        tmp_q = f"Q{j}:{tmp_q}"
                                        tmp_a = f"A{j}:{tmp_a}"
                                    newly_generated_info = tmp_q + "\n" + tmp_a
                                if newly_generated_info == history_list[ii][j] and len(new_info_queue_list[ii]): # newly added, they are the same, so do not need to push into the Queue, also save money
                                    print(f"Delete last element in Queue, as it is the same as the previous...") # debugging
                                    del new_info_queue_list[ii][-1] # newly added
                                history_list[ii][j] = newly_generated_info
                                #print(f"update history: {history_list[ii]}") # debugging
                        # UPDATE DIALOGUE
                        prev_messages_list[ii].extend(test_messages_td + [{"role": "assistant", "content": test_final_response_text}])
                        if is_in_2nd_stage:
                            prev_messages_list[ii].extend(test_messages_td_2 + [{"role": "assistant", "content": test_final_response_text_2}])
                        #print("==========") # debugging
                        #print(f"extend chat history: {prev_messages_list[ii]}") # debugging
                        #print("==========") # debugging
                        #s = input("Pause. Press any keys to continue, Or 'q' to quit\n") # debugging
                        #if s.lower() == 'q': # debugging
                        #    assert False  # debugging
                    #print(f"display updated history: {history_list[ii]}") # debugging
                    #print("==========") # debugging
                    # end for
                #print(f"after history = {history_list[ii]}") # debugging
                #print("==========") # debugging
                # end while
                # save data
                if unexpected_error_occurs:
                    if method == "t1-tp-tc-tr-td-ti":
                        torch.save({"is_abort": True, "index": ii, "delete_iter": cur_iter+1, "is_trim": is_trim, "is_recall": is_recall, "instruct_delete": instruct_delete, "modified_story_and_prev_qas": history_list[ii], "all_chat_history": prev_messages_list[ii]}, f"{pt_file_path}/coqa_{data_split}_yes_no_mturk_rewrite_{rewrite_i}_p_c{top_k_idx[i]}_r_d_chatgpt_data_index_{ii}_metadata.pt")
                        tqdm.write(f"Fail to run entirely, but still save {pt_file_path}/coqa_{data_split}_yes_no_mturk_rewrite_{rewrite_i}_p_c{top_k_idx[i]}_r_d_chatgpt_data_index_{ii}_metadata.pt!")
                else:
                    assert not unexpected_error_occurs
                    if method == "t1-tp-tc-tr-td-ti":
                        torch.save({"is_abort": False, "index": ii, "delete_iter": cur_iter+1, "is_trim": is_trim, "is_recall": is_recall, "instruct_delete": instruct_delete, "modified_story_and_prev_qas": history_list[ii], "all_chat_history": prev_messages_list[ii]}, f"{pt_file_path}/coqa_{data_split}_yes_no_mturk_rewrite_{rewrite_i}_p_c{top_k_idx[i]}_r_d_chatgpt_data_index_{ii}_metadata.pt")
                        tqdm.write(f"Successfully save {pt_file_path}/coqa_{data_split}_yes_no_mturk_rewrite_{rewrite_i}_p_c{top_k_idx[i]}_r_d_chatgpt_data_index_{ii}_metadata.pt!")

