import argparse
import random
import os

import pandas as pd
import torch
from tqdm import tqdm
from statistics import stdev

from datasets import load_dataset

from set_seed import set_seed
from generate_yes_no_index_list import generate_yes_no_index_list
from get_various_story_attr import get_various_story_attr
from get_templates_list_from_daily_dialog import get_templates_list_from_daily_dialog
from load_result_pt import load_result_pt
from compare_old_new_answer import compare_old_new_answer
from update_answer_list import update_answer_list

####################################################
def parse_argument():
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument('--data_split', required=True, choices=['train', 'validation', 'all'], help='Choose training set or validation set in CoQA dataset, all means we evaluate both (note that setting all flag will need to ensure report_concise is set to True and not load_from_scratch (too slow))')
    parser.add_argument('--model', choices=['gpt-3.5-turbo-0301', 'gpt-3.5-turbo-0613', 'gpt-3.5-turbo-1106', 'gpt-3.5-turbo-0125', 'gpt-4-1106-preview', 'gpt-4o-2024-08-06', 'gpt-4o-mini-2024-07-18', 'gemma-2-2b-it', 'gemma-2-9b-it', 'gemma-2-27b-it', 'vicuna-7b-v1.5-16k', 'vicuna-13b-v1.5-16k', 'vicuna-33b-v1.3', 'Llama-2-7b-chat-hf', 'Llama-2-13b-chat-hf', 'Meta-Llama-3-8B-Instruct'], help='specify the model, note that some models in choices may be deprecated in the future, now support models other than GPT') # newly added
    #parser.add_argument('--update_answer', action='store_true', help='Whether to update answer in train/validation answer (ensure that the gold answers will start with Yes/No)')
    parser.add_argument('--rewrite_i', type=int, choices=range(0, 3), help='Choose 1st, 2nd, or 3rd response of MTurk, if not specified, will use ChatGPT response')
    parser.add_argument('--is_trim', action='store_true', help='Whether to trim story')
    parser.add_argument('--do_ablation', action='store_true', help='Whether to analyze ablation results')
    parser.add_argument('--single_turn', action='store_true', help='Whether to squeeze multi-turn dialogues into single-turn (for comparison)')
    parser.add_argument('--extract_answer', action='store_true', help='Whether to convert answer to Yes/No, note that we only use it when model is GPT-4') # newly added
    parser.add_argument('--insert_new_info', action='store_true', help='Whether to insert new information into the story') # newly added
    parser.add_argument('--exclude_tc', action='store_true', help='Whether to insert new information into the story AND exclude correction phase') # newly added
    parser.add_argument('--report_concise', action='store_true', help='Whether to report the average performance in three run. Will overwrite --rewrite_i flag by running it three times.') # newly added
    parser.add_argument('--load_from_scratch', action='store_true', help='Whether to load the CoQA train/validation from scratch (note that loading training data from scratch will take around 2hr 20 mins...)')
    parser.add_argument('--extract_last_also', action='store_true', help='Whether to extract the last token if we cannot decide the answer is different in the first token (if an LLM response is ... so the answer to the last question is "No." <- this is frequently found in Llama-2-7b-chat-hf) [DEFAULT SHALL BE FALSE AS WE DO NOT SET THIS IN OUR PREVIOUS RESULTS]')
    args = parser.parse_args()
    return args

args = parse_argument()
data_split = args.data_split # "train", "validation"
model = args.model # newly added
rewrite_i = args.rewrite_i
update_answer = True # args.update_answer
is_trim = args.is_trim
do_ablation = args.do_ablation
single_turn = args.single_turn
extract_answer = args.extract_answer  # newly added
insert_new_info = args.insert_new_info # newly added
exclude_tc = args.exclude_tc # newly added
report_concise = args.report_concise # newly added, can save time for analysis...
load_from_scratch = args.load_from_scratch # newly added, can save time for analysis...
extract_last_also = args.extract_last_also # DEFAULT SHALL BE FALSE AS WE DO NOT SET THIS IN OUR PREVIOUS RESULTS, you should only set this for Llama-2-7b-chat-hf in Verification or Recall

