import argparse
import random

#import pandas as pd
import tiktoken # count number of tokens
import torch
from tqdm import tqdm

from datasets import load_dataset

from set_seed import set_seed
from openai_api_setup import get_response_list_from_chatgpt
from generate_yes_no_index_list import generate_yes_no_index_list
from get_various_story_attr import get_various_story_attr
import generate_queries_list
from compare_old_new_answer import compare_old_new_answer
from update_answer_list import update_answer_list

####################################################

####################################################
def parse_argument():
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument('--data_split', required=True, choices=['train', 'validation'], help='choose training set or validation set in CoQA dataset')
    parser.add_argument('--data_subset', required=True, choices=['all', 'hard', 'easy'], help='choose training subset or validation subset in CoQA dataset')
    parser.add_argument('--template_index', required=True, type=int, choices=range(0, 15), help='choose the correction template index (0~14)') # newly added
    #parser.add_argument('--update_answer', action='store_true', help='Whether to update answer in train/validation answer (ensure that the gold answers will start with Yes/No)')
    parser.add_argument('--rewrite_i', required=True, type=int, choices=range(0, 3), help='Choose 1st, 2nd, or 3rd response of MTurk')
    parser.add_argument('--run_ti', action='store_true', help='Whether to evaluate T1m-Tpm-Ti from scratch')
    parser.add_argument('--run_append', action='store_true', help='Whether to append and evaluate Ti to the previous (long) chat history')
    args = parser.parse_args()
    return args

args = parse_argument()
data_split = args.data_split # "train", "validation"
data_subset = args.data_subset # "all", "hard", "easy"
template_index = args.template_index # newly added
rewrite_i = args.rewrite_i
run_ti = args.run_ti
run_append = args.run_append


#data_split = "validation" # "train", "validation"
#data_subset = "all" # "hard", "easy", "all"
#rewrite_i = 2 # change to 0, 1, 2
#run_ti = False # change to True to run T1m-Tpm-Ti from scratch
#run_append = True
update_answer = True

if run_ti:
    assert not run_append

if run_append:
    assert not run_ti


mturk_responses = torch.load('../data/csv/MTurk_validation/final_validation_464.pt')
####################################################

_d_coqa = load_dataset('coqa')
data = _d_coqa[data_split]


yn_qa_idx_list = generate_yes_no_index_list(data)
set_seed(0)
sel_idx_list = [random.choice(i) if i else -1 for i in yn_qa_idx_list]


log_file_path = "../data/log"
pt_file_path = "../data/pt"
####################################################
# start error analysis

story_list = []
prev_questions_answers_list = []
question_list = []
answer_list = []
old_sent_list = []
cnt = 0
for i in tqdm(range(data.num_rows)):
    if sel_idx_list[i] == -1:
        continue
    story, question, answer, s, e, old_support_sent = get_various_story_attr(data, i, sel_idx_list[i])
    # process story
    assert old_support_sent in story
    story_split = story.split(old_support_sent)
    story = f'\033[0;31m{old_support_sent}\033[m'.join(story_split)
    # process previous qa
    prev_questions = data['questions'][i][:sel_idx_list[i]]
    prev_answers = data['answers'][i]['input_text'][:sel_idx_list[i]]
    prev_questions_answers = ""
    for j, (q, a) in enumerate(zip(prev_questions, prev_answers)):
        prev_questions_answers += f"Q{j}: {q}\n"
        prev_questions_answers += f"A{j}: {a}\n\n"
    story_list.append(story)
    prev_questions_answers_list.append(prev_questions_answers)
    question_list.append(question)
    answer_list.append(answer)
    old_sent_list.append(old_support_sent)
    cnt += 1

assert cnt == len(story_list) == len(prev_questions_answers_list) == len(question_list) == len(answer_list) == len(old_sent_list)

####################################################

if update_answer:
    answer_list = update_answer_list(data_split, answer_list)
    print('Answer updated!')

####################################################

#from pprint import pprint

