import argparse
import random
import os

import torch
from datasets import load_dataset

from set_seed import set_seed
from openai_api_setup import get_response_list_from_chatgpt
##### newly added #####
from google_api_setup import get_response_list_from_gemma_2 # do not test gemini in Tr
from fastchat_api_setup import get_response_list_from_vicuna, get_response_list_from_llama_2
from meta_api_setup import get_response_list_from_llama_3
from transformers import AutoModelForCausalLM, AutoTokenizer # for Llama-3 and Gemma-2
##### newly added #####
from generate_yes_no_index_list import generate_yes_no_index_list
import generate_queries_list
from save_data import save_data
from get_templates_list_from_daily_dialog import get_templates_list_from_daily_dialog

####################################################
def parse_argument():
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument('--data_split', required=True, choices=['train', 'validation'], help='choose training set or validation set in CoQA dataset')
    parser.add_argument('--model', required=True, choices=['gpt-4-1106-preview', 'gpt-3.5-turbo-0613', 'gpt-3.5-turbo-1106', 'gpt-3.5-turbo-0125', 'gpt-4o-mini-2024-07-18', 'gemma-2-2b-it', 'gemma-2-9b-it', 'gemma-2-27b-it', 'vicuna-7b-v1.5-16k', 'vicuna-13b-v1.5-16k', 'vicuna-33b-v1.3', 'Llama-2-7b-chat-hf', 'Llama-2-13b-chat-hf', 'Meta-Llama-3-8B-Instruct'], help='specify the model, note that some models in choices may be deprecated in the future, now support models other than GPT') # newly added
    parser.add_argument('--is_trim', action='store_true', help='Whether to trim the story (so original support sentence is always at the end of the story)')
    #parser.add_argument('--instruct_update', action='store_true', help='Whether to instruct model to output new story and focus on correction')
    parser.add_argument('--run_ti', action='store_true', help='Whether to test Ti turn')
    parser.add_argument('--run_oracle', action='store_true', help='Whether to test the oracle performance of Recall (i.e., the old knowledge in the story is automatically replaced by us). Note that this flag can not be set along with run_ti')
    # the following two arguments are only for vicuna and Llama-2
    parser.add_argument('--num_gpus', type=int, help='choose number of gpus (only available when model is vicuna, LLama-2)')
    parser.add_argument('--max_gpu_memory', type=int, help='set maximum memory (GB) for each gpu, for example, set 14 for num_gpus=2 when running vicuna-13b-v1.5-16k, or set 7 for num_gpus=4 (float16 is used in FastChat)')
    # these arguments are ONLY for resumed inference (normally you don't need to set this)... due to unexpected error like server crashing...
    parser.add_argument('--resume_rewrite_i', type=int, choices=[0, 1, 2], help='choose the first (0), second (1), or last (2) mturk rewrite new knowledge to start with.') # newly added (09/09/2024)
    parser.add_argument('--resume_setting', choices=['t1-tc-tr', 't1-tp-tc-tr'], help='choose the experimental setting to start with (\'t1-tc-tr\' = CAM, \'t1-tp-tc-tr\' = CBA).') # newly added (09/09/2024)
    parser.add_argument('--resume_correction_templates', type=int, choices=range(0,15), help='choose the correction templates to start with.') # newly added (09/09/2024)
    args = parser.parse_args()
    return args

args = parse_argument()
data_split = args.data_split # "train", "validation"
model = args.model
is_trim = args.is_trim
#instruct_update = args.instruct_update
instruct_update = True
run_ti = args.run_ti
run_oracle = args.run_oracle

assert not (run_ti and run_oracle) # if one is set, then the other must be false (None)

resume_rewrite_i = args.resume_rewrite_i # newly added (09/09/2024)
resume_setting = args.resume_setting # newly added (09/09/2024)
resume_correction_templates = args.resume_correction_templates # newly added (09/09/2024)
assert all(_ is None for _ in [resume_rewrite_i, resume_correction_templates, resume_setting]) or any(_ is not None for _ in [resume_rewrite_i, resume_correction_templates, resume_setting]) # newly added (09/09/2024)

num_gpus = args.num_gpus
max_gpu_memory = args.max_gpu_memory
if 'vicuna' in model or 'Llama-2' in model or 'Llama-3' in model or 'gemma' in model:
    print("Set GPU configuration...")
    print(f"num_gpus: {num_gpus}, max_gpu_memory: {max_gpu_memory}GB")
