import argparse
import random
import os

import torch
from datasets import load_dataset

from set_seed import set_seed
from openai_api_setup import get_response_list_from_chatgpt
##### newly added #####
from google_api_setup import get_response_list_from_gemini, get_response_list_from_gemma_2
from fastchat_api_setup import get_response_list_from_vicuna, get_response_list_from_llama_2
from meta_api_setup import get_response_list_from_llama_3
from transformers import AutoModelForCausalLM, AutoTokenizer # for Llama-3 and Gemma-2
##### newly added #####
from generate_yes_no_index_list import generate_yes_no_index_list
import generate_queries_list
from save_data import save_data
from get_templates_list_from_daily_dialog import get_templates_list_from_daily_dialog

####################################################
def parse_argument():
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument('--data_split', required=True, choices=['train', 'validation'], help='choose training set or validation set in CoQA dataset')
    parser.add_argument('--model', required=True, choices=['gpt-4-1106-preview', 'gpt-3.5-turbo-0613', 'gpt-3.5-turbo-1106', 'gpt-3.5-turbo-0125', 'gpt-4o-mini-2024-07-18', 'gemma-2-2b-it', 'gemma-2-9b-it', 'gemma-2-27b-it', 'vicuna-7b-v1.5-16k', 'vicuna-13b-v1.5-16k', 'vicuna-33b-v1.3', 'Llama-2-7b-chat-hf', 'Llama-2-13b-chat-hf', 'Meta-Llama-3-8B-Instruct'], help='specify the model, note that some models in choices may be deprecated in the future, now support models other than GPT') # newly added
    parser.add_argument('--extract_answer', action='store_true', help='Whether to convert Tv answer to Yes/No')
    parser.add_argument('--num_gpus', type=int, help='choose number of gpus (only available when model is vicuna, LLama-2)')
    parser.add_argument('--max_gpu_memory', type=int, help='set maximum memory (GB) for each gpu, for example, set 14 for num_gpus=2 when running vicuna-13b-v1.5-16k, or set 7 for num_gpus=4 (float16 is used in FastChat)')
    # these arguments are ONLY for resumed inference (normally you don't need to set this)... due to unexpected error like server crashing...
    parser.add_argument('--resume_rewrite_i', type=int, choices=[0, 1, 2], help='choose the first (0), second (1), or last (2) mturk rewrite new knowledge to start with.') # newly added (09/09/2024)
    parser.add_argument('--resume_setting', choices=['t1-tc-tp-ti-tv', 't1-tp-tc-ti-tv'], help='choose the experimental setting to start with (\'t1-tc-tp-ti-tv\' = CAM, \'t1-tp-tc-ti-tv\' = CBA).') # newly added (09/09/2024)
    parser.add_argument('--resume_correction_templates', type=int, choices=range(0,15), help='choose the correction templates to start with.') # newly added (09/09/2024)
    args = parser.parse_args()
    return args

args = parse_argument()
data_split = args.data_split # "train", "validation"
model = args.model
extract_answer = args.extract_answer

resume_rewrite_i = args.resume_rewrite_i # newly added (09/09/2024)
resume_setting = args.resume_setting # newly added (09/09/2024)
resume_correction_templates = args.resume_correction_templates # newly added (09/09/2024)
assert all(_ is None for _ in [resume_rewrite_i, resume_correction_templates, resume_setting]) or any(_ is not None for _ in [resume_rewrite_i, resume_correction_templates, resume_setting]) # newly added (09/09/2024)

num_gpus = args.num_gpus
max_gpu_memory = args.max_gpu_memory
if 'vicuna' in model or 'Llama-2' in model or 'Llama-3' in model or 'gemma' in model:
    print("Set GPU configuration...")
    print(f"num_gpus: {num_gpus}, max_gpu_memory: {max_gpu_memory}GB")
else:
    if num_gpus or max_gpu_memory:
        print(f"You set GPU-related parameters, but they are not used in {model}...")