# generated from data_error_analysis.py (at the end of the file), for testing that are really hard to modify the exising knowledge and debugging as well
# CAM setting
extremely_hard_index = [9, 12, 27, 40, 41, 44, 47, 49, 66, 71, 72, 81, 107, 116, 120, 121, 126, 129, 130, 134, 136, 147, 162, 167, 168, 190, 198, 214, 215, 222, 234, 237, 243, 248, 257, 258, 268, 273, 274, 282, 285, 292, 301, 310, 313, 317, 325, 338, 341, 353, 356, 385, 388, 391, 394, 404, 406, 409, 417, 435, 436, 461]
second_hard_index = [7, 10, 16, 18, 23, 35, 38, 42, 48, 52, 55, 59, 88, 96, 99, 103, 111, 115, 123, 128, 133, 150, 156, 160, 172, 174, 176, 177, 179, 185, 187, 191, 194, 195, 206, 213, 225, 227, 230, 238, 241, 255, 264, 277, 295, 298, 309, 312, 314, 332, 334, 354, 368, 381, 387, 401, 418, 428, 444, 448, 450, 451, 454, 458, 460, 463]
third_hard_index = [31, 165, 202, 259, 350, 386, 419, 432]
fourth_hard_index = [11, 17, 22, 50, 64, 78, 91, 104, 106, 112, 122, 158, 175, 181, 261, 286, 331, 335, 336, 357, 362, 364, 365, 370, 372, 378, 398, 402, 415, 422, 427, 433]

# CBA setting
#extremely_hard_index = [12, 13, 41, 44, 49, 55, 66, 68, 72, 126, 136, 137, 156, 158, 168, 174, 175, 176, 198, 203, 206, 215, 234, 243, 248, 257, 264, 268, 273, 274, 278, 282, 301, 317, 337, 341, 352, 353, 356, 382, 394, 406, 417, 422, 436, 461, 463]
#second_hard_index = [7, 37, 38, 42, 59, 71, 93, 96, 102, 103, 110, 122, 130, 142, 147, 167, 171, 179, 181, 190, 199, 213, 219, 225, 230, 237, 238, 241, 258, 295, 298, 313, 321, 328, 334, 350, 354, 365, 368, 374, 380, 385, 402, 418, 435, 448, 450, 451, 454, 460]
#third_hard_index = [81, 109, 123, 162, 182, 193, 222, 259, 292, 308, 325, 342, 386, 398]
#fourth_hard_index = [19, 20, 27, 36, 40, 48, 50, 77, 88, 115, 124, 125, 138, 165, 172, 207, 214, 261, 266, 277, 286, 294, 306, 310, 363, 364, 372, 401, 415, 427, 433, 444, 458]

combine_index_list = sorted(list(set(extremely_hard_index + second_hard_index + third_hard_index + fourth_hard_index)))

# CAM setting
idx_list = [7, 9, 10, 11, 12, 16, 17, 18, 22, 23, 27, 31, 35, 38, 40, 41, 42, 44, 47, 48, 49, 50, 52, 55, 59, 64, 66, 71, 72, 78, 81, 88, 91, 96, 99, 103, 104, 106, 107, 111, 112, 115, 116, 120, 121, 122, 123, 126, 128, 129, 130, 133, 134, 136, 147, 150, 156, 158, 160, 162, 165, 167, 168, 172, 174, 175, 176, 177, 179, 181, 185, 187, 190, 191, 194, 195, 198, 202, 206, 213, 214, 215, 222, 225, 227, 230, 234, 237, 238, 241, 243, 248, 255, 257, 258, 259, 261, 264, 268, 273, 274, 277, 282, 285, 286, 292, 295, 298, 301, 309, 310, 312, 313, 314, 317, 325, 331, 332, 334, 335, 336, 338, 341, 350, 353, 354, 356, 357, 362, 364, 365, 368, 370, 372, 378, 381, 385, 386, 387, 388, 391, 394, 398, 401, 402, 404, 406, 409, 415, 417, 418, 419, 422, 427, 428, 432, 433, 435, 436, 444, 448, 450, 451, 454, 458, 460, 461, 463]
# CBA setting
#idx_list = [7, 12, 13, 19, 20, 27, 36, 37, 38, 40, 41, 42, 44, 48, 49, 50, 55, 59, 66, 68, 71, 72, 77, 81, 88, 93, 96, 102, 103, 109, 110, 115, 122, 123, 124, 125, 126, 130, 136, 137, 138, 142, 147, 156, 158, 162, 165, 167, 168, 171, 172, 174, 175, 176, 179, 181, 182, 190, 193, 198, 199, 203, 206, 207, 213, 214, 215, 219, 222, 225, 230, 234, 237, 238, 241, 243, 248, 257, 258, 259, 261, 264, 266, 268, 273, 274, 277, 278, 282, 286, 292, 294, 295, 298, 301, 306, 308, 310, 313, 317, 321, 325, 328, 334, 337, 341, 342, 350, 352, 353, 354, 356, 363, 364, 365, 368, 372, 374, 380, 382, 385, 386, 394, 398, 401, 402, 406, 415, 417, 418, 422, 427, 433, 435, 436, 444, 448, 450, 451, 454, 458, 460, 461, 463]
assert idx_list == combine_index_list
if data_subset == "hard":
    idx_list = idx_list # pass
