import torch

data_split = "validation"
rewrite_i = 0 # change to 0, 1, 2

log_file_path = "../data/log"
pt_file_path = "../data/pt"

####################################################

# CAM setting
extremely_hard_index = [9, 12, 27, 40, 41, 44, 47, 49, 66, 71, 72, 81, 107, 116, 120, 121, 126, 129, 130, 134, 136, 147, 162, 167, 168, 190, 198, 214, 215, 222, 234, 237, 243, 248, 257, 258, 268, 273, 274, 282, 285, 292, 301, 310, 313, 317, 325, 338, 341, 353, 356, 385, 388, 391, 394, 404, 406, 409, 417, 435, 436, 461]
second_hard_index = [7, 10, 16, 18, 23, 35, 38, 42, 48, 52, 55, 59, 88, 96, 99, 103, 111, 115, 123, 128, 133, 150, 156, 160, 172, 174, 176, 177, 179, 185, 187, 191, 194, 195, 206, 213, 225, 227, 230, 238, 241, 255, 264, 277, 295, 298, 309, 312, 314, 332, 334, 354, 368, 381, 387, 401, 418, 428, 444, 448, 450, 451, 454, 458, 460, 463]
third_hard_index = [31, 165, 202, 259, 350, 386, 419, 432]
fourth_hard_index = [11, 17, 22, 50, 64, 78, 91, 104, 106, 112, 122, 158, 175, 181, 261, 286, 331, 335, 336, 357, 362, 364, 365, 370, 372, 378, 398, 402, 415, 422, 427, 433]

# CBA setting
#extremely_hard_index = [12, 13, 41, 44, 49, 55, 66, 68, 72, 126, 136, 137, 156, 158, 168, 174, 175, 176, 198, 203, 206, 215, 234, 243, 248, 257, 264, 268, 273, 274, 278, 282, 301, 317, 337, 341, 352, 353, 356, 382, 394, 406, 417, 422, 436, 461, 463]
#second_hard_index = [7, 37, 38, 42, 59, 71, 93, 96, 102, 103, 110, 122, 130, 142, 147, 167, 171, 179, 181, 190, 199, 213, 219, 225, 230, 237, 238, 241, 258, 295, 298, 313, 321, 328, 334, 350, 354, 365, 368, 374, 380, 385, 402, 418, 435, 448, 450, 451, 454, 460]
#third_hard_index = [81, 109, 123, 162, 182, 193, 222, 259, 292, 308, 325, 342, 386, 398]
#fourth_hard_index = [19, 20, 27, 36, 40, 48, 50, 77, 88, 115, 124, 125, 138, 165, 172, 207, 214, 261, 266, 277, 286, 294, 306, 310, 363, 364, 372, 401, 415, 427, 433, 444, 458]

combine_index_list = sorted(list(set(extremely_hard_index + second_hard_index + third_hard_index + fourth_hard_index)))

# CAM seting
idx_list = [7, 9, 10, 11, 12, 16, 17, 18, 22, 23, 27, 31, 35, 38, 40, 41, 42, 44, 47, 48, 49, 50, 52, 55, 59, 64, 66, 71, 72, 78, 81, 88, 91, 96, 99, 103, 104, 106, 107, 111, 112, 115, 116, 120, 121, 122, 123, 126, 128, 129, 130, 133, 134, 136, 147, 150, 156, 158, 160, 162, 165, 167, 168, 172, 174, 175, 176, 177, 179, 181, 185, 187, 190, 191, 194, 195, 198, 202, 206, 213, 214, 215, 222, 225, 227, 230, 234, 237, 238, 241, 243, 248, 255, 257, 258, 259, 261, 264, 268, 273, 274, 277, 282, 285, 286, 292, 295, 298, 301, 309, 310, 312, 313, 314, 317, 325, 331, 332, 334, 335, 336, 338, 341, 350, 353, 354, 356, 357, 362, 364, 365, 368, 370, 372, 378, 381, 385, 386, 387, 388, 391, 394, 398, 401, 402, 404, 406, 409, 415, 417, 418, 419, 422, 427, 428, 432, 433, 435, 436, 444, 448, 450, 451, 454, 458, 460, 461, 463]
# CBA setting
#idx_list = [7, 12, 13, 19, 20, 27, 36, 37, 38, 40, 41, 42, 44, 48, 49, 50, 55, 59, 66, 68, 71, 72, 77, 81, 88, 93, 96, 102, 103, 109, 110, 115, 122, 123, 124, 125, 126, 130, 136, 137, 138, 142, 147, 156, 158, 162, 165, 167, 168, 171, 172, 174, 175, 176, 179, 181, 182, 190, 193, 198, 199, 203, 206, 207, 213, 214, 215, 219, 222, 225, 230, 234, 237, 238, 241, 243, 248, 257, 258, 259, 261, 264, 266, 268, 273, 274, 277, 278, 282, 286, 292, 294, 295, 298, 301, 306, 308, 310, 313, 317, 321, 325, 328, 334, 337, 341, 342, 350, 352, 353, 354, 356, 363, 364, 365, 368, 372, 374, 380, 382, 385, 386, 394, 398, 401, 402, 406, 415, 417, 418, 422, 427, 433, 435, 436, 444, 448, 450, 451, 454, 458, 460, 461, 463]
assert idx_list == combine_index_list

easy_idx_list = sorted(list(set(list(range(464))) - set(idx_list)))


cur_new_ans_list = torch.load(f"{pt_file_path}/coqa_{data_split}_mturk_test_{len(idx_list)}_rewrite_{rewrite_i}_t1m_tpm_ti_chatgpt.pt")
easy_cur_new_ans_list = torch.load(f"{pt_file_path}/coqa_{data_split}_mturk_test_{len(easy_idx_list)}_rewrite_{rewrite_i}_t1m_tpm_ti_chatgpt.pt")
# extract answer
#cur_new_ans_list = torch.load(f"{pt_file_path}/coqa_{data_split}_mturk_test_{len(idx_list)}_rewrite_{rewrite_i}_extract_answer_t1m_tpm_ti_chatgpt.pt")
#easy_cur_new_ans_list = torch.load(f"{pt_file_path}/coqa_{data_split}_mturk_test_{len(easy_idx_list)}_rewrite_{rewrite_i}_extract_answer_t1m_tpm_ti_chatgpt.pt")



ret = []
while idx_list and easy_idx_list:
    assert idx_list[0] != easy_idx_list[0]
    if idx_list[0] < easy_idx_list[0]:
        ret.append(cur_new_ans_list[0])
        del cur_new_ans_list[0]
        del idx_list[0]
    else:
        ret.append(easy_cur_new_ans_list[0])
        del easy_cur_new_ans_list[0]
        del easy_idx_list[0]

while idx_list:
    ret.append(cur_new_ans_list[0])
    del cur_new_ans_list[0]
    del idx_list[0]

while easy_idx_list:
    ret.append(easy_cur_new_ans_list[0])
    del easy_cur_new_ans_list[0]
    del easy_idx_list[0]


torch.save(ret, f"{pt_file_path}/coqa_{data_split}_mturk_test_{len(ret)}_rewrite_{rewrite_i}_t1m_tpm_ti_chatgpt.pt")
# extract answer
#torch.save(ret, f"{pt_file_path}/coqa_{data_split}_mturk_test_{len(ret)}_rewrite_{rewrite_i}_extract_answer_t1m_tpm_ti_chatgpt.pt")