import argparse
import random
import os

import torch
from datasets import load_dataset

from set_seed import set_seed
from openai_api_setup import get_response_list_from_chatgpt
##### newly added #####
from google_api_setup import get_response_list_from_gemini, get_response_list_from_gemma_2
from fastchat_api_setup import get_response_list_from_vicuna, get_response_list_from_llama_2
from meta_api_setup import get_response_list_from_llama_3
from transformers import AutoModelForCausalLM, AutoTokenizer # for Llama-3 and Gemma-2
##### newly added #####
from generate_yes_no_index_list import generate_yes_no_index_list
import generate_queries_list
from save_data import save_data
from get_templates_list_from_daily_dialog import get_templates_list_from_daily_dialog

####################################################
def parse_argument():
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument('--data_split', required=True, choices=['train', 'validation'], help='choose training set or validation set in CoQA dataset')
    parser.add_argument('--is_trim', action='store_true', help='Whether to trim story')
    parser.add_argument('--do_ablation', action='store_true', help='Whether to do ablation analysis')
    parser.add_argument('--single_turn', action='store_true', help='Whether to squeeze multi-turn dialogues into single-turn (for comparison)')
    parser.add_argument('--model', choices=['gpt-4-1106-preview', 'gpt-3.5-turbo-1106', 'gpt-3.5-turbo-0125', 'gpt-4o-2024-08-06', 'gpt-4o-mini-2024-07-18', 'gemini-1.0-pro-001', 'gemma-2-2b-it', 'gemma-2-9b-it', 'gemma-2-27b-it', 'vicuna-7b-v1.5-16k', 'vicuna-13b-v1.5-16k', 'vicuna-33b-v1.3', 'Llama-2-7b-chat-hf', 'Llama-2-13b-chat-hf', 'Meta-Llama-3-8B-Instruct'], help='specify the model, note that some models in choices may be deprecated in the future, now support models other than GPT') # newly added
    parser.add_argument('--extract_answer', action='store_true', help='Whether to convert answer to Yes/No, note that we only use it when model is GPT-4') # newly added
    parser.add_argument('--insert_new_info', action='store_true', help='Whether to insert new information into the story') # newly added
    parser.add_argument('--exclude_tc', action='store_true', help='Whether to insert new information into the story AND exclude correction phase') # newly added
    # the following two arguments are only for vicuna and Llama-2
    parser.add_argument('--num_gpus', type=int, help='choose number of gpus (only available when model is vicuna, LLama-2)')
    parser.add_argument('--max_gpu_memory', type=int, help='set maximum memory (GB) for each gpu, for example, set 14 for num_gpus=2 when running vicuna-13b-v1.5-16k, or set 7 for num_gpus=4 (float16 is used in FastChat)')
    args = parser.parse_args()
    return args

args = parse_argument()
data_split = args.data_split # "train", "validation"
is_trim = args.is_trim
do_ablation = args.do_ablation
single_turn = args.single_turn
model = args.model # "gpt-4-1106-preview", "gpt-3.5-turbo-1106"  # newly added
extract_answer = args.extract_answer # newly added
insert_new_info = args.insert_new_info # newly added
exclude_tc = args.exclude_tc # newly added

if exclude_tc: # newly added
    assert insert_new_info

if insert_new_info: # newly added
    assert data_split == "validation" # DO NOT support training set for now, since I have not modified t_T1_modified()
    assert not do_ablation and not is_trim and not single_turn

if not model:  # newly added
    model = "gpt-3.5-turbo-0613" # original configuration  # newly added

if extract_answer:  # newly added
    assert "gpt-4" in model  # newly added

# the followings are hard-coded, definitely can try different combination, but I just 'fix' it.
if single_turn:
    assert not do_ablation and not is_trim

assert not (do_ablation and is_trim) # Must be either of them is true, or none of them, but not both


num_gpus = args.num_gpus
max_gpu_memory = args.max_gpu_memory
if 'vicuna' in model or 'Llama-2' in model or 'Llama-3' in model or 'gemma' in model:
    print("Set GPU configuration...")
    print(f"num_gpus: {num_gpus}, max_gpu_memory: {max_gpu_memory}GB")
else:
    if num_gpus or max_gpu_memory:
        print(f"You set GPU-related parameters, but they are not used in {model}...")