elif data_subset == "easy":
    idx_list = sorted(list(set(list(range(464))) - set(idx_list))) # remaining
elif data_subset == "all":
    idx_list = list(range(464))
else:
    print("should not be here...")
    assert False


idx_dict = {1: [], 0: []} # 1: success, 0: fail
data_list = []
data_dict = {}

for i in idx_list:
    # keys = ['is_abort', 'index', 'delete_iter', 'modified_story_and_prev_qas', 'all_chat_history']
    _ = torch.load(f'{pt_file_path}/coqa_validation_yes_no_mturk_rewrite_{rewrite_i}_p_c{template_index}_r_d_chatgpt_data_index_{i}_metadata.pt')
    assert _['index'] == i
    if _['is_abort']:
        idx_dict[0].append(i)
    else:
        idx_dict[1].append(i)
    data_dict[i] = _
    data_list.append(_)

data_list.sort(key=lambda x: x['index'])

"""
for i in idx_dict[1]:
    print(f"index: {data_dict[i]['index']}")
    print("story:")
    print(data_dict[i]['modified_story_and_prev_qas'][0])
    pprint(data_dict[i]['modified_story_and_prev_qas'][1:])
    _ = input("==========================")
"""

if run_ti:
    modified_story_list = [i['modified_story_and_prev_qas'][0] for i in data_list]
    modified_prev_questions_answers_list = []
    for i in data_list:
        tmp = []
        for j, v in enumerate(i['modified_story_and_prev_qas']):
            if j == 0:
                continue
            if len(v.split("\n")) >= 3:
                v = list(filter(lambda x: x, v.split("\n")))
                v = '\n'.join(v)
            if (":" in v) and ("Q" in v) and ("A" in v) and (len(v.split("\n")) == 2):
                q, a = v.split("\n")
                if ":" in q:
                    q = q[q.index(":")+1:].strip()
                    tmp.append({"role": "user", "content": q})
                if ":" in a:
                    a = a[a.index(":")+1:].strip()
                    tmp.append({"role": "assistant", "content": a})
            elif "?" in v:
                v = v.replace("??", "?")
                try:
                    q, a = v.split("?")
                    if ":" in q:
                        q = q[q.index(":")+1:].strip()
                        tmp.append({"role": "user", "content": q})
                    if ":" in a:
                        a = a[a.index(":")+1:].strip()
                        tmp.append({"role": "assistant", "content": a})
                except: # if the question contains more than two question marks
                    print(f"Ill format of QA: {v}") # debugging
                    _ = v.split("?")
                    q, a = "?".join(_[:-1]), _[-1]
                    if ":" in q:
                        q = q[q.index(":")+1:].strip()
                        tmp.append({"role": "user", "content": q})
                    if ":" in a:
                        a = a[a.index(":")+1:].strip()
                        tmp.append({"role": "assistant", "content": a})
                finally:
                    pass
            else: # skip invalid QA pair
                continue
        modified_prev_questions_answers_list.append(tmp)

    assert len(modified_story_list) == len(modified_prev_questions_answers_list) == len(idx_list)