if data_split == "all":
    assert report_concise is True and not load_from_scratch

if extract_last_also: # newly added
    _ = input("Setting extract_last_also if the first token does not start with Yes/No... Note that DEFAULT SHALL BE FALSE AS WE DO NOT SET THIS IN OUR PREVIOUS RESULTS, you should only set this for Llama-2-7b-chat-hf in Verification or Recall to replicate our other experimental results...\nPress any key to continue...") # newly added

if exclude_tc: # newly added
    assert insert_new_info

# the followings are hard-coded, definitely can try different combination, but I just 'fix' it.
if insert_new_info: # newly added
    assert data_split == "validation" # DO NOT support training set for now, since I have not modified t_T1_modified()
    assert not do_ablation and not is_trim and not single_turn

if not model:  # newly added
    model = "gpt-3.5-turbo-0613" # original configuration  # newly added

if single_turn:
    assert not do_ablation and not is_trim

assert not (do_ablation and is_trim) # Must be either of them is true, or none of them, but not both

if data_split == "train": # the code is added because of the evaluation of training data
    start_idx = 5000 # the code is added because of the evaluation of training data


#log_file_path = "../data/log"
#pt_file_path = "../data/pt"

if any(_ for _ in [is_trim, do_ablation, single_turn, insert_new_info, exclude_tc]): # newly added
    pt_file_path = f"../data/pt/{model}/{data_split}/Tc_ablation" # newly added
else: # newly added
    pt_file_path = f"../data/pt/{model}/{data_split}/Tc" # newly added

if data_split != "all":
    assert os.path.exists(pt_file_path) # newly added

print(f"Evaluating the {model} model baseline (OTC) performance on KEIC {data_split}....") # newly added
print(f"Other configs: is_trim = {is_trim}, do_ablation = {do_ablation}, single_turn = {single_turn}, extract_answer = {extract_answer}, insert_new_info = {insert_new_info}, exclude_tc = {exclude_tc}") # newly added

#################### start load_from_scratch ####################

if load_from_scratch:
    ####################################################

    _d_coqa = load_dataset('coqa')
    data = _d_coqa[data_split]


    yn_qa_idx_list = generate_yes_no_index_list(data)
    set_seed(0)
    sel_idx_list = [random.choice(i) if i else -1 for i in yn_qa_idx_list]

    ####################################################
    # start error analysis

    question_list = []
    answer_list = []
    old_sent_list = []
    #new_sent_list = []
    cnt = 0
    # Since we have to update answer, we cannot add the code here or the index will not align...
    #cnt_start_idx = 0  # the code is added because of the evaluation of training data
    for i in tqdm(range(data.num_rows)):
        if sel_idx_list[i] == -1:
            continue
        # Since we have to update answer, we cannot add the code here or the index will not align...
        #if start_idx and data_split == "train" and cnt_start_idx < start_idx:  # the code is added because of the evaluation of training data
        #    cnt_start_idx += 1  # the code is added because of the evaluation of training data
        #    continue  # the code is added because of the evaluation of training data
        _, question, answer, _, _, support_sent = get_various_story_attr(data, i, sel_idx_list[i])
        question_list.append(question)
        answer_list.append(answer)
        old_sent_list.append(support_sent)
        #new_sent_list.append(response_t3_text_list[cnt])
        cnt += 1

    assert cnt == len(question_list) == len(answer_list) == len(old_sent_list) #== len(new_sent_list)

    # need to update answer_list since some may not start with yes/no
    # Update: not necessary
    #for i,v in enumerate(answer_list):
    #    if 'yes' != v[:3].lower() and 'no' != v[:2].lower():
    #        print(f"Q{i}: {question_list[i]}")
    #        print(f"A{i}: {v}")


    #original_idx_list = [i for i in range(len(sel_idx_list)) if sel_idx_list[i] >= 0]
    # if you want to view the corresponding story, use data['story'][original_idx_list[i]]

    # Note: does not have effective result
    if update_answer:
        answer_list = update_answer_list(data_split, answer_list)
        print('Answer updated!')

    if data_split == "train": # the code is added because of the evaluation of training data
        question_list = question_list[start_idx:] # the code is added because of the evaluation of training data
        answer_list = answer_list[start_idx:] # the code is added because of the evaluation of training data
        old_sent_list = old_sent_list[start_idx:] # the code is added because of the evaluation of training data
        cnt = 1317 # the code is added because of the evaluation of training data
        assert cnt == len(question_list) == len(answer_list) == len(old_sent_list) # the code is added because of the evaluation of training data

    ##### for testing #####
    if data_split == "train":
        pre_processed_coqa_data = torch.load(f"pre_processed_coqa_{data_split}_data_1317_for_evaluation.pt")
    else:
        pre_processed_coqa_data = torch.load(f"pre_processed_coqa_{data_split}_data_464_for_evaluation.pt")
    assert question_list == pre_processed_coqa_data["question_list"]
    assert answer_list == pre_processed_coqa_data["answer_list"]
    assert old_sent_list == pre_processed_coqa_data["old_sent_list"]
    assert cnt == pre_processed_coqa_data["cnt"]
    print("After comparison, they are the same...")