# not sure why load model in get_response_list_from_llama_3 will have WARNING:root:Some parameters are on the meta device device because they were offloaded to the cpu.
# guess the memory is not released fast enough...?
if 'Llama-3' in model:
    model_id = 'meta-llama/' + model
    device = torch.device('cuda')
    print(f'Load {model}\'s model and tokenizer...')
    hf_llama_3_tokenizer = AutoTokenizer.from_pretrained(model_id)
    assert hf_llama_3_tokenizer.chat_template is not None # ensure the template they trained with exists, or we have to do a lot of prompt engineering...
    hf_llama_3_model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="auto")
    hf_llama_3_model.generation_config.temperature = 1e-8 # ValueError: `temperature` (=0) has to be a strictly positive float, otherwise your next token scores will be invalid.
    hf_llama_3_model.generation_config.top_p = 1
    print('Model\'s generation_config:')
    print(hf_llama_3_model.generation_config.to_dict())

if 'gemma' in model:
    model_id = 'google/' + model
    device = torch.device('cuda')
    print(f'Load {model}\'s model and tokenizer...')
    hf_gemma_2_tokenizer = AutoTokenizer.from_pretrained(model_id)
    assert hf_gemma_2_tokenizer.chat_template is not None # ensure the template they trained with exists, or we have to do a lot of prompt engineering...
    hf_gemma_2_model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="auto")
    hf_gemma_2_model.generation_config.temperature = 0
    hf_gemma_2_model.generation_config.top_p = 1
    print('Model\'s generation_config:')
    print(hf_gemma_2_model.generation_config.to_dict())

####################################################

_d_coqa = load_dataset('coqa')
data = _d_coqa[data_split]


yn_qa_idx_list = generate_yes_no_index_list(data)
set_seed(0)
sel_idx_list = [random.choice(i) if i else -1 for i in yn_qa_idx_list]

# old
#log_file_path = "../data/log"
#pt_file_path = "../data/pt"
# update
#log_file_path = f"../data/log/{model}/{data_split}"
pt_file_path = f"../data/pt/{model}/{data_split}"

#assert os.path.exists(log_file_path)
assert os.path.exists(pt_file_path)


if data_split == "validation":
    num_data = 464
else:
    assert data_split == "train"
    num_data = 1317


mturk_responses = torch.load(f'../data/csv/MTurk_{data_split}/final_{data_split}_{num_data}.pt')
#assert num_data == len(list(filter(lambda x: x != -1, sel_idx_list))) # does not work on train because we haven't labelled all data
assert all(len(mturk_responses[data_split][k]) == 3 for k in mturk_responses[data_split])

if data_split == "validation":
    mturk_rewrite_0_list = [mturk_responses[data_split][k][0] for k in range(num_data)]
    mturk_rewrite_1_list = [mturk_responses[data_split][k][1] for k in range(num_data)]
    mturk_rewrite_2_list = [mturk_responses[data_split][k][2] for k in range(num_data)]
else:
    assert data_split == "train"
    _ = sorted(mturk_responses[data_split].keys()) # [5000, 5001, ..., 6316]
    mturk_rewrite_0_list = [mturk_responses[data_split][k][0] for k in _]
    mturk_rewrite_1_list = [mturk_responses[data_split][k][1] for k in _]
    mturk_rewrite_2_list = [mturk_responses[data_split][k][2] for k in _]

# old
#mturk_responses = torch.load('../data/csv/MTurk_validation/final_validation_464.pt')
#assert 464 == len(list(filter(lambda x: x != -1, sel_idx_list)))
#assert all(len(mturk_responses["validation"][k]) == 3 for k in mturk_responses["validation"])
#mturk_rewrite_0_list = [mturk_responses["validation"][k][0] for k in range(464)]
#mturk_rewrite_1_list = [mturk_responses["validation"][k][1] for k in range(464)]
#mturk_rewrite_2_list = [mturk_responses["validation"][k][2] for k in range(464)]

mturk_rewrite_lists = [mturk_rewrite_0_list,
                       mturk_rewrite_1_list,
                       mturk_rewrite_2_list
                      ]



####################################################