if run_ti:
    print(f"Run experiments: MTurk {rewrite_i}")
    # similar to more_error_analysis_mturk.py
    # T1m-Tpm-Ti
    separator = "\n\n==========\n\n"
    # (test) T1 modified
    test_queries_t1_modified_list = [f"Read and memorize the following story.\n\nStory:\n\n{modified_story}{separator}Have you memorized the story?" for modified_story in modified_story_list]
    test_messages_t1_list = [[{"role": "user", "content": queries}] for queries in test_queries_t1_modified_list]
    test_response_t1 = [{"role": "assistant", "content": "Yes, I have memorized the story."}]
    merge_test_messages_t1_list = [test_msg_t1 + test_response_t1 for test_msg_t1 in test_messages_t1_list]
    # (test) Tp (previous QA pairs, consecutive)
    merge_test_messages_tp_list = [queries for queries in modified_prev_questions_answers_list]
    # (test) Ti
    test_queries_ti_list = generate_queries_list.t_Ti(data, sel_idx_list)
    test_queries_ti_list = [queries for i, queries in enumerate(test_queries_ti_list) if i in idx_list]
    assert test_queries_ti_list == [question_list[i] for i in idx_list]
    assert len(merge_test_messages_t1_list) == len(merge_test_messages_tp_list) == len(test_queries_ti_list) #== len(list(filter(lambda x: x != -1, sel_idx_list)))
    # run experiment
    # T1 (modified), Tp (modified), Ti, Ti2 (2nd-stage prompting)
    prev_messages_list = [t1 + tp for t1, tp in zip(merge_test_messages_t1_list, merge_test_messages_tp_list)]
    final_merge_messages_list = [prev + [{"role": "user", "content": queries}] for prev, queries in zip(prev_messages_list, test_queries_ti_list)]
    test_final_response_list = get_response_list_from_chatgpt(final_merge_messages_list)
    test_final_response_text_list = [response['choices'][0]['message']['content'] for response in test_final_response_list]
    # (DONE) NOTE THTAT FILE MAY OVERWRITE HERE! YOU NEED TO ADD DETECTION
    torch.save(test_final_response_text_list, f"{pt_file_path}/coqa_{data_split}_mturk_test_{len(idx_list)}_rewrite_{rewrite_i}_c{template_index}_t1m_tpm_ti_chatgpt.pt")
    # 2nd-stage prompting (answer extraction)
    final_merge_messages_list_2 = [prev + [{"role": "assistant", "content": queries}] for prev, queries in zip(final_merge_messages_list, test_final_response_text_list)]
    answer_extraction_str = "Therefore, based on your previous response, your answer to the last question is more likley to be 'Yes', 'No'? You must output 'Yes' or 'No' first."
    final_merge_messages_list_2 = [prev + [{"role": "user", "content": answer_extraction_str}] for prev in final_merge_messages_list_2]
    test_final_response_list_2 = get_response_list_from_chatgpt(final_merge_messages_list_2)
    test_final_response_text_list_2 = [response['choices'][0]['message']['content'] for response in test_final_response_list_2]
    # NOTE THTAT FILE MAY OVERWRITE HERE! YOU NEED TO ADD DETECTION
    torch.save(test_final_response_text_list_2, f"{pt_file_path}/coqa_{data_split}_mturk_test_{len(idx_list)}_rewrite_{rewrite_i}_c{template_index}_extract_answer_t1m_tpm_ti_chatgpt.pt")

####################################################