else:
    if num_gpus or max_gpu_memory:
        print(f"You set GPU-related parameters, but they are not used in {model}...")

# not sure why load model in get_response_list_from_llama_3 will have WARNING:root:Some parameters are on the meta device device because they were offloaded to the cpu.
# guess the memory is not released fast enough...?
if 'Llama-3' in model:
    model_id = 'meta-llama/' + model
    device = torch.device('cuda')
    print(f'Load {model}\'s model and tokenizer...')
    hf_llama_3_tokenizer = AutoTokenizer.from_pretrained(model_id)
    #print(f'Tokenizer\'s chat_template:', tokenizer.chat_template)
    assert hf_llama_3_tokenizer.chat_template is not None # ensure the template they trained with exists, or we have to do a lot of prompt engineering...
    hf_llama_3_model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="auto")
    hf_llama_3_model.generation_config.temperature = 1e-8 # ValueError: `temperature` (=0) has to be a strictly positive float, otherwise your next token scores will be invalid.
    hf_llama_3_model.generation_config.top_p = 1
    print('Model\'s generation_config:')
    print(hf_llama_3_model.generation_config.to_dict())

if 'gemma' in model:
    model_id = 'google/' + model
    device = torch.device('cuda')
    print(f'Load {model}\'s model and tokenizer...')
    hf_gemma_2_tokenizer = AutoTokenizer.from_pretrained(model_id)
    assert hf_gemma_2_tokenizer.chat_template is not None # ensure the template they trained with exists, or we have to do a lot of prompt engineering...
    hf_gemma_2_model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="auto")
    hf_gemma_2_model.generation_config.temperature = 0
    hf_gemma_2_model.generation_config.top_p = 1
    print('Model\'s generation_config:')
    print(hf_gemma_2_model.generation_config.to_dict())

####################################################

_d_coqa = load_dataset('coqa')
data = _d_coqa[data_split]


yn_qa_idx_list = generate_yes_no_index_list(data)
set_seed(0)
sel_idx_list = [random.choice(i) if i else -1 for i in yn_qa_idx_list]


# log_file_path = "../data/log"
# pt_file_path = "../data/pt"
#log_file_path = f"../data/log/{model}/{data_split}"
pt_file_path = f"../data/pt/{model}/{data_split}"

#assert os.path.exists(log_file_path)
assert os.path.exists(pt_file_path)


if data_split == "validation":
    num_data = 464
else:
    assert data_split == "train"
    num_data = 1317


mturk_responses = torch.load(f'../data/csv/MTurk_{data_split}/final_{data_split}_{num_data}.pt')
#assert num_data == len(list(filter(lambda x: x != -1, sel_idx_list))) # does not work on train because we haven't labelled all data
assert all(len(mturk_responses[data_split][k]) == 3 for k in mturk_responses[data_split])

if data_split == "validation":
    mturk_rewrite_0_list = [mturk_responses[data_split][k][0] for k in range(num_data)]
    mturk_rewrite_1_list = [mturk_responses[data_split][k][1] for k in range(num_data)]
    mturk_rewrite_2_list = [mturk_responses[data_split][k][2] for k in range(num_data)]
else:
    assert data_split == "train"
    _ = sorted(mturk_responses[data_split].keys()) # [5000, 5001, ..., 6316]
    mturk_rewrite_0_list = [mturk_responses[data_split][k][0] for k in _]
    mturk_rewrite_1_list = [mturk_responses[data_split][k][1] for k in _]
    mturk_rewrite_2_list = [mturk_responses[data_split][k][2] for k in _]

mturk_rewrite_lists = [mturk_rewrite_0_list,
                       mturk_rewrite_1_list,
                       mturk_rewrite_2_list
                      ]

templates_list_daily_dialog = get_templates_list_from_daily_dialog()
# If you want to run a subset, YOU MUST change the file name to "..._c{top_k_idx[i]}_..." (not "c{i}")
#top_k_idx = list(range(15)) # [0, 4, 6, 9, 10, 11, 12, 13, 14] # [5, 8, 3] # [7, 2, 1] # test all in the future
#templates_list_daily_dialog = [templates_list_daily_dialog[i] for i in top_k_idx]
#assert len(top_k_idx) == len(templates_list_daily_dialog)


####################################################