def determine_saved_file_name_for_vicuna_and_llama_2(data_split, rewrite_i, i, method, insert_new_info, do_ablation, is_trim, single_turn, exclude_tc):
    if method == "t1-tc-tp-ti":
        if insert_new_info:
            assert not do_ablation and not is_trim and not single_turn
            if exclude_tc:
                return f"coqa_{data_split}_yes_no_mturk_rewrite_{rewrite_i}_c{i}_p_insert_exclude_tc_chatgpt.pt"
            else:
                assert not exclude_tc
                return f"coqa_{data_split}_yes_no_mturk_rewrite_{rewrite_i}_c{i}_p_insert_chatgpt.pt"
        if single_turn:
            assert not do_ablation and not is_trim
            return f"coqa_{data_split}_yes_no_mturk_rewrite_{rewrite_i}_c{i}_p_single_turn_chatgpt.pt"
        if do_ablation:
            return f"coqa_{data_split}_yes_no_mturk_rewrite_{rewrite_i}_c{i}_p_ablation_chatgpt.pt"
        else:
            assert not do_ablation
            if is_trim:
                return f"coqa_{data_split}_yes_no_mturk_rewrite_{rewrite_i}_c{i}_p_trim_chatgpt.pt"
            else:
                assert not is_trim
                return f"coqa_{data_split}_yes_no_mturk_rewrite_{rewrite_i}_c{i}_p_chatgpt.pt"
    else:
        assert method == "t1-tp-tc-ti"
        if insert_new_info:
            assert not do_ablation and not is_trim and not single_turn
            if exclude_tc:
                # Well, they are the same so we should not run again..., but it does not hurt actually tho... just cost more time (and money?)
                return f"coqa_{data_split}_yes_no_mturk_rewrite_{rewrite_i}_p_c{i}_insert_exclude_tc_chatgpt.pt"
            else:
                assert not exclude_tc
                return f"coqa_{data_split}_yes_no_mturk_rewrite_{rewrite_i}_p_c{i}_insert_chatgpt.pt"
        if single_turn:
            assert not do_ablation and not is_trim
            return f"coqa_{data_split}_yes_no_mturk_rewrite_{rewrite_i}_p_c{i}_single_turn_chatgpt.pt"
        if do_ablation:
            return f"coqa_{data_split}_yes_no_mturk_rewrite_{rewrite_i}_p_c{i}_ablation_chatgpt.pt"
        else:
            assert not do_ablation
            if is_trim:
                return f"coqa_{data_split}_yes_no_mturk_rewrite_{rewrite_i}_p_c{i}_trim_chatgpt.pt"
            else:
                return f"coqa_{data_split}_yes_no_mturk_rewrite_{rewrite_i}_p_c{i}_chatgpt.pt"

####################################################

# decide role and content key-value pair. For example, in ChatGPT:
# {"role": "user", "content": "How are you?"} -> {ROLE_KEY: ROLE_USER_VALUE, CONTENT_KEY: "How are you?"}
# {"role": "assistant", "content": "I am fine, thank you."} -> {ROLE_KEY: ROLE_BOT_VALUE, CONTENT_KEY: "I am fine, thank you."}
ROLE_KEY = "role"
ROLE_VALUE_USER = "user"
ROLE_VALUE_BOT = "assistant"
CONTENT_KEY = "content"
if 'gemini' in model:
    ROLE_BOT_VALUE = "model"
    CONTENT_KEY = "parts"


####################################################
# Test if ChatGPT can update answer based on new story

templates_list_daily_dialog = get_templates_list_from_daily_dialog()

for rewrite_i in range(len(mturk_rewrite_lists)):
    mturk_rewrite_list = mturk_rewrite_lists[rewrite_i]
    print(f"MTurk rewrite {rewrite_i}, data_split={data_split}, is_trim={is_trim}, do_ablation={do_ablation}, single_turn={single_turn}, insert_new_info={insert_new_info}, exclude_tc={exclude_tc}, model={model}")
    # (test) T1
    if insert_new_info:
        # if story = "ABC DEF GHI JKL..."" and GHI is wrong (should be XYZ), then after modification, the new story
        # may look like "ABC DEF GHI (Actually, XYZ) JKL..." <- append Tc after the original support sentence.
        # Note that the correction templates will be dealt with later (after the Tc templates are loaded)
        assert data_split == "validation"
        start_token = "[START]"
        end_token = "[END]"
        test_queries_t1_list = generate_queries_list.t_T1_insertion(data, sel_idx_list, mturk_rewrite_list, start_token=start_token, end_token=end_token)
    else:
        assert not insert_new_info
        if data_split == "train":
            test_queries_t1_list = generate_queries_list.t_T1(data, sel_idx_list, is_trim=is_trim, do_ablation=do_ablation, start_idx=5000, data_split=data_split)
        else:
            assert data_split == "validation"
            test_queries_t1_list = generate_queries_list.t_T1(data, sel_idx_list, is_trim=is_trim, do_ablation=do_ablation)
    test_messages_t1_list = [[{ROLE_KEY: ROLE_VALUE_USER, CONTENT_KEY: queries}] for queries in test_queries_t1_list]
    if do_ablation:
        test_response_t1 = [{ROLE_KEY: ROLE_VALUE_BOT, CONTENT_KEY: ""}]
    else:
        assert not do_ablation
        test_response_t1 = [{ROLE_KEY: ROLE_VALUE_BOT, CONTENT_KEY: "Yes, I have memorized the story."}]
    merge_test_messages_t1_list = [test_msg_t1 + test_response_t1 for test_msg_t1 in test_messages_t1_list]
    # (test) Tp (previous QA pairs, consecutive)
    if data_split == "train":
        if 'gemini' not in model:
            merge_test_messages_tp_list = generate_queries_list.t_Tp(data, sel_idx_list, start_idx=5000, data_split=data_split)
        else:
            merge_test_messages_tp_list = generate_queries_list.t_Tp(data, sel_idx_list, start_idx=5000, data_split=data_split, model='gemini')
    else:
        assert data_split == "validation"
        if 'gemini' not in model:
            merge_test_messages_tp_list = generate_queries_list.t_Tp(data, sel_idx_list)
        else:
            merge_test_messages_tp_list = generate_queries_list.t_Tp(data, sel_idx_list, model='gemini')
    # (test) Tc (ideally, can insert anywhere)
    if data_split == "train":
        test_queries_correct_list_ensemble = generate_queries_list.t_Tc_ensemble(data, sel_idx_list, templates_list_daily_dialog, mturk_rewrite_list, start_idx=5000, data_split=data_split)
    else:
        assert data_split == "validation"
        test_queries_correct_list_ensemble = generate_queries_list.t_Tc_ensemble(data, sel_idx_list, templates_list_daily_dialog, mturk_rewrite_list)
    test_messages_correct_list_ensemble = [[[{ROLE_KEY: ROLE_VALUE_USER, CONTENT_KEY: cur_query}] for cur_query in queries] for queries in test_queries_correct_list_ensemble]
    # play with ChatGPT and choose one reasonable response beforehand
    if do_ablation:
        test_response_correct = [{ROLE_KEY: ROLE_VALUE_BOT, CONTENT_KEY: ""}]
    else:
        assert not do_ablation
        test_response_correct = [{ROLE_KEY: ROLE_VALUE_BOT, CONTENT_KEY: "No problem at all! I have updated my memory of the story with the correction you provided. Thank you for letting me know."}]
    merge_test_messages_tc_list_ensemble = [[j + test_response_correct for j in i] for i in test_messages_correct_list_ensemble]
    # (test) Ti
    if data_split == "train":
        test_queries_ti_list = generate_queries_list.t_Ti(data, sel_idx_list, start_idx=5000, data_split=data_split)
    else:
        assert data_split == "validation"
        test_queries_ti_list = generate_queries_list.t_Ti(data, sel_idx_list)
    assert len(merge_test_messages_t1_list) == len(merge_test_messages_tc_list_ensemble) == len(merge_test_messages_tp_list) == len(test_queries_ti_list) #== len(list(filter(lambda x: x != -1, sel_idx_list)))
    assert all(len(i) == len(templates_list_daily_dialog) for i in merge_test_messages_tc_list_ensemble)
    # run experiment
    # combination: (1) T1, Tc, Tp, Ti (2) T1, Tp, Tc, Ti
    for method in ["t1-tc-tp-ti", "t1-tp-tc-ti"]:
        for i in range(len(templates_list_daily_dialog)):
            if method == "t1-tp-tc-ti" and exclude_tc and insert_new_info: # As we do not include Tc turn, there is no difference between CAM and CBA, both will be T1 (insert), Tp, Ti
                break
            if not insert_new_info:
                if method == "t1-tc-tp-ti":
                    prev_messages_list = [t1 + tc[i] + tp for t1, tc, tp in zip(merge_test_messages_t1_list, merge_test_messages_tc_list_ensemble, merge_test_messages_tp_list)]
                else:
                    assert method == "t1-tp-tc-ti"
                    prev_messages_list = [t1 + tp + tc[i] for t1, tc, tp in zip(merge_test_messages_t1_list, merge_test_messages_tc_list_ensemble, merge_test_messages_tp_list)]
            else:
                assert insert_new_info
                prev_messages_list = []
                if method == "t1-tc-tp-ti":
                    for j, (t1, tc, tp) in enumerate(zip(merge_test_messages_t1_list, merge_test_messages_tc_list_ensemble, merge_test_messages_tp_list)):
                        assert f"{start_token}{mturk_rewrite_list[j]}{end_token}" in t1[0][CONTENT_KEY]
                        _ = t1[0][CONTENT_KEY].replace(f"{start_token}{mturk_rewrite_list[j]}{end_token}", f"({tc[i][0][CONTENT_KEY]})")
                        if not exclude_tc:
                            prev_messages_list.append([{ROLE_KEY: ROLE_VALUE_USER, CONTENT_KEY: _}] + test_response_t1 + tc[i] + tp)
                        else:
                            assert exclude_tc
                            prev_messages_list.append([{ROLE_KEY: ROLE_VALUE_USER, CONTENT_KEY: _}] + test_response_t1 + tp)
                else:
                    assert method == "t1-tp-tc-ti"
                    for j, (t1, tc, tp) in enumerate(zip(merge_test_messages_t1_list, merge_test_messages_tc_list_ensemble, merge_test_messages_tp_list)):
                        assert f"{start_token}{mturk_rewrite_list[j]}{end_token}" in t1[0][CONTENT_KEY]
                        _ = t1[0][CONTENT_KEY].replace(f"{start_token}{mturk_rewrite_list[j]}{end_token}", f"({tc[i][0][CONTENT_KEY]})")
                        if not exclude_tc:
                            prev_messages_list.append([{ROLE_KEY: ROLE_VALUE_USER, CONTENT_KEY: _}] + test_response_t1 + tp + tc[i])
                        else:
                            assert exclude_tc
                            prev_messages_list.append([{ROLE_KEY: ROLE_VALUE_USER, CONTENT_KEY: _}] + test_response_t1 + tp)
            if not extract_answer:
                final_merge_messages_list = [prev + [{ROLE_KEY: ROLE_VALUE_USER, CONTENT_KEY: queries}] for prev, queries in zip(prev_messages_list, test_queries_ti_list)]
                if single_turn:
                    final_merge_messages_text_list = ['\n'.join([_[CONTENT_KEY] for _ in final_merge_messages]) for final_merge_messages in final_merge_messages_list]
                    final_merge_messages_list = [[{ROLE_KEY: ROLE_VALUE_USER, CONTENT_KEY: queries}] for queries in final_merge_messages_text_list]
                if 'gpt' in model:
                    test_final_response_list = get_response_list_from_chatgpt(final_merge_messages_list, model=model)
                    test_final_response_text_list = [response.choices[0].message.content for response in test_final_response_list]
                    #test_final_response_text_list = [response['choices'][0]['message']['content'] for response in test_final_response_list]
                elif 'vicuna' in model or 'Llama-2' in model: # use FastChat
                    save_file_name = determine_saved_file_name_for_vicuna_and_llama_2(data_split, rewrite_i, i, method, insert_new_info, do_ablation, is_trim, single_turn, exclude_tc)
                    assert save_file_name
                    if 'vicuna' in model:
                        get_response_list_from_vicuna(final_merge_messages_list, model=model, save_file_name=save_file_name, num_gpus=num_gpus, max_gpu_memory=max_gpu_memory)
                    else:
                        assert 'Llama-2' in model
                        get_response_list_from_llama_2(final_merge_messages_list, model=model, save_file_name=save_file_name, num_gpus=num_gpus, max_gpu_memory=max_gpu_memory)
                    continue # we do NOT save file in main.py, as they will be saved by huggingface_api.py (which should be located under FastChat/fastchat/serve/)
                elif 'Llama-3' in model:
                    # In here, model is a real, large model, while model_id is the name of the model
                    test_final_response_text_list = get_response_list_from_llama_3(final_merge_messages_list, model_id=model, tokenizer=hf_llama_3_tokenizer, model=hf_llama_3_model)
                    test_final_response_list = test_final_response_text_list # dummy, used for save_data, or delete log file in save_data tho (they're primary used for GPT-3.5 but I did not even look at them lol)
                elif 'gemma' in model:
                    # In here, model is a real, large model, while model_id is the name of the model
                    test_final_response_text_list = get_response_list_from_gemma_2(final_merge_messages_list, model_id=model, tokenizer=hf_gemma_2_tokenizer, model=hf_gemma_2_model)
                    test_final_response_list = test_final_response_text_list # dummy, used for save_data, or delete log file in save_data tho (they're primary used for GPT-3.5 but I did not even look at them lol)
                else:
                    assert 'gemini' in model
                    test_final_response_list = get_response_list_from_gemini(final_merge_messages_list, model=model)
                    # do not use len(response.parts) to avoid ValueError: The `response.parts` quick accessor only works for a single candidate, but none were returned. Check the `response.prompt_feedback` to see if the prompt was blocked.
                    # UPDATE: STILL got error using len(response.candidates), resort to try & except instead
                    test_final_response_text_list = []
                    for response in test_final_response_list:
                        try:
                            test_final_response_text_list.append(response.parts[0].text)
                        except:
                            test_final_response_text_list.append(response.prompt_feedback.block_reason.name)
                        finally:
                            pass
                # save data
                # TODO: determine file name to be saved first then there's only one line to call: save_data()
                # NOTE THAT THESE CODE SHOULD BE THE SAME AS VICUNA AND LLAMA-2 "saved_file_name"!!!
                if method == "t1-tc-tp-ti":
                    if insert_new_info:
                        assert not do_ablation and not is_trim and not single_turn
                        if exclude_tc:
                            save_data([#{"data": str(test_final_response_list), "name": f"{log_file_path}/coqa-{data_split}-yes-no-mturk-rewrite-{rewrite_i}-c{i}-p-insert-exclude-tc-chatgpt.log"},
                                       {"data": test_final_response_text_list, "name": f"{pt_file_path}/coqa_{data_split}_yes_no_mturk_rewrite_{rewrite_i}_c{i}_p_insert_exclude_tc_chatgpt.pt"}
                                      ])
                            # duplicate the data for error_analysis only (our code, error_analysis.py, need to automatically read two settings (CAM and CBA) at once)
                            save_data([#{"data": str(test_final_response_list), "name": f"{log_file_path}/coqa-{data_split}-yes-no-mturk-rewrite-{rewrite_i}-p-c{i}-insert-exclude-tc-chatgpt.log"},
                                       {"data": test_final_response_text_list, "name": f"{pt_file_path}/coqa_{data_split}_yes_no_mturk_rewrite_{rewrite_i}_p_c{i}_insert_exclude_tc_chatgpt.pt"}
                                      ])
                        else:
                            assert not exclude_tc
                            save_data([#{"data": str(test_final_response_list), "name": f"{log_file_path}/coqa-{data_split}-yes-no-mturk-rewrite-{rewrite_i}-c{i}-p-insert-chatgpt.log"},
                                       {"data": test_final_response_text_list, "name": f"{pt_file_path}/coqa_{data_split}_yes_no_mturk_rewrite_{rewrite_i}_c{i}_p_insert_chatgpt.pt"}
                                      ])
                        continue
                    if single_turn:
                        assert not do_ablation and not is_trim
                        save_data([#{"data": str(test_final_response_list), "name": f"{log_file_path}/coqa-{data_split}-yes-no-mturk-rewrite-{rewrite_i}-c{i}-p-single-turn-chatgpt.log"},
                                   {"data": test_final_response_text_list, "name": f"{pt_file_path}/coqa_{data_split}_yes_no_mturk_rewrite_{rewrite_i}_c{i}_p_single_turn_chatgpt.pt"}
                                  ])
                        continue
                    if do_ablation:
                        save_data([#{"data": str(test_final_response_list), "name": f"{log_file_path}/coqa-{data_split}-yes-no-mturk-rewrite-{rewrite_i}-c{i}-p-ablation-chatgpt.log"},
                                   {"data": test_final_response_text_list, "name": f"{pt_file_path}/coqa_{data_split}_yes_no_mturk_rewrite_{rewrite_i}_c{i}_p_ablation_chatgpt.pt"}
                                  ])
                    else:
                        assert not do_ablation
                        if is_trim:
                            save_data([#{"data": str(test_final_response_list), "name": f"{log_file_path}/coqa-{data_split}-yes-no-mturk-rewrite-{rewrite_i}-c{i}-p-trim-chatgpt.log"},
                                       {"data": test_final_response_text_list, "name": f"{pt_file_path}/coqa_{data_split}_yes_no_mturk_rewrite_{rewrite_i}_c{i}_p_trim_chatgpt.pt"}
                                      ])
                        else:
                            assert not is_trim
                            save_data([#{"data": str(test_final_response_list), "name": f"{log_file_path}/coqa-{data_split}-yes-no-mturk-rewrite-{rewrite_i}-c{i}-p-chatgpt.log"},
                                       {"data": test_final_response_text_list, "name": f"{pt_file_path}/coqa_{data_split}_yes_no_mturk_rewrite_{rewrite_i}_c{i}_p_chatgpt.pt"}
                                      ])
                else:
                    assert method == "t1-tp-tc-ti"
                    if insert_new_info:
                        assert not do_ablation and not is_trim and not single_turn
                        if exclude_tc:
                            pass # it's saved earlier
                            #save_data([{"data": str(test_final_response_list), "name": f"{log_file_path}/coqa-{data_split}-yes-no-mturk-rewrite-{rewrite_i}-p-c{i}-insert-exclude-tc-chatgpt.log"},
                            #           {"data": test_final_response_text_list, "name": f"{pt_file_path}/coqa_{data_split}_yes_no_mturk_rewrite_{rewrite_i}_p_c{i}_insert_exclude_tc_chatgpt.pt"}
                            #          ])
                        else:
                            assert not exclude_tc
                            save_data([#{"data": str(test_final_response_list), "name": f"{log_file_path}/coqa-{data_split}-yes-no-mturk-rewrite-{rewrite_i}-p-c{i}-insert-chatgpt.log"},
                                       {"data": test_final_response_text_list, "name": f"{pt_file_path}/coqa_{data_split}_yes_no_mturk_rewrite_{rewrite_i}_p_c{i}_insert_chatgpt.pt"}
                                      ])
                        continue
                    if single_turn:
                        assert not do_ablation and not is_trim
                        save_data([#{"data": str(test_final_response_list), "name": f"{log_file_path}/coqa-{data_split}-yes-no-mturk-rewrite-{rewrite_i}-p-c{i}-single-turn-chatgpt.log"},
                                   {"data": test_final_response_text_list, "name": f"{pt_file_path}/coqa_{data_split}_yes_no_mturk_rewrite_{rewrite_i}_p_c{i}_single_turn_chatgpt.pt"}
                                  ])
                        continue
                    if do_ablation:
                        save_data([#{"data": str(test_final_response_list), "name": f"{log_file_path}/coqa-{data_split}-yes-no-mturk-rewrite-{rewrite_i}-p-c{i}-ablation-chatgpt.log"},
                                   {"data": test_final_response_text_list, "name": f"{pt_file_path}/coqa_{data_split}_yes_no_mturk_rewrite_{rewrite_i}_p_c{i}_ablation_chatgpt.pt"}
                                  ])
                    else:
                        assert not do_ablation
                        if is_trim:
                            save_data([#{"data": str(test_final_response_list), "name": f"{log_file_path}/coqa-{data_split}-yes-no-mturk-rewrite-{rewrite_i}-p-c{i}-trim-chatgpt.log"},
                                       {"data": test_final_response_text_list, "name": f"{pt_file_path}/coqa_{data_split}_yes_no_mturk_rewrite_{rewrite_i}_p_c{i}_trim_chatgpt.pt"}
                                      ])
                        else:
                            save_data([#{"data": str(test_final_response_list), "name": f"{log_file_path}/coqa-{data_split}-yes-no-mturk-rewrite-{rewrite_i}-p-c{i}-chatgpt.log"},
                                       {"data": test_final_response_text_list, "name": f"{pt_file_path}/coqa_{data_split}_yes_no_mturk_rewrite_{rewrite_i}_p_c{i}_chatgpt.pt"}
                                      ])
            else: # only 'gpt-4' can go here
                assert extract_answer and 'gpt-4' in model # this is also added in Lines 55-56
                if method == "t1-tc-tp-ti":
                    test_response_ti_text_list = torch.load(f"{pt_file_path}/coqa_{data_split}_yes_no_mturk_rewrite_{rewrite_i}_c{i}_p_chatgpt.pt")
                else:
                    assert method == "t1-tp-tc-ti"
                    test_response_ti_text_list = torch.load(f"{pt_file_path}/coqa_{data_split}_yes_no_mturk_rewrite_{rewrite_i}_p_c{i}_chatgpt.pt")
                # (test) Ti
                test_messages_ti_list = [[{ROLE_KEY: ROLE_VALUE_USER, CONTENT_KEY: queries}] for queries in test_queries_ti_list]
                test_response_ti_list = [[{ROLE_KEY: ROLE_VALUE_BOT, CONTENT_KEY: _}] for _ in test_response_ti_text_list]
                merge_test_messages_ti_list = [test_msg_ti + test_response_ti for test_msg_ti, test_response_ti in zip(test_messages_ti_list, test_response_ti_list)]
                ####
                # Re-use generate_queries_list.t_Tv_2()
                # (test) Ti_2 (2nd-stage prompting, tell ChatGPT to classify and output the answer that starts with Yes/No/Unknown)
                if data_split == "train":
                    test_queries_ti_2_list = generate_queries_list.t_Tv_2(data, sel_idx_list, start_idx=5000, data_split=data_split)
                else:
                    assert data_split == "validation"
                    test_queries_ti_2_list = generate_queries_list.t_Tv_2(data, sel_idx_list)
                if method == "t1-tc-tp-ti":
                    prev_messages_list = [t1 + tc[i] + tp + ti for t1, tc, tp, ti in zip(merge_test_messages_t1_list, merge_test_messages_tc_list_ensemble, merge_test_messages_tp_list, merge_test_messages_ti_list)]
                else:
                    assert method == "t1-tp-tc-ti"
                    prev_messages_list = [t1 + tp + tc[i] + ti for t1, tp, tc, ti in zip(merge_test_messages_t1_list, merge_test_messages_tp_list, merge_test_messages_tc_list_ensemble, merge_test_messages_ti_list)]
                final_merge_messages_list = [prev + [{ROLE_KEY: ROLE_VALUE_USER, CONTENT_KEY: queries}] for prev, queries in zip(prev_messages_list, test_queries_ti_2_list)]
                test_final_response_list = get_response_list_from_chatgpt(final_merge_messages_list, model=model)
                test_final_response_text_list = [response.choices[0].message.content for response in test_final_response_list]
                #test_final_response_text_list = [response['choices'][0]['message']['content'] for response in test_final_response_list]
                if method == "t1-tc-tp-ti":
                    save_data([#{"data": str(test_final_response_list), "name": f"{log_file_path}/coqa-{data_split}-yes-no-mturk-rewrite-{rewrite_i}-c{i}-p-i-i2-chatgpt.log"},
                               {"data": test_final_response_text_list, "name": f"{pt_file_path}/coqa_{data_split}_yes_no_mturk_rewrite_{rewrite_i}_c{i}_p_i_i2_chatgpt.pt"}
                              ])
                else:
                    assert method == "t1-tp-tc-ti"
                    save_data([#{"data": str(test_final_response_list), "name": f"{log_file_path}/coqa-{data_split}-yes-no-mturk-rewrite-{rewrite_i}-p-c{i}-i-i2-chatgpt.log"},
                               {"data": test_final_response_text_list, "name": f"{pt_file_path}/coqa_{data_split}_yes_no_mturk_rewrite_{rewrite_i}_p_c{i}_i_i2_chatgpt.pt"}
                              ])