#################### end load_from_scratch ####################

else:
    assert not load_from_scratch
    if data_split == "all":
        pre_processed_coqa_data_train = torch.load(f"pre_processed_coqa_train_data_1317_for_evaluation.pt")
        question_list = pre_processed_coqa_data_train["question_list"]
        answer_list = pre_processed_coqa_data_train["answer_list"]
        old_sent_list = pre_processed_coqa_data_train["old_sent_list"]
        cnt = pre_processed_coqa_data_train["cnt"]
        pre_processed_coqa_data_validation = torch.load(f"pre_processed_coqa_validation_data_464_for_evaluation.pt")
        question_list.extend(pre_processed_coqa_data_validation["question_list"])
        answer_list.extend(pre_processed_coqa_data_validation["answer_list"])
        old_sent_list.extend(pre_processed_coqa_data_validation["old_sent_list"])
        cnt += pre_processed_coqa_data_validation["cnt"]
        assert len(question_list) == len(answer_list) == len(old_sent_list) == cnt == 1781
        print("Successfully load from pre-processed CoQA data!")
    else:
        assert data_split != "all"
        if data_split == "train":
            pre_processed_coqa_data = torch.load(f"pre_processed_coqa_{data_split}_data_1317_for_evaluation.pt")
        else:
            assert data_split == "validation"
            pre_processed_coqa_data = torch.load(f"pre_processed_coqa_{data_split}_data_464_for_evaluation.pt")
        question_list = pre_processed_coqa_data["question_list"]
        answer_list = pre_processed_coqa_data["answer_list"]
        old_sent_list = pre_processed_coqa_data["old_sent_list"]
        cnt = pre_processed_coqa_data["cnt"]
        print("Successfully load from pre-processed CoQA data!")


templates_list_daily_dialog = get_templates_list_from_daily_dialog()


concise_result = {"t1-tc-tp-ti": {_: [] for _ in [1, 3, 5, len(templates_list_daily_dialog)]}, "t1-tp-tc-ti": {_: [] for _ in [1, 3, 5, len(templates_list_daily_dialog)]}} # newly added, will only print this if report_concise is set

def print_top_k_result(method, diff_result, df, k, n):
    f = lambda x: round((x/n)*100, 2)
    print(f"top {k}:")
    top_k_idx_list = df.index.tolist()[:k]
    tmp = [diff_result[method][_] for _ in top_k_idx_list]
    upper_bound = len(list(filter(lambda x: x > 0, list(map(lambda x: x.count(1), zip(*tmp))))))
    vote = list(map(sum, zip(*tmp))) # majority voting
    pos = len(list(filter(lambda x: x > 0, vote)))
    neg = len(list(filter(lambda x: x < 0, vote)))
    zero = len(list(filter(lambda x: x == 0, vote)))
    print(f"upper bound: {upper_bound} ({f(upper_bound)}%)")
    print(f"(+,-,0) = ({pos} ({f(pos)}%), {neg} ({f(neg)}%), {zero} ({f(zero)}%))")
    concise_result[method][k].append([f(pos), f(neg), round(100-f(pos)-f(neg), 2), f(upper_bound)]) # newly added, it is the GLOBAL VARIABLE IN THIS FILE
    assert isinstance(f(pos), float) and isinstance(f(neg), float) and isinstance(f(zero), float) # newly added

def do_error_analysis(method, diff_result, templates_list_daily_dialog, n):
    assert method in ["t1-tc-tp-ti", "t1-tp-tc-ti"]
    cnt_pos = []
    cnt_neg = []
    cnt_zero = []
    for i in range(len(templates_list_daily_dialog)):
        cnt_pos.append(diff_result[method][i].count(1))
        cnt_neg.append(diff_result[method][i].count(-1))
        cnt_zero.append(diff_result[method][i].count(0))
    print("method:", method)
    df = pd.DataFrame({'+1': cnt_pos, '-1': cnt_neg, '0': cnt_zero}).sort_values(by='+1', ascending=False)
    print(df)
    for k in [1, 3, 5, len(templates_list_daily_dialog)]:
        print_top_k_result(method, diff_result, df, k, n)

#################### start report_concise ####################
if report_concise:
    for rewrite_i in [0, 1, 2]:
        if data_split in ["train", "validation"]:
            dd_result = load_result_pt(templates_list_daily_dialog, data_split, pt_file_path, rewrite_i, 
                                       is_trim=is_trim, 
                                       do_ablation=do_ablation, 
                                       single_turn=single_turn, 
                                       extract_answer=extract_answer, 
                                       insert_new_info=insert_new_info,
                                       exclude_tc=exclude_tc)
        else:
            assert data_split == "all"
            dd_result = load_result_pt(templates_list_daily_dialog, "train", f"../data/pt/{model}/train/Tc" if all(not _ for _ in [is_trim, do_ablation, single_turn, insert_new_info, exclude_tc]) else f"../data/pt/{model}/train/Tc_ablation", rewrite_i, 
                                       is_trim=is_trim, 
                                       do_ablation=do_ablation, 
                                       single_turn=single_turn, 
                                       extract_answer=extract_answer, 
                                       insert_new_info=insert_new_info,
                                       exclude_tc=exclude_tc)
            dd_result_validation = load_result_pt(templates_list_daily_dialog, "validation", f"../data/pt/{model}/validation/Tc" if all(not _ for _ in [is_trim, do_ablation, single_turn, insert_new_info, exclude_tc]) else f"../data/pt/{model}/validation/Tc_ablation", rewrite_i, 
                                       is_trim=is_trim, 
                                       do_ablation=do_ablation, 
                                       single_turn=single_turn, 
                                       extract_answer=extract_answer, 
                                       insert_new_info=insert_new_info,
                                       exclude_tc=exclude_tc)
            for method in ["t1-tc-tp-ti", "t1-tp-tc-ti"]:
                for i in range(len(templates_list_daily_dialog)):
                    dd_result[method][i].extend(dd_result_validation[method][i])
        diff_result = {
                    "t1-tc-tp-ti": [[] for _ in range(len(templates_list_daily_dialog))],
                    "t1-tp-tc-ti": [[] for _ in range(len(templates_list_daily_dialog))]
        }

        for method in ["t1-tc-tp-ti", "t1-tp-tc-ti"]:
            for i in range(len(templates_list_daily_dialog)):
                cur_new_ans_list = dd_result[method][i]
                assert len(cur_new_ans_list) == len(answer_list)
                for old_a, new_a in zip(answer_list, cur_new_ans_list):
                    if not extract_last_also: # newly added
                        diff_result[method][i].append(compare_old_new_answer(old_a, new_a)) # old method
                    else: # newly added
                        assert extract_last_also # newly added
                        if compare_old_new_answer(old_a, new_a) == 0: # newly added
                            diff_result[method][i].append(compare_old_new_answer(old_a, new_a.split(' ')[-1])) # newly added
                        else: # newly added
                            diff_result[method][i].append(compare_old_new_answer(old_a, new_a)) # newly added
            do_error_analysis(method, diff_result, templates_list_daily_dialog, cnt)
            print("===============")
    print(f"FINAL RAW RESULT: {concise_result}")
    print(f"MODEL: {model}")
    print("AVERAGED RESULT FOR TABLE:")
    for method in concise_result:
        df = pd.DataFrame(columns=['top-k', 'update', 'no update', 'unknown', 'upper bound'])
        print(f"METHOD: {method}")
        for k in concise_result[method]:
            assert len(concise_result[method][k]) == 3 and all(len(_) == 4 for _ in concise_result[method][k]) # update, no update, unknown, upper bound
            update, no_update, unknown, upper_bound = [round(sum(values)/len(values), 2) for values in zip(*concise_result[method][k])]
            unknown = round(100.0 - update - no_update, 2)
            assert round(update + no_update + unknown, 2) == 100.00
            update_std, no_update_std, unknown_std, upper_bound_std = [round(stdev(values), 2) for values in zip(*concise_result[method][k])]
            #print(f"top-{k}: {update}, {no_update}, {unknown}")
            df.loc[len(df)] = pd.Series([int(k), f"{update:05.2f}({update_std:04.2f})", f"{no_update:05.2f}({no_update_std:04.2f})", f"{unknown:05.2f}({unknown_std:04.2f})", f"{upper_bound:05.2f}({upper_bound_std:04.2f})"], index=df.columns)
            #df = df.append(pd.Series([int(k), f"{update:05.2f}({update_std:04.2f})", f"{no_update:05.2f}({no_update_std:04.2f})", f"{unknown:05.2f}({unknown_std:04.2f})", f"{upper_bound:05.2f}({upper_bound_std:04.2f})"], index=df.columns), ignore_index=True)
        print(df)


#################### end report_concise ####################

else:
    assert not report_concise # make sure it does not go here...
    assert data_split != "all" # make sure it does not go here...

    dd_result = load_result_pt(templates_list_daily_dialog, data_split, pt_file_path, rewrite_i, 
                               is_trim=is_trim, 
                               do_ablation=do_ablation, 
                               single_turn=single_turn, 
                               extract_answer=extract_answer, 
                               insert_new_info=insert_new_info,
                               exclude_tc=exclude_tc)

    diff_result = {
                "t1-tc-tp-ti": [[] for _ in range(len(templates_list_daily_dialog))],
                "t1-tp-tc-ti": [[] for _ in range(len(templates_list_daily_dialog))]
    }

    for method in ["t1-tc-tp-ti", "t1-tp-tc-ti"]:
        for i in range(len(templates_list_daily_dialog)):
            cur_new_ans_list = dd_result[method][i]
            assert len(cur_new_ans_list) == len(answer_list)
            for old_a, new_a in zip(answer_list, cur_new_ans_list):
                if not extract_last_also: # newly added
                    diff_result[method][i].append(compare_old_new_answer(old_a, new_a)) # old method
                else: # newly added
                    assert extract_last_also # newly added
                    if compare_old_new_answer(old_a, new_a) == 0: # newly added
                        diff_result[method][i].append(compare_old_new_answer(old_a, new_a.split(' ')[-1])) # newly added
                    else: # newly added
                        diff_result[method][i].append(compare_old_new_answer(old_a, new_a)) # newly added


    for method in ["t1-tc-tp-ti", "t1-tp-tc-ti"]:
        do_error_analysis(method, diff_result, templates_list_daily_dialog, cnt)
        print("===============")
