import argparse
import random

import torch
from tqdm import tqdm

from datasets import load_dataset

from set_seed import set_seed
from generate_yes_no_index_list import generate_yes_no_index_list
from get_various_story_attr import get_various_story_attr
import generate_queries_list
from compare_old_new_answer import compare_old_new_answer
from update_answer_list import update_answer_list

####################################################

def parse_argument():
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument('--data_split', required=True, choices=['train', 'validation'], help='choose training set or validation set in CoQA dataset')
    parser.add_argument('--data_subset', required=True, choices=['all', 'hard', 'easy'], help='choose training subset or validation subset in CoQA dataset')
    parser.add_argument('--setting', required=True, choices=['CAM', 'CBA'], help='choose CAM or CBA setting in Recall methodology')
    parser.add_argument('--conv_flow', required=True, choices=['Tfm-Tpm-Ti', 'Tf-Tp-Tc-Tr-Td-Ti'], help='choose Deletion methodology and evaluate it')
    #parser.add_argument('--update_answer', action='store_true', help='Whether to update answer in train/validation answer (ensure that the gold answers will start with Yes/No)')
    parser.add_argument('--rewrite_i', required=True, type=int, choices=range(0, 3), help='Choose 1st, 2nd, or 3rd response of MTurk')
    args = parser.parse_args()
    return args

args = parse_argument()
data_split = args.data_split # "train", "validation"
data_subset = args.data_subset # "all", "hard", "easy"
setting = args.setting # "CAM", "CBA"
conv_flow = args.conv_flow # "Tfm-Tpm-Ti", "Tf-Tp-Tc-Tr-Td-Ti"
rewrite_i = args.rewrite_i

update_answer = True
mturk_responses = torch.load('../data/csv/MTurk_validation/final_validation_464.pt')

####################################################

_d_coqa = load_dataset('coqa')
data = _d_coqa[data_split]


yn_qa_idx_list = generate_yes_no_index_list(data)
set_seed(0)
sel_idx_list = [random.choice(i) if i else -1 for i in yn_qa_idx_list]


log_file_path = "../data/log"
pt_file_path = "../data/pt"
####################################################
# start error analysis

story_list = []
prev_questions_answers_list = []
question_list = []
answer_list = []
old_sent_list = []
cnt = 0
for i in tqdm(range(data.num_rows)):
    if sel_idx_list[i] == -1:
        continue
    story, question, answer, s, e, old_support_sent = get_various_story_attr(data, i, sel_idx_list[i])
    # process story
    assert old_support_sent in story
    story_split = story.split(old_support_sent)
    story = f'\033[0;31m{old_support_sent}\033[m'.join(story_split)
    # process previous qa
    prev_questions = data['questions'][i][:sel_idx_list[i]]
    prev_answers = data['answers'][i]['input_text'][:sel_idx_list[i]]
    prev_questions_answers = ""
    for j, (q, a) in enumerate(zip(prev_questions, prev_answers)):
        prev_questions_answers += f"Q{j}: {q}\n"
        prev_questions_answers += f"A{j}: {a}\n\n"
    story_list.append(story)
    prev_questions_answers_list.append(prev_questions_answers)
    question_list.append(question)
    answer_list.append(answer)
    old_sent_list.append(old_support_sent)
    cnt += 1

assert cnt == len(story_list) == len(prev_questions_answers_list) == len(question_list) == len(answer_list) == len(old_sent_list)

####################################################

if update_answer:
    answer_list = update_answer_list(data_split, answer_list)
    print('Answer updated!')

####################################################

#from pprint import pprint

# generated from data_error_analysis.py (at the end of the file), for testing that are really hard to modify the exising knowledge and debugging as well
if setting == "CAM":
    extremely_hard_index = [9, 12, 27, 40, 41, 44, 47, 49, 66, 71, 72, 81, 107, 116, 120, 121, 126, 129, 130, 134, 136, 147, 162, 167, 168, 190, 198, 214, 215, 222, 234, 237, 243, 248, 257, 258, 268, 273, 274, 282, 285, 292, 301, 310, 313, 317, 325, 338, 341, 353, 356, 385, 388, 391, 394, 404, 406, 409, 417, 435, 436, 461]
    second_hard_index = [7, 10, 16, 18, 23, 35, 38, 42, 48, 52, 55, 59, 88, 96, 99, 103, 111, 115, 123, 128, 133, 150, 156, 160, 172, 174, 176, 177, 179, 185, 187, 191, 194, 195, 206, 213, 225, 227, 230, 238, 241, 255, 264, 277, 295, 298, 309, 312, 314, 332, 334, 354, 368, 381, 387, 401, 418, 428, 444, 448, 450, 451, 454, 458, 460, 463]
    third_hard_index = [31, 165, 202, 259, 350, 386, 419, 432]
    fourth_hard_index = [11, 17, 22, 50, 64, 78, 91, 104, 106, 112, 122, 158, 175, 181, 261, 286, 331, 335, 336, 357, 362, 364, 365, 370, 372, 378, 398, 402, 415, 422, 427, 433]