# not sure why load model in get_response_list_from_llama_3 will have WARNING:root:Some parameters are on the meta device device because they were offloaded to the cpu.
# guess the memory is not released fast enough...?
if 'Llama-3' in model:
    model_id = 'meta-llama/' + model
    device = torch.device('cuda')
    print(f'Load {model}\'s model and tokenizer...')
    hf_llama_3_tokenizer = AutoTokenizer.from_pretrained(model_id)
    assert hf_llama_3_tokenizer.chat_template is not None # ensure the template they trained with exists, or we have to do a lot of prompt engineering...
    hf_llama_3_model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="auto")
    hf_llama_3_model.generation_config.temperature = 1e-8 # ValueError: `temperature` (=0) has to be a strictly positive float, otherwise your next token scores will be invalid.
    hf_llama_3_model.generation_config.top_p = 1
    print('Model\'s generation_config:')
    print(hf_llama_3_model.generation_config.to_dict())

if 'gemma' in model:
    model_id = 'google/' + model
    device = torch.device('cuda')
    print(f'Load {model}\'s model and tokenizer...')
    hf_gemma_2_tokenizer = AutoTokenizer.from_pretrained(model_id)
    assert hf_gemma_2_tokenizer.chat_template is not None # ensure the template they trained with exists, or we have to do a lot of prompt engineering...
    hf_gemma_2_model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="auto")
    hf_gemma_2_model.generation_config.temperature = 0
    hf_gemma_2_model.generation_config.top_p = 1
    print('Model\'s generation_config:')
    print(hf_gemma_2_model.generation_config.to_dict())

####################################################

_d_coqa = load_dataset('coqa')
data = _d_coqa[data_split]


yn_qa_idx_list = generate_yes_no_index_list(data)
set_seed(0)
sel_idx_list = [random.choice(i) if i else -1 for i in yn_qa_idx_list]


#log_file_path = "../data/log"
#pt_file_path = "../data/pt"
#log_file_path = f"../data/log/{model}/{data_split}"
pt_file_path = f"../data/pt/{model}/{data_split}"

#assert os.path.exists(log_file_path)
assert os.path.exists(pt_file_path)


if data_split == "validation":
    num_data = 464
else:
    assert data_split == "train"
    num_data = 1317


mturk_responses = torch.load(f'../data/csv/MTurk_{data_split}/final_{data_split}_{num_data}.pt')
#assert num_data == len(list(filter(lambda x: x != -1, sel_idx_list))) # does not work on train because we haven't labelled all data
assert all(len(mturk_responses[data_split][k]) == 3 for k in mturk_responses[data_split])

if data_split == "validation":
    mturk_rewrite_0_list = [mturk_responses[data_split][k][0] for k in range(num_data)]
    mturk_rewrite_1_list = [mturk_responses[data_split][k][1] for k in range(num_data)]
    mturk_rewrite_2_list = [mturk_responses[data_split][k][2] for k in range(num_data)]
else:
    assert data_split == "train"
    _ = sorted(mturk_responses[data_split].keys()) # [5000, 5001, ..., 6316]
    mturk_rewrite_0_list = [mturk_responses[data_split][k][0] for k in _]
    mturk_rewrite_1_list = [mturk_responses[data_split][k][1] for k in _]
    mturk_rewrite_2_list = [mturk_responses[data_split][k][2] for k in _]

mturk_rewrite_lists = [mturk_rewrite_0_list,
                       mturk_rewrite_1_list,
                       mturk_rewrite_2_list
                      ]

####################################################