def determine_saved_file_name_for_vicuna_and_llama_2(data_split, rewrite_i, i, method, is_trim, run_ti, run_oracle):
    if not run_ti and not run_oracle:
        if method == "t1-tc-tr":
            return f"coqa_{data_split}_yes_no_mturk_rewrite_{rewrite_i}_c{i}_r_chatgpt_recall_story_trim_{is_trim}.pt"
        else:
            return f"coqa_{data_split}_yes_no_mturk_rewrite_{rewrite_i}_p_c{i}_r_chatgpt_recall_story_trim_{is_trim}.pt"
    elif run_ti:
        if method == "t1-tc-tr": # "t1-tc-tr-tp-ti"
            return f"coqa_{data_split}_yes_no_mturk_rewrite_{rewrite_i}_c{i}_r_p_chatgpt.pt"
        else: # "t1-tp-tc-tr-ti"
            return f"coqa_{data_split}_yes_no_mturk_rewrite_{rewrite_i}_p_c{i}_r_chatgpt.pt"
    else:
        assert run_oracle
        if method == "t1-tc-tr": # "t1-tc-tr-tp-ti"
            return f"coqa_{data_split}_yes_no_mturk_rewrite_{rewrite_i}_c{i}_r_p_oracle_chatgpt.pt"
        else: # "t1-tp-tc-tr-ti"
            return f"coqa_{data_split}_yes_no_mturk_rewrite_{rewrite_i}_p_c{i}_r_oracle_chatgpt.pt"


####################################################
# Test if ChatGPT can update answer based on new story