else:
    assert setting == "CBA"
    extremely_hard_index = [12, 13, 41, 44, 49, 55, 66, 68, 72, 126, 136, 137, 156, 158, 168, 174, 175, 176, 198, 203, 206, 215, 234, 243, 248, 257, 264, 268, 273, 274, 278, 282, 301, 317, 337, 341, 352, 353, 356, 382, 394, 406, 417, 422, 436, 461, 463]
    second_hard_index = [7, 37, 38, 42, 59, 71, 93, 96, 102, 103, 110, 122, 130, 142, 147, 167, 171, 179, 181, 190, 199, 213, 219, 225, 230, 237, 238, 241, 258, 295, 298, 313, 321, 328, 334, 350, 354, 365, 368, 374, 380, 385, 402, 418, 435, 448, 450, 451, 454, 460]
    third_hard_index = [81, 109, 123, 162, 182, 193, 222, 259, 292, 308, 325, 342, 386, 398]
    fourth_hard_index = [19, 20, 27, 36, 40, 48, 50, 77, 88, 115, 124, 125, 138, 165, 172, 207, 214, 261, 266, 277, 286, 294, 306, 310, 363, 364, 372, 401, 415, 427, 433, 444, 458]

combine_index_list = sorted(list(set(extremely_hard_index + second_hard_index + third_hard_index + fourth_hard_index)))

if setting == "CAM":
    idx_list = [7, 9, 10, 11, 12, 16, 17, 18, 22, 23, 27, 31, 35, 38, 40, 41, 42, 44, 47, 48, 49, 50, 52, 55, 59, 64, 66, 71, 72, 78, 81, 88, 91, 96, 99, 103, 104, 106, 107, 111, 112, 115, 116, 120, 121, 122, 123, 126, 128, 129, 130, 133, 134, 136, 147, 150, 156, 158, 160, 162, 165, 167, 168, 172, 174, 175, 176, 177, 179, 181, 185, 187, 190, 191, 194, 195, 198, 202, 206, 213, 214, 215, 222, 225, 227, 230, 234, 237, 238, 241, 243, 248, 255, 257, 258, 259, 261, 264, 268, 273, 274, 277, 282, 285, 286, 292, 295, 298, 301, 309, 310, 312, 313, 314, 317, 325, 331, 332, 334, 335, 336, 338, 341, 350, 353, 354, 356, 357, 362, 364, 365, 368, 370, 372, 378, 381, 385, 386, 387, 388, 391, 394, 398, 401, 402, 404, 406, 409, 415, 417, 418, 419, 422, 427, 428, 432, 433, 435, 436, 444, 448, 450, 451, 454, 458, 460, 461, 463]
else:
    assert setting == "CBA"
    idx_list = [7, 12, 13, 19, 20, 27, 36, 37, 38, 40, 41, 42, 44, 48, 49, 50, 55, 59, 66, 68, 71, 72, 77, 81, 88, 93, 96, 102, 103, 109, 110, 115, 122, 123, 124, 125, 126, 130, 136, 137, 138, 142, 147, 156, 158, 162, 165, 167, 168, 171, 172, 174, 175, 176, 179, 181, 182, 190, 193, 198, 199, 203, 206, 207, 213, 214, 215, 219, 222, 225, 230, 234, 237, 238, 241, 243, 248, 257, 258, 259, 261, 264, 266, 268, 273, 274, 277, 278, 282, 286, 292, 294, 295, 298, 301, 306, 308, 310, 313, 317, 321, 325, 328, 334, 337, 341, 342, 350, 352, 353, 354, 356, 363, 364, 365, 368, 372, 374, 380, 382, 385, 386, 394, 398, 401, 402, 406, 415, 417, 418, 422, 427, 433, 435, 436, 444, 448, 450, 451, 454, 458, 460, 461, 463]
assert idx_list == combine_index_list
if data_subset == "hard":
    idx_list = idx_list # pass
elif data_subset == "easy":
    idx_list = sorted(list(set(list(range(464))) - set(idx_list))) # remaining
elif data_subset == "all":
    idx_list = list(range(464))
else:
    print("should not be here...")
    assert False


idx_dict = {1: [], 0: []} # 1: success, 0: fail
data_list = []
data_dict = {}

for i in idx_list:
    # keys = ['is_abort', 'index', 'delete_iter', 'modified_story_and_prev_qas', 'all_chat_history']
    _ = torch.load(f'{pt_file_path}/coqa_validation_yes_no_mturk_rewrite_{rewrite_i}_p_c7_r_d_chatgpt_data_index_{i}_metadata.pt')
    assert _['index'] == i
    if _['is_abort']:
        idx_dict[0].append(i)
    else:
        idx_dict[1].append(i)
    data_dict[i] = _
    data_list.append(_)

data_list.sort(key=lambda x: x['index'])

"""
for i in idx_dict[1]:
    print(f"index: {data_dict[i]['index']}")
    print("story:")
    print(data_dict[i]['modified_story_and_prev_qas'][0])
    pprint(data_dict[i]['modified_story_and_prev_qas'][1:])
    _ = input("==========================")
"""

if conv_flow == "Tfm-Tpm-Ti":
    cur_new_ans_list = torch.load(f"{pt_file_path}/coqa_{data_split}_mturk_test_464_rewrite_{rewrite_i}_t1m_tpm_ti_chatgpt.pt")
    cur_new_ans_list_2 = torch.load(f"{pt_file_path}/coqa_{data_split}_mturk_test_464_rewrite_{rewrite_i}_extract_answer_t1m_tpm_ti_chatgpt.pt")
    for i, (x, y) in enumerate(zip(cur_new_ans_list, cur_new_ans_list_2)):
        if (x == "NO MODIFICATION") or ("NO MODIFICATION".lower() in x) or (("yes" not in x[:3].lower()) and ("no" not in x[:2].lower())):
            if "yes" not in y[:3].lower() and "no" not in y[:2].lower():
                # sometimes the answer is at the end..
                if "yes" in y[-6:].lower():
                    y = "Yes. " + y
                elif "no" in y[-5:].lower():
                    y = "No. " + y
            cur_new_ans_list[i] = y
    cur_new_ans_list_part = [cur_new_ans_list[i] for i in idx_list]
    answer_list_part = [answer_list[i] for i in idx_list]
    assert len(answer_list_part) == len(cur_new_ans_list_part)
    diff_result = []
    for old_a, new_a in zip(answer_list_part, cur_new_ans_list_part):
        diff_result.append(compare_old_new_answer(old_a, new_a))

    f = lambda x: round((x/len(diff_result))*100, 2)
    print(f"(+1,-1,0) = ({diff_result.count(1)} ({f(diff_result.count(1))}%), {diff_result.count(-1)} ({f(diff_result.count(-1))}%), {diff_result.count(0)} ({f(diff_result.count(0))}%))")

    ####################################################
    success_run = [1 if not data_dict[i]['is_abort'] else -1 for i in idx_list]

    num_success_run = success_run.count(1)
    num_success_run_and_change = sum([1 for x, y in zip(success_run, diff_result) if x == 1 and y == 1])

    print(f"Percent of change in success run = {num_success_run_and_change/num_success_run} ({num_success_run_and_change}/{num_success_run})")


####################################################

if conv_flow == "Tf-Tp-Tc-Tr-Td-Ti":
    cur_new_ans_list = torch.load(f"{pt_file_path}/coqa_{data_split}_mturk_test_464_rewrite_{rewrite_i}_p_c7_r_d_append_chatgpt.pt")
    cur_new_ans_list_2 = torch.load(f"{pt_file_path}/coqa_{data_split}_mturk_test_464_rewrite_{rewrite_i}_p_c7_r_d_append_extract_answer_chatgpt.pt")
    # If the 1st stage prompting (Ti) does not start with Yes/No, we use the response in 2nd stage prompting instead. Otherwise, the output of 2nd stage prompting is ignored.
    for i, (x, y) in enumerate(zip(cur_new_ans_list, cur_new_ans_list_2)):
        if (x == "NO MODIFICATION") or ("NO MODIFICATION".lower() in x) or (("yes" not in x[:3].lower()) and ("no" not in x[:2].lower())):
            if "yes" not in y[:3].lower() and "no" not in y[:2].lower():
                # sometimes the answer is at the end..
                if "yes" in y[-6:].lower():
                    y = "Yes. " + y
                elif "no" in y[-5:].lower():
                    y = "No. " + y
            cur_new_ans_list[i] = y
    cur_new_ans_list_part = [cur_new_ans_list[i] for i in idx_list]
    answer_list_part = [answer_list[i] for i in idx_list]
    assert len(answer_list_part) == len(cur_new_ans_list_part)
    diff_result = []
    for old_a, new_a in zip(answer_list_part, cur_new_ans_list_part):
        diff_result.append(compare_old_new_answer(old_a, new_a))
    f = lambda x: round((x/len(diff_result))*100, 2)
    print(f"(+1,-1,0) = ({diff_result.count(1)} ({f(diff_result.count(1))}%), {diff_result.count(-1)} ({f(diff_result.count(-1))}%), {diff_result.count(0)} ({f(diff_result.count(0))}%))")
    success_run = [1 if not data_dict[i]['is_abort'] else -1 for i in idx_list]
    num_success_run = success_run.count(1)
    num_success_run_and_change = sum([1 for x, y in zip(success_run, diff_result) if x == 1 and y == 1])
    print(f"Percent of change in success run = {num_success_run_and_change/num_success_run} ({num_success_run_and_change}/{num_success_run})")