def determine_saved_file_name_for_vicuna_and_llama_2(data_split, rewrite_i, i, method, extract_answer):
    if not extract_answer:
        if method == "t1-tc-tp-ti-tv":
            return f"coqa_{data_split}_yes_no_mturk_rewrite_{rewrite_i}_c{i}_p_i_v_chatgpt.pt"
        else:
            assert method == "t1-tp-tc-ti-tv"
            return f"coqa_{data_split}_yes_no_mturk_rewrite_{rewrite_i}_p_c{i}_i_v_chatgpt.pt"
    else:
        assert extract_answer
        if method == "t1-tc-tp-ti-tv":
            return f"coqa_{data_split}_yes_no_mturk_rewrite_{rewrite_i}_c{i}_p_i_v_v2_chatgpt.pt"
        else:
            assert method == "t1-tp-tc-ti-tv"
            return f"coqa_{data_split}_yes_no_mturk_rewrite_{rewrite_i}_p_c{i}_i_v_v2_chatgpt.pt"

####################################################
# Test if ChatGPT can update answer based on new story

templates_list_daily_dialog = get_templates_list_from_daily_dialog()

for rewrite_i in range(len(mturk_rewrite_lists)):
    if resume_rewrite_i and rewrite_i < resume_rewrite_i: # newly added (09/09/2024)
        continue # newly added (09/09/2024)
    print(f"MTurk rewrite {rewrite_i}, data_split = {data_split}, extract_answer = {extract_answer}, model={model}")
    mturk_rewrite_list = mturk_rewrite_lists[rewrite_i]
    # (test) T1
    if data_split == "train":
        test_queries_t1_list = generate_queries_list.t_T1(data, sel_idx_list, start_idx=5000, data_split=data_split)
    else:
        assert data_split == "validation"
        test_queries_t1_list = generate_queries_list.t_T1(data, sel_idx_list)
    test_messages_t1_list = [[{"role": "user", "content": queries}] for queries in test_queries_t1_list]
    test_response_t1 = [{"role": "assistant", "content": "Yes, I have memorized the story."}]
    merge_test_messages_t1_list = [test_msg_t1 + test_response_t1 for test_msg_t1 in test_messages_t1_list]
    # (test) Tp (previous QA pairs, consecutive)
    if data_split == "train":
        merge_test_messages_tp_list = generate_queries_list.t_Tp(data, sel_idx_list, start_idx=5000, data_split=data_split)
    else:
        assert data_split == "validation"
        merge_test_messages_tp_list = generate_queries_list.t_Tp(data, sel_idx_list)
    # (test) Tc (ideally, can insert anywhere)
    if data_split == "train":
        test_queries_correct_list_ensemble = generate_queries_list.t_Tc_ensemble(data, sel_idx_list, templates_list_daily_dialog, mturk_rewrite_list, start_idx=5000, data_split=data_split)
    else:
        assert data_split == "validation"
        test_queries_correct_list_ensemble = generate_queries_list.t_Tc_ensemble(data, sel_idx_list, templates_list_daily_dialog, mturk_rewrite_list)
    test_messages_correct_list_ensemble = [[[{"role": "user", "content": cur_query}] for cur_query in queries] for queries in test_queries_correct_list_ensemble]
    # play with ChatGPT and choose one reasonable response beforehand
    test_response_correct = [{"role": "assistant", "content": "No problem at all! I have updated my memory of the story with the correction you provided. Thank you for letting me know."}]
    merge_test_messages_tc_list_ensemble = [[j + test_response_correct for j in i] for i in test_messages_correct_list_ensemble]
    # (test) Ti
    if data_split == "train":
        test_queries_ti_list = generate_queries_list.t_Ti(data, sel_idx_list, start_idx=5000, data_split=data_split)
    else:
        assert data_split == "validation"
        test_queries_ti_list = generate_queries_list.t_Ti(data, sel_idx_list)
    test_messages_ti_list = [[{"role": "user", "content": queries}] for queries in test_queries_ti_list]
    # (test) Tv (seeking confirmation, or verification, e.g., "Really?")
    if data_split == "train":
        test_queries_tv_list = generate_queries_list.t_Tv(data, sel_idx_list, start_idx=5000, data_split=data_split)
    else:
        assert data_split == "validation"
        test_queries_tv_list = generate_queries_list.t_Tv(data, sel_idx_list)
    assert len(merge_test_messages_t1_list) == len(merge_test_messages_tc_list_ensemble) == len(merge_test_messages_tp_list) == len(test_messages_ti_list) == len(test_queries_tv_list) #== len(list(filter(lambda x: x != -1, sel_idx_list)))
    assert all(len(i) == len(templates_list_daily_dialog) for i in merge_test_messages_tc_list_ensemble)
    # run experiment
    # combination: (1) "T1, Tc, Tp, Ti" + "Tv" (2) "T1, Tp, Tc, Ti" + "Tv"
    for method in ["t1-tc-tp-ti-tv", "t1-tp-tc-ti-tv"]:
        if method == "t1-tc-tp-ti-tv" and resume_setting == "t1-tp-tc-ti-tv" and resume_correction_templates >= 0: # newly added (09/09/2024)
            continue # newly added (09/09/2024)
        print(f"Method {method}")
        for i in range(len(templates_list_daily_dialog)):
            if resume_correction_templates and i < resume_correction_templates: # newly added (09/09/2024)
                continue # newly added (09/09/2024)
            else: # newly added (09/09/2024)
                # we set this < 0 so that we will not skip the rest...
                resume_correction_templates = -1 # newly added (09/09/2024)
            if method == "t1-tc-tp-ti-tv":
                test_response_ti_text_list = torch.load(f"{pt_file_path}/coqa_{data_split}_yes_no_mturk_rewrite_{rewrite_i}_c{i}_p_chatgpt.pt")
            else:
                assert method == "t1-tp-tc-ti-tv"
                test_response_ti_text_list = torch.load(f"{pt_file_path}/coqa_{data_split}_yes_no_mturk_rewrite_{rewrite_i}_p_c{i}_chatgpt.pt")
            # (test) Ti
            test_response_ti_list = [[{"role": "assistant", "content": _}] for _ in test_response_ti_text_list]
            merge_test_messages_ti_list = [test_msg_ti + test_response_ti for test_msg_ti, test_response_ti in zip(test_messages_ti_list, test_response_ti_list)]
            if not extract_answer:
                if method == "t1-tc-tp-ti-tv":
                    prev_messages_list = [t1 + tc[i] + tp + ti for t1, tc, tp, ti in zip(merge_test_messages_t1_list, merge_test_messages_tc_list_ensemble, merge_test_messages_tp_list, merge_test_messages_ti_list)]
                else:
                    assert method == "t1-tp-tc-ti-tv"
                    prev_messages_list = [t1 + tp + tc[i] + ti for t1, tc, tp, ti in zip(merge_test_messages_t1_list, merge_test_messages_tc_list_ensemble, merge_test_messages_tp_list, merge_test_messages_ti_list)]
                final_merge_messages_list = [prev + [{"role": "user", "content": queries}] for prev, queries in zip(prev_messages_list, test_queries_tv_list)]
                if 'gpt' in model:
                    test_final_response_list = get_response_list_from_chatgpt(final_merge_messages_list, model=model)
                    test_final_response_text_list = [response.choices[0].message.content for response in test_final_response_list]
                    #test_final_response_text_list = [response['choices'][0]['message']['content'] for response in test_final_response_list]
                elif 'vicuna' in model or 'Llama-2' in model: # use FastChat
                    save_file_name = determine_saved_file_name_for_vicuna_and_llama_2(data_split, rewrite_i, i, method, extract_answer)
                    assert save_file_name
                    if 'vicuna' in model:
                        get_response_list_from_vicuna(final_merge_messages_list, model=model, save_file_name=save_file_name, num_gpus=num_gpus, max_gpu_memory=max_gpu_memory)
                    else:
                        assert 'Llama-2' in model
                        get_response_list_from_llama_2(final_merge_messages_list, model=model, save_file_name=save_file_name, num_gpus=num_gpus, max_gpu_memory=max_gpu_memory)
                    continue # we do NOT save file in main.py, as they will be saved by huggingface_api.py (which should be located under FastChat/fastchat/serve/)
                elif 'Llama-3' in model:
                    # In here, model is a real, large model, while model_id is the name of the model
                    test_final_response_text_list = get_response_list_from_llama_3(final_merge_messages_list, model_id=model, tokenizer=hf_llama_3_tokenizer, model=hf_llama_3_model)
                    #test_final_response_list = test_final_response_text_list # dummy, used for save_data, or delete log file in save_data tho (they're primary used for GPT-3.5 but I did not even look at them lol)
                else:
                    assert 'gemma' in model
                    # In here, model is a real, large model, while model_id is the name of the model
                    test_final_response_text_list = get_response_list_from_gemma_2(final_merge_messages_list, model_id=model, tokenizer=hf_gemma_2_tokenizer, model=hf_gemma_2_model)
                    #test_final_response_list = test_final_response_text_list # dummy, used for save_data, or delete log file in save_data tho (they're primary used for GPT-3.5 but I did not even look at them lol)
                if method == "t1-tc-tp-ti-tv":
                    save_data([#{"data": str(test_final_response_list), "name": f"{log_file_path}/coqa-{data_split}-yes-no-mturk-rewrite-{rewrite_i}-c{i}-p-i-v-chatgpt.log"},
                               {"data": test_final_response_text_list, "name": f"{pt_file_path}/coqa_{data_split}_yes_no_mturk_rewrite_{rewrite_i}_c{i}_p_i_v_chatgpt.pt"}
                              ])
                else:
                    assert method == "t1-tp-tc-ti-tv"
                    save_data([#{"data": str(test_final_response_list), "name": f"{log_file_path}/coqa-{data_split}-yes-no-mturk-rewrite-{rewrite_i}-p-c{i}-i-v-chatgpt.log"},
                               {"data": test_final_response_text_list, "name": f"{pt_file_path}/coqa_{data_split}_yes_no_mturk_rewrite_{rewrite_i}_p_c{i}_i_v_chatgpt.pt"}
                              ])
            else:
                assert extract_answer
                if method == "t1-tc-tp-ti-tv":
                    test_response_tv_text_list = torch.load(f"{pt_file_path}/coqa_{data_split}_yes_no_mturk_rewrite_{rewrite_i}_c{i}_p_i_v_chatgpt.pt")
                    print(f"Successfully load {pt_file_path}/coqa_{data_split}_yes_no_mturk_rewrite_{rewrite_i}_c{i}_p_i_v_chatgpt.pt")
                else:
                    assert method == "t1-tp-tc-ti-tv"
                    test_response_tv_text_list = torch.load(f"{pt_file_path}/coqa_{data_split}_yes_no_mturk_rewrite_{rewrite_i}_p_c{i}_i_v_chatgpt.pt")
                    print(f"Successfully load {pt_file_path}/coqa_{data_split}_yes_no_mturk_rewrite_{rewrite_i}_p_c{i}_i_v_chatgpt.pt")
                # (test) Tv (seeking confirmation, or verification, e.g., "Really?")
                test_messages_tv_list = [[{"role": "user", "content": queries}] for queries in test_queries_tv_list]
                test_response_tv_list = [[{"role": "assistant", "content": _}] for _ in test_response_tv_text_list]
                merge_test_messages_tv_list = [test_msg_tv + test_response_tv for test_msg_tv, test_response_tv in zip(test_messages_tv_list, test_response_tv_list)]
                # (test) Tv_2 (2nd-stage prompting, tell ChatGPT to classify and output the answer that starts with Yes/No/Unknown)
                if data_split == "train":
                    test_queries_tv_2_list = generate_queries_list.t_Tv_2(data, sel_idx_list, start_idx=5000, data_split=data_split)
                else:
                    assert data_split == "validation"
                    test_queries_tv_2_list = generate_queries_list.t_Tv_2(data, sel_idx_list)
                if method == "t1-tc-tp-ti-tv":
                    prev_messages_list = [t1 + tc[i] + tp + ti + tv for t1, tc, tp, ti, tv in zip(merge_test_messages_t1_list, merge_test_messages_tc_list_ensemble, merge_test_messages_tp_list, merge_test_messages_ti_list, merge_test_messages_tv_list)]
                else:
                    assert method == "t1-tp-tc-ti-tv"
                    prev_messages_list = [t1 + tp + tc[i] + ti + tv for t1, tp, tc, ti, tv in zip(merge_test_messages_t1_list, merge_test_messages_tp_list, merge_test_messages_tc_list_ensemble, merge_test_messages_ti_list, merge_test_messages_tv_list)]
                final_merge_messages_list = [prev + [{"role": "user", "content": queries}] for prev, queries in zip(prev_messages_list, test_queries_tv_2_list)]
                if 'gpt' in model:
                    test_final_response_list = get_response_list_from_chatgpt(final_merge_messages_list, model=model)
                    test_final_response_text_list = [response.choices[0].message.content for response in test_final_response_list]
                    #test_final_response_text_list = [response['choices'][0]['message']['content'] for response in test_final_response_list]
                elif 'vicuna' in model or 'Llama-2' in model: # use FastChat
                    save_file_name = determine_saved_file_name_for_vicuna_and_llama_2(data_split, rewrite_i, i, method, extract_answer)
                    assert save_file_name
                    if 'vicuna' in model:
                        get_response_list_from_vicuna(final_merge_messages_list, model=model, save_file_name=save_file_name, num_gpus=num_gpus, max_gpu_memory=max_gpu_memory)
                    else:
                        assert 'Llama-2' in model
                        get_response_list_from_llama_2(final_merge_messages_list, model=model, save_file_name=save_file_name, num_gpus=num_gpus, max_gpu_memory=max_gpu_memory)
                    continue # we do NOT save file in main.py, as they will be saved by huggingface_api.py (which should be located under FastChat/fastchat/serve/)
                elif 'Llama-3' in model:
                    # In here, model is a real, large model, while model_id is the name of the model
                    test_final_response_text_list = get_response_list_from_llama_3(final_merge_messages_list, model_id=model, tokenizer=hf_llama_3_tokenizer, model=hf_llama_3_model)
                    #test_final_response_list = test_final_response_text_list # dummy, used for save_data, or delete log file in save_data tho (they're primary used for GPT-3.5 but I did not even look at them lol)
                else:
                    assert 'gemma' in model
                    # In here, model is a real, large model, while model_id is the name of the model
                    test_final_response_text_list = get_response_list_from_gemma_2(final_merge_messages_list, model_id=model, tokenizer=hf_gemma_2_tokenizer, model=hf_gemma_2_model)
                    #test_final_response_list = test_final_response_text_list # dummy, used for save_data, or delete log file in save_data tho (they're primary used for GPT-3.5 but I did not even look at them lol)
                if method == "t1-tc-tp-ti-tv":
                    save_data([#{"data": str(test_final_response_list), "name": f"{log_file_path}/coqa-{data_split}-yes-no-mturk-rewrite-{rewrite_i}-c{i}-p-i-v-v2-chatgpt.log"},
                               {"data": test_final_response_text_list, "name": f"{pt_file_path}/coqa_{data_split}_yes_no_mturk_rewrite_{rewrite_i}_c{i}_p_i_v_v2_chatgpt.pt"}
                              ])
                else:
                    assert method == "t1-tp-tc-ti-tv"
                    save_data([#{"data": str(test_final_response_list), "name": f"{log_file_path}/coqa-{data_split}-yes-no-mturk-rewrite-{rewrite_i}-p-c{i}-i-v-v2-chatgpt.log"},
                               {"data": test_final_response_text_list, "name": f"{pt_file_path}/coqa_{data_split}_yes_no_mturk_rewrite_{rewrite_i}_p_c{i}_i_v_v2_chatgpt.pt"}
                              ])