for rewrite_i in range(len(mturk_rewrite_lists)):
    if resume_rewrite_i and rewrite_i < resume_rewrite_i: # newly added (09/09/2024)
        continue # newly added (09/09/2024)
    print(f"MTurk rewrite {rewrite_i}, data_split={data_split}, is_trim={is_trim}, run_ti={run_ti}, run_oracle={run_oracle}, model={model}")
    mturk_rewrite_list = mturk_rewrite_lists[rewrite_i]
    # (test) T1
    if is_trim == True:
        if data_split == "train":
            test_queries_t1_list = generate_queries_list.t_T1(data, sel_idx_list, is_trim=is_trim, start_idx=5000, data_split=data_split)
        else:
            assert data_split == "validation"
            test_queries_t1_list = generate_queries_list.t_T1(data, sel_idx_list, is_trim=is_trim)
    else:
        if data_split == "train":
            test_queries_t1_list = generate_queries_list.t_T1(data, sel_idx_list, start_idx=5000, data_split=data_split)
        else:
            assert data_split == "validation"
            test_queries_t1_list = generate_queries_list.t_T1(data, sel_idx_list)
    test_messages_t1_list = [[{"role": "user", "content": queries}] for queries in test_queries_t1_list]
    test_response_t1 = [{"role": "assistant", "content": "Yes, I have memorized the story."}]
    merge_test_messages_t1_list = [test_msg_t1 + test_response_t1 for test_msg_t1 in test_messages_t1_list]
    # (test) Tp (previous QA pairs, consecutive)
    if data_split == "train":
        merge_test_messages_tp_list = generate_queries_list.t_Tp(data, sel_idx_list, start_idx=5000, data_split=data_split)
    else:
        assert data_split == "validation"
        merge_test_messages_tp_list = generate_queries_list.t_Tp(data, sel_idx_list)
    # (test) Tc (ideally, can insert anywhere)
    if data_split == "train":
        test_queries_correct_list_ensemble = generate_queries_list.t_Tc_ensemble(data, sel_idx_list, templates_list_daily_dialog, mturk_rewrite_list, start_idx=5000, data_split=data_split)
    else:
        assert data_split == "validation"
        test_queries_correct_list_ensemble = generate_queries_list.t_Tc_ensemble(data, sel_idx_list, templates_list_daily_dialog, mturk_rewrite_list)
    test_messages_correct_list_ensemble = [[[{"role": "user", "content": cur_query}] for cur_query in queries] for queries in test_queries_correct_list_ensemble]
    # play with ChatGPT and choose one reasonable response beforehand
    test_response_correct = [{"role": "assistant", "content": "No problem at all! I have updated my memory of the story with the correction you provided. Thank you for letting me know."}]
    merge_test_messages_tc_list_ensemble = [[j + test_response_correct for j in i] for i in test_messages_correct_list_ensemble]
    if run_ti or run_oracle:
        # (test) Ti
        if data_split == "train":
            test_queries_ti_list = generate_queries_list.t_Ti(data, sel_idx_list, start_idx=5000, data_split=data_split)
        else:
            assert data_split == "validation"
            test_queries_ti_list = generate_queries_list.t_Ti(data, sel_idx_list)
    # (test) Tr (recall, ask ChatGPT to summairze the previous story)
    if instruct_update:
        if data_split == "train":
            test_queries_tr_list = generate_queries_list.t_Tr(data, sel_idx_list, instruct_update=instruct_update, start_idx=5000, data_split=data_split)
        else:
            assert data_split == "validation"
            test_queries_tr_list = generate_queries_list.t_Tr(data, sel_idx_list, instruct_update=instruct_update)
    else:
        if data_split == "train":
            test_queries_tr_list = generate_queries_list.t_Tr(data, sel_idx_list, start_idx=5000, data_split=data_split)
        else:
            assert data_split == "validation"
            test_queries_tr_list = generate_queries_list.t_Tr(data, sel_idx_list)
    if run_oracle:
        if data_split == "train":
            test_response_tr_oracle_list = generate_queries_list.t_Tr_oracle(data, sel_idx_list, mturk_rewrite_list, start_idx=5000, data_split=data_split)
        else:
            assert data_split == "validation"
            test_response_tr_oracle_list = generate_queries_list.t_Tr_oracle(data, sel_idx_list, mturk_rewrite_list)
    assert len(merge_test_messages_t1_list) == len(merge_test_messages_tc_list_ensemble) == len(merge_test_messages_tp_list) == len(test_queries_tr_list) #== len(list(filter(lambda x: x != -1, sel_idx_list)))
    if run_ti or run_oracle:
        assert len(test_queries_ti_list) == len(merge_test_messages_t1_list) #== len(list(filter(lambda x: x != -1, sel_idx_list)))
    if run_oracle:
        assert len(test_response_tr_oracle_list) == len(merge_test_messages_t1_list)
    assert all(len(i) == len(templates_list_daily_dialog) for i in merge_test_messages_tc_list_ensemble)
    is_trim = 1 if is_trim else 0
    instruct_update = 1 if instruct_update else 0
    # run experiment (Tc and Tr are binded)
    # combination: (1) T1, Tc, Tr (2) T1, Tp, Tc, Tr
    # Note that (1) can be used in T1, Tc, Tr, Tp, Ti, and (2) can be used in T1, Tp, Tc, Tr, Ti
    # Note that "T1, Tc, Tp, Tr" should be worse than "T1, Tc, Tr" so I do not test it
    for method in ["t1-tc-tr", "t1-tp-tc-tr"]:
        if method == "t1-tc-tr" and resume_setting == "t1-tp-tc-tr" and resume_correction_templates >= 0: # newly added (09/09/2024)
            continue # newly added (09/09/2024)
        print(f"Method {method}")
        for i in range(len(templates_list_daily_dialog)):
            if resume_correction_templates and i < resume_correction_templates: # newly added (09/09/2024)
                continue # newly added (09/09/2024)
            else: # newly added (09/09/2024)
                # we set this < 0 so that we will not skip the rest...
                resume_correction_templates = -1 # newly added (09/09/2024)
            if not run_ti and not run_oracle: # add "and not run_oracle," this should ask LLMs to generate a new story
                if method == "t1-tc-tr":
                    prev_messages_list = [t1 + tc[i] for t1, tc in zip(merge_test_messages_t1_list, merge_test_messages_tc_list_ensemble)]
                else:
                    assert method == "t1-tp-tc-tr"
                    prev_messages_list = [t1 + tp + tc[i] for t1, tc, tp in zip(merge_test_messages_t1_list, merge_test_messages_tc_list_ensemble, merge_test_messages_tp_list)]
                final_merge_messages_list = [prev + [{"role": "user", "content": queries}] for prev, queries in zip(prev_messages_list, test_queries_tr_list)]
                if 'gpt' in model:
                    test_final_response_list = get_response_list_from_chatgpt(final_merge_messages_list, model=model, max_tokens=512)
                    test_final_response_text_list = [response.choices[0].message.content for response in test_final_response_list]
                    #test_final_response_text_list = [response['choices'][0]['message']['content'] for response in test_final_response_list]
                elif 'vicuna' in model or 'Llama-2' in model: # use FastChat
                    save_file_name = determine_saved_file_name_for_vicuna_and_llama_2(data_split, rewrite_i, i, method, is_trim, run_ti, run_oracle)
                    assert save_file_name
                    if 'vicuna' in model:
                        get_response_list_from_vicuna(final_merge_messages_list, model=model, save_file_name=save_file_name, num_gpus=num_gpus, max_gpu_memory=max_gpu_memory)
                    else:
                        assert 'Llama-2' in model
                        get_response_list_from_llama_2(final_merge_messages_list, model=model, save_file_name=save_file_name, num_gpus=num_gpus, max_gpu_memory=max_gpu_memory)
                    continue # we do NOT save file in main.py, as they will be saved by huggingface_api.py (which should be located under FastChat/fastchat/serve/)
                elif 'Llama-3' in model:
                    # In here, model is a real, large model, while model_id is the name of the model
                    test_final_response_text_list = get_response_list_from_llama_3(final_merge_messages_list, model_id=model, tokenizer=hf_llama_3_tokenizer, model=hf_llama_3_model, max_tokens=512)
                    #test_final_response_list = test_final_response_text_list # dummy, used for save_data, or delete log file in save_data tho (they're primary used for GPT-3.5 but I did not even look at them lol)
                else:
                    assert 'gemma' in model
                    # In here, model is a real, large model, while model_id is the name of the model
                    test_final_response_text_list = get_response_list_from_gemma_2(final_merge_messages_list, model_id=model, tokenizer=hf_gemma_2_tokenizer, model=hf_gemma_2_model, max_tokens=512)
                    #test_final_response_list = test_final_response_text_list # dummy, used for save_data, or delete log file in save_data tho (they're primary used for GPT-3.5 but I did not even look at them lol)
                # save data
                if method == "t1-tc-tr":
                    save_data([#{"data": str(test_final_response_list), "name": f"{log_file_path}/coqa-{data_split}-yes-no-mturk-rewrite-{rewrite_i}-c{i}-r-chatgpt-recall-story-trim-{is_trim}.log"},
                               {"data": test_final_response_text_list, "name": f"{pt_file_path}/coqa_{data_split}_yes_no_mturk_rewrite_{rewrite_i}_c{i}_r_chatgpt_recall_story_trim_{is_trim}.pt"}
                              ])
                else:
                    assert method == "t1-tp-tc-tr"
                    save_data([#{"data": str(test_final_response_list), "name": f"{log_file_path}/coqa-{data_split}-yes-no-mturk-rewrite-{rewrite_i}-p-c{i}-r-chatgpt-recall-story-trim-{is_trim}.log"},
                               {"data": test_final_response_text_list, "name": f"{pt_file_path}/coqa_{data_split}_yes_no_mturk_rewrite_{rewrite_i}_p_c{i}_r_chatgpt_recall_story_trim_{is_trim}.pt"}
                              ])
            else:
                assert run_ti or run_oracle
                test_messages_tr_list = [[{"role": "user", "content": queries}] for queries in test_queries_tr_list]
                if run_ti:
                    if method == "t1-tc-tr":
                        test_response_tr_text_list = torch.load(f"{pt_file_path}/coqa_{data_split}_yes_no_mturk_rewrite_{rewrite_i}_c{i}_r_chatgpt_recall_story_trim_{is_trim}.pt")
                        print(f"Successfully load {pt_file_path}/coqa_{data_split}_yes_no_mturk_rewrite_{rewrite_i}_c{i}_r_chatgpt_recall_story_trim_{is_trim}.pt")
                    else:
                        assert method == "t1-tp-tc-tr"
                        test_response_tr_text_list = torch.load(f"{pt_file_path}/coqa_{data_split}_yes_no_mturk_rewrite_{rewrite_i}_p_c{i}_r_chatgpt_recall_story_trim_{is_trim}.pt")
                        print(f"Successfully load {pt_file_path}/coqa_{data_split}_yes_no_mturk_rewrite_{rewrite_i}_p_c{i}_r_chatgpt_recall_story_trim_{is_trim}.pt")
                else:
                    assert run_oracle
                    test_response_tr_text_list = test_response_tr_oracle_list
                test_response_tr_list = [[{"role": "assistant", "content": _}] for _ in test_response_tr_text_list]
                merge_test_messages_tr_list = [test_msg_tr + test_response_tr for test_msg_tr, test_response_tr in zip(test_messages_tr_list, test_response_tr_list)]
                if method == "t1-tc-tr": # "t1-tc-tr-tp-ti"
                    prev_messages_list = [t1 + tc[i] + tr + tp for t1, tc, tr, tp in zip(merge_test_messages_t1_list, merge_test_messages_tc_list_ensemble, merge_test_messages_tr_list, merge_test_messages_tp_list)]
                else: # "t1-tp-tc-tr-ti"
                    assert method == "t1-tp-tc-tr"
                    prev_messages_list = [t1 + tp + tc[i] + tr for t1, tp, tc, tr in zip(merge_test_messages_t1_list, merge_test_messages_tp_list, merge_test_messages_tc_list_ensemble, merge_test_messages_tr_list)]
                final_merge_messages_list = [prev + [{"role": "user", "content": queries}] for prev, queries in zip(prev_messages_list, test_queries_ti_list)]
                if 'gpt' in model:
                    test_final_response_list = get_response_list_from_chatgpt(final_merge_messages_list, model=model)
                    test_final_response_text_list = [response.choices[0].message.content for response in test_final_response_list]
                    #test_final_response_text_list = [response['choices'][0]['message']['content'] for response in test_final_response_list]
                elif 'vicuna' in model or 'Llama-2' in model:
                    save_file_name = determine_saved_file_name_for_vicuna_and_llama_2(data_split, rewrite_i, i, method, is_trim, run_ti, run_oracle)
                    assert save_file_name
                    if 'vicuna' in model:
                        get_response_list_from_vicuna(final_merge_messages_list, model=model, save_file_name=save_file_name, num_gpus=num_gpus, max_gpu_memory=max_gpu_memory)
                    else:
                        assert 'Llama-2' in model
                        get_response_list_from_llama_2(final_merge_messages_list, model=model, save_file_name=save_file_name, num_gpus=num_gpus, max_gpu_memory=max_gpu_memory)
                    continue # we do NOT save file in main.py, as they will be saved by huggingface_api.py (which should be located under FastChat/fastchat/serve/)
                elif 'Llama-3' in model:
                    # In here, model is a real, large model, while model_id is the name of the model
                    test_final_response_text_list = get_response_list_from_llama_3(final_merge_messages_list, model_id=model, tokenizer=hf_llama_3_tokenizer, model=hf_llama_3_model)
                    #test_final_response_list = test_final_response_text_list # dummy, used for save_data, or delete log file in save_data tho (they're primary used for GPT-3.5 but I did not even look at them lol)
                else:
                    assert 'gemma' in model
                    # In here, model is a real, large model, while model_id is the name of the model
                    test_final_response_text_list = get_response_list_from_gemma_2(final_merge_messages_list, model_id=model, tokenizer=hf_gemma_2_tokenizer, model=hf_gemma_2_model)
                    #test_final_response_list = test_final_response_text_list # dummy, used for save_data, or delete log file in save_data tho (they're primary used for GPT-3.5 but I did not even look at them lol)
                if method == "t1-tc-tr": # "t1-tc-tr-tp-ti"
                    if run_ti:
                        save_data([{"data": test_final_response_text_list, "name": f"{pt_file_path}/coqa_{data_split}_yes_no_mturk_rewrite_{rewrite_i}_c{i}_r_p_chatgpt.pt"}])
                    else:
                        assert run_oracle
                        save_data([{"data": test_final_response_text_list, "name": f"{pt_file_path}/coqa_{data_split}_yes_no_mturk_rewrite_{rewrite_i}_c{i}_r_p_oracle_chatgpt.pt"}])
                else: # "t1-tp-tc-tr-ti"
                    assert method == "t1-tp-tc-tr"
                    if run_ti:
                        save_data([{"data": test_final_response_text_list, "name": f"{pt_file_path}/coqa_{data_split}_yes_no_mturk_rewrite_{rewrite_i}_p_c{i}_r_chatgpt.pt"}])
                    else:
                        assert run_oracle
                        save_data([{"data": test_final_response_text_list, "name": f"{pt_file_path}/coqa_{data_split}_yes_no_mturk_rewrite_{rewrite_i}_p_c{i}_r_oracle_chatgpt.pt"}])