if not run_append:
    #print("Before answer extraction:")

    cur_new_ans_list = torch.load(f"{pt_file_path}/coqa_{data_split}_mturk_test_{len(idx_list)}_rewrite_{rewrite_i}_c{template_index}_t1m_tpm_ti_chatgpt.pt")
    cur_new_ans_list_2 = torch.load(f"{pt_file_path}/coqa_{data_split}_mturk_test_{len(idx_list)}_rewrite_{rewrite_i}_c{template_index}_extract_answer_t1m_tpm_ti_chatgpt.pt")
    for i, (x, y) in enumerate(zip(cur_new_ans_list, cur_new_ans_list_2)):
        if (x == "NO MODIFICATION") or ("NO MODIFICATION".lower() in x) or (("yes" not in x[:3].lower()) and ("no" not in x[:2].lower())):
            if "yes" not in y[:3].lower() and "no" not in y[:2].lower():
                # sometimes the answer is at the end..
                if "yes" in y[-6:].lower():
                    y = "Yes. " + y
                elif "no" in y[-5:].lower():
                    y = "No. " + y
            cur_new_ans_list[i] = y
    answer_list_part = [answer_list[i] for i in idx_list]
    assert len(answer_list_part) == len(cur_new_ans_list)
    diff_result = []
    for old_a, new_a in zip(answer_list_part, cur_new_ans_list):
        diff_result.append(compare_old_new_answer(old_a, new_a))

    f = lambda x: round((x/len(diff_result))*100, 2)
    print(f"(+1,-1,0) = ({diff_result.count(1)} ({f(diff_result.count(1))}%), {diff_result.count(-1)} ({f(diff_result.count(-1))}%), {diff_result.count(0)} ({f(diff_result.count(0))}%))")

    ####################################################
    success_run = [1 if not data_dict[i]['is_abort'] else -1 for i in idx_list]

    num_success_run = success_run.count(1)
    num_success_run_and_change = sum([1 for x, y in zip(success_run, diff_result) if x == 1 and y == 1])

    print(f"Percent of change in success run = {num_success_run_and_change/num_success_run} ({num_success_run_and_change}/{num_success_run})")


    """
    ####################################################
    # DO NOT INCLUDE UPDATE?
    ####################################################

    print("After answer extraction:")

    #cur_new_ans_list = torch.load(f"{pt_file_path}/coqa_{data_split}_mturk_test_{len(idx_list)}_rewrite_{rewrite_i}_t1m_tpm_ti_chatgpt.pt")
    cur_new_ans_list = torch.load(f"{pt_file_path}/coqa_{data_split}_mturk_test_{len(idx_list)}_rewrite_{rewrite_i}_extract_answer_t1m_tpm_ti_chatgpt.pt")

    answer_list_part = [answer_list[i] for i in idx_list]
    assert len(answer_list_part) == len(cur_new_ans_list)
    diff_result = []
    for old_a, new_a in zip(answer_list_part, cur_new_ans_list):
        diff_result.append(compare_old_new_answer(old_a, new_a))

    f = lambda x: round((x/len(diff_result))*100, 2)
    print(f"(+1,-1,0) = ({diff_result.count(1)} ({f(diff_result.count(1))}%), {diff_result.count(-1)} ({f(diff_result.count(-1))}%), {diff_result.count(0)} ({f(diff_result.count(0))}%))")

    ####################################################
    success_run = [1 if not data_dict[i]['is_abort'] else -1 for i in idx_list]

    num_success_run = success_run.count(1)
    num_success_run_and_change = sum([1 for x, y in zip(success_run, diff_result) if x == 1 and y == 1])

    print(f"Percent of change in success run = {num_success_run_and_change/num_success_run} ({num_success_run_and_change}/{num_success_run})")

    """

####################################################

if run_append:
    enc = tiktoken.encoding_for_model("gpt-3.5-turbo-0613")
    # e.g., len(enc.encode("hello world")) == 2
    print(f"Run experiments: MTurk {rewrite_i}")
    # (test) Ti
    test_queries_ti_list = generate_queries_list.t_Ti(data, sel_idx_list)
    test_queries_ti_list = [queries for i, queries in enumerate(test_queries_ti_list) if i in idx_list]
    assert test_queries_ti_list == [question_list[i] for i in idx_list]
    exceed_16k = lambda x:  sum([len(_['content']) for _ in x]) + 128 > 16384
    prev_messages_list = [i['all_chat_history'][2:] if exceed_16k(i['all_chat_history']) else i['all_chat_history'] for i in data_list]
    final_merge_messages_list = [prev + [{"role": "user", "content": queries}] for prev, queries in zip(prev_messages_list, test_queries_ti_list)]
    test_final_response_list = get_response_list_from_chatgpt(final_merge_messages_list, model="gpt-3.5-turbo-16k-0613")
    test_final_response_text_list = [response['choices'][0]['message']['content'] for response in test_final_response_list]
    # (DONE) NOTE THTAT FILE MAY OVERWRITE HERE! YOU NEED TO ADD DETECTION
    torch.save(test_final_response_text_list, f"{pt_file_path}/coqa_{data_split}_mturk_test_{len(idx_list)}_rewrite_{rewrite_i}_p_c{template_index}_r_d_append_chatgpt.pt")
    # 2nd-stage prompting (answer extraction)
    final_merge_messages_list_2 = [prev + [{"role": "assistant", "content": queries}] for prev, queries in zip(final_merge_messages_list, test_final_response_text_list)]
    answer_extraction_str = f"Therefore, based on your previous response, your answer to the last question is more likley to be 'Yes', 'No'? You must output 'Yes' or 'No' first."
    final_merge_messages_list_2 = [prev + [{"role": "user", "content": answer_extraction_str}] for prev in final_merge_messages_list_2]
    final_merge_messages_list_2 = [i[2:] if exceed_16k(i) else i for i in final_merge_messages_list_2]
    test_final_response_list_2 = get_response_list_from_chatgpt(final_merge_messages_list_2, model="gpt-3.5-turbo-16k-0613")
    test_final_response_text_list_2 = [response['choices'][0]['message']['content'] for response in test_final_response_list_2]
    # (DONE) NOTE THTAT FILE MAY OVERWRITE HERE! YOU NEED TO ADD DETECTION
    torch.save(test_final_response_text_list_2, f"{pt_file_path}/coqa_{data_split}_mturk_test_{len(idx_list)}_rewrite_{rewrite_i}_p_c{template_index}_r_d_append_extract_answer_chatgpt.pt")
    ####################################################
    cur_new_ans_list = torch.load(f"{pt_file_path}/coqa_{data_split}_mturk_test_{len(idx_list)}_rewrite_{rewrite_i}_p_c{template_index}_r_d_append_chatgpt.pt")
    cur_new_ans_list_2 = torch.load(f"{pt_file_path}/coqa_{data_split}_mturk_test_{len(idx_list)}_rewrite_{rewrite_i}_p_c{template_index}_r_d_append_extract_answer_chatgpt.pt")
    # If the 1st stage prompting (Ti) does not start with Yes/No, we use the response in 2nd stage prompting instead. Otherwise, the output of 2nd stage prompting is ignored.
    for i, (x, y) in enumerate(zip(cur_new_ans_list, cur_new_ans_list_2)):
        if (x == "NO MODIFICATION") or ("NO MODIFICATION".lower() in x) or (("yes" not in x[:3].lower()) and ("no" not in x[:2].lower())):
            if "yes" not in y[:3].lower() and "no" not in y[:2].lower():
                # sometimes the answer is at the end..
                if "yes" in y[-6:].lower():
                    y = "Yes. " + y
                elif "no" in y[-5:].lower():
                    y = "No. " + y
            cur_new_ans_list[i] = y
    answer_list_part = [answer_list[i] for i in idx_list]
    assert len(answer_list_part) == len(cur_new_ans_list)
    diff_result = []
    for old_a, new_a in zip(answer_list_part, cur_new_ans_list):
        diff_result.append(compare_old_new_answer(old_a, new_a))
    f = lambda x: round((x/len(diff_result))*100, 2)
    print(f"(+1,-1,0) = ({diff_result.count(1)} ({f(diff_result.count(1))}%), {diff_result.count(-1)} ({f(diff_result.count(-1))}%), {diff_result.count(0)} ({f(diff_result.count(0))}%))")
    success_run = [1 if not data_dict[i]['is_abort'] else -1 for i in idx_list]
    num_success_run = success_run.count(1)
    num_success_run_and_change = sum([1 for x, y in zip(success_run, diff_result) if x == 1 and y == 1])
    print(f"Percent of change in success run = {num_success_run_and_change/num_success_run} ({num_success_run_and_change}/{num_success_run})")
