import argparse
import random
import os

import pandas as pd
import torch
from tqdm import tqdm
from statistics import stdev

from datasets import load_dataset

from set_seed import set_seed
from generate_yes_no_index_list import generate_yes_no_index_list
from get_various_story_attr import get_various_story_attr
from get_templates_list_from_daily_dialog import get_templates_list_from_daily_dialog
#from load_result_pt import load_result_pt # this is for baseline data only
from compare_old_new_answer import compare_old_new_answer
from update_answer_list import update_answer_list

####################################################
def parse_argument():
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument('--data_split', required=True, choices=['train', 'validation', 'all'], help='Choose training set or validation set in CoQA dataset, all means we evaluate both (note that setting all flag will need to ensure report_concise is set to True and not load_from_scratch (too slow))')
    parser.add_argument('--model', choices=['gpt-3.5-turbo-0613', 'gpt-3.5-turbo-0125', 'gpt-4o-mini-2024-07-18', 'gemma-2-2b-it', 'gemma-2-9b-it', 'gemma-2-27b-it', 'vicuna-7b-v1.5-16k', 'vicuna-13b-v1.5-16k', 'Llama-2-7b-chat-hf', 'Llama-2-13b-chat-hf', 'Meta-Llama-3-8B-Instruct'], help='specify the model, note that some models in choices may be deprecated in the future, now support models other than GPT') # newly added
    #parser.add_argument('--update_answer', action='store_true', help='Whether to update answer in train/validation answer (ensure that the gold answers will start with Yes/No)')
    parser.add_argument('--rewrite_i', type=int, choices=range(0, 3), help='Choose 1st, 2nd, or 3rd response of MTurk')
    parser.add_argument('--report_concise', action='store_true', help='Whether to report the average performance in three run. Will overwrite --rewrite_i flag by running it three times.') # newly added
    parser.add_argument('--load_from_scratch', action='store_true', help='Whether to load the CoQA train/validation from scratch (note that loading training data from scratch will take around 2hr 20 mins...)')
    parser.add_argument('--extract_last_also', action='store_true', help='Whether to extract the last token if we cannot decide the answer is different in the first token (if an LLM response is ... so the answer to the last question is "No." <- this is frequently found in Llama-2-7b-chat-hf) [DEFAULT SHALL BE FALSE AS WE DO NOT SET THIS IN OUR PREVIOUS RESULTS]')
    args = parser.parse_args()
    return args

args = parse_argument()
data_split = args.data_split # "train", "validation"
model = args.model # newly added
rewrite_i = args.rewrite_i
update_answer = True # args.update_answer
report_concise = args.report_concise # newly added, can save time for analysis...
load_from_scratch = args.load_from_scratch # newly added, can save time for analysis...
extract_last_also = args.extract_last_also # DEFAULT SHALL BE FALSE AS WE DO NOT SET THIS IN OUR PREVIOUS RESULTS, you should only set this for Llama-2-7b-chat-hf in Verification or Recall

if data_split == "all":
    assert report_concise is True and not load_from_scratch

if extract_last_also: # newly added
    _ = input("Setting extract_last_also if the first token does not start with Yes/No... Note that DEFAULT SHALL BE FALSE AS WE DO NOT SET THIS IN OUR PREVIOUS RESULTS, you should only set this for Llama-2-7b-chat-hf in Verification or Recall to replicate our other experimental results...\nPress any key to continue...") # newly added

if not report_concise:
    assert isinstance(rewrite_i, int)

if data_split == "train": # the code is added because of the evaluation of training data
    start_idx = 5000 # the code is added because of the evaluation of training data


pt_file_path = f"../data/pt/{model}/{data_split}/Tv" # newly added

if data_split != "all":
    assert os.path.exists(pt_file_path) # newly added

print(f"Evaluating the {model} model Verfication performance on KEIC {data_split}....") # newly added

#################### start load_from_scratch ####################

if load_from_scratch:
    ####################################################

    _d_coqa = load_dataset('coqa')
    data = _d_coqa[data_split]


    yn_qa_idx_list = generate_yes_no_index_list(data)
    set_seed(0)
    sel_idx_list = [random.choice(i) if i else -1 for i in yn_qa_idx_list]

    ####################################################
    # start error analysis

    question_list = []
    answer_list = []
    old_sent_list = []
    #new_sent_list = []
    cnt = 0
    for i in tqdm(range(data.num_rows)):
        if sel_idx_list[i] == -1:
            continue
        _, question, answer, _, _, support_sent = get_various_story_attr(data, i, sel_idx_list[i])
        question_list.append(question)
        answer_list.append(answer)
        old_sent_list.append(support_sent)
        #new_sent_list.append(response_t3_text_list[cnt])
        cnt += 1

    assert cnt == len(question_list) == len(answer_list) == len(old_sent_list) #== len(new_sent_list)

    # need to update answer_list since some may not start with yes/no
    # Update: not necessary
    #for i,v in enumerate(answer_list):
    #    if 'yes' != v[:3].lower() and 'no' != v[:2].lower():
    #        print(f"Q{i}: {question_list[i]}")
    #        print(f"A{i}: {v}")


    #original_idx_list = [i for i in range(len(sel_idx_list)) if sel_idx_list[i] >= 0]
    # if you want to view the corresponding story, use data['story'][original_idx_list[i]]

    # Note: does not have effective result
    if update_answer:
        answer_list = update_answer_list(data_split, answer_list)
        print('Answer updated!')

    if data_split == "train": # the code is added because of the evaluation of training data
        question_list = question_list[start_idx:] # the code is added because of the evaluation of training data
        answer_list = answer_list[start_idx:] # the code is added because of the evaluation of training data
        old_sent_list = old_sent_list[start_idx:] # the code is added because of the evaluation of training data
        cnt = 1317 # the code is added because of the evaluation of training data
        assert cnt == len(question_list) == len(answer_list) == len(old_sent_list) # the code is added because of the evaluation of training data

    ##### for testing #####
    if data_split == "train":
        pre_processed_coqa_data = torch.load(f"pre_processed_coqa_{data_split}_data_1317_for_evaluation.pt")
    else:
        pre_processed_coqa_data = torch.load(f"pre_processed_coqa_{data_split}_data_464_for_evaluation.pt")
    assert question_list == pre_processed_coqa_data["question_list"]
    assert answer_list == pre_processed_coqa_data["answer_list"]
    assert old_sent_list == pre_processed_coqa_data["old_sent_list"]
    assert cnt == pre_processed_coqa_data["cnt"]
    print("After comparison, they are the same...")

#################### end load_from_scratch ####################

else:
    assert not load_from_scratch
    if data_split == "all":
        pre_processed_coqa_data_train = torch.load(f"pre_processed_coqa_train_data_1317_for_evaluation.pt")
        question_list = pre_processed_coqa_data_train["question_list"]
        answer_list = pre_processed_coqa_data_train["answer_list"]
        old_sent_list = pre_processed_coqa_data_train["old_sent_list"]
        cnt = pre_processed_coqa_data_train["cnt"]
        pre_processed_coqa_data_validation = torch.load(f"pre_processed_coqa_validation_data_464_for_evaluation.pt")
        question_list.extend(pre_processed_coqa_data_validation["question_list"])
        answer_list.extend(pre_processed_coqa_data_validation["answer_list"])
        old_sent_list.extend(pre_processed_coqa_data_validation["old_sent_list"])
        cnt += pre_processed_coqa_data_validation["cnt"]
        assert len(question_list) == len(answer_list) == len(old_sent_list) == cnt == 1781
        print("Successfully load from pre-processed CoQA data!")
    else:
        assert data_split != "all"
        if data_split == "train":
            pre_processed_coqa_data = torch.load(f"pre_processed_coqa_{data_split}_data_1317_for_evaluation.pt")
        else:
            assert data_split == "validation"
            pre_processed_coqa_data = torch.load(f"pre_processed_coqa_{data_split}_data_464_for_evaluation.pt")
        question_list = pre_processed_coqa_data["question_list"]
        answer_list = pre_processed_coqa_data["answer_list"]
        old_sent_list = pre_processed_coqa_data["old_sent_list"]
        cnt = pre_processed_coqa_data["cnt"]
        print("Successfully load from pre-processed CoQA data!")


templates_list_daily_dialog = get_templates_list_from_daily_dialog()


concise_result = {"t1-tc-tp-ti-tv": {_: [] for _ in [1, 3, 5, len(templates_list_daily_dialog)]}, "t1-tp-tc-ti-tv": {_: [] for _ in [1, 3, 5, len(templates_list_daily_dialog)]}} # newly added, will only print this if report_concise is set

def load_result_pt(templates_list_daily_dialog, data_split, pt_file_path, rewrite_i):
    dd_result = {
                "t1-tc-tp-ti-tv": [[] for _ in range(len(templates_list_daily_dialog))],
                "t1-tp-tc-ti-tv": [[] for _ in range(len(templates_list_daily_dialog))]
                }
    for method in ["t1-tc-tp-ti-tv", "t1-tp-tc-ti-tv"]:
        for i in range(len(templates_list_daily_dialog)):
            if method == "t1-tc-tp-ti-tv":
                if rewrite_i in [0, 1, 2]:
                    data = torch.load(f'{pt_file_path}/coqa_{data_split}_yes_no_mturk_rewrite_{rewrite_i}_c{i}_p_i_v_v2_chatgpt.pt')
                else:
                    print("Should not be here...")
                    assert False
                dd_result[method][i] = data
            else:
                assert method == "t1-tp-tc-ti-tv"
                if rewrite_i in [0, 1, 2]:
                    data = torch.load(f'{pt_file_path}/coqa_{data_split}_yes_no_mturk_rewrite_{rewrite_i}_p_c{i}_i_v_v2_chatgpt.pt')
                else:
                    print("Should not be here...")
                    assert False
                dd_result[method][i] = data
    return dd_result

def print_top_k_result(method, diff_result, df, k, n):
    f = lambda x: round((x/n)*100, 2)
    print(f"top {k}:")
    top_k_idx_list = df.index.tolist()[:k]
    tmp = [diff_result[method][_] for _ in top_k_idx_list]
    upper_bound = len(list(filter(lambda x: x > 0, list(map(lambda x: x.count(1), zip(*tmp))))))
    vote = list(map(sum, zip(*tmp))) # majority voting
    pos = len(list(filter(lambda x: x > 0, vote)))
    neg = len(list(filter(lambda x: x < 0, vote)))
    zero = len(list(filter(lambda x: x == 0, vote)))
    print(f"upper bound: {upper_bound} ({f(upper_bound)}%)")
    print(f"(+,-,0) = ({pos} ({f(pos)}%), {neg} ({f(neg)}%), {zero} ({f(zero)}%))")
    concise_result[method][k].append([f(pos), f(neg), round(100-f(pos)-f(neg), 2), f(upper_bound)]) # newly added, it is the GLOBAL VARIABLE IN THIS FILE
    assert isinstance(f(pos), float) and isinstance(f(neg), float) and isinstance(f(zero), float) # newly added

def do_error_analysis(method, diff_result, templates_list_daily_dialog, n):
    assert method in ["t1-tc-tp-ti-tv", "t1-tp-tc-ti-tv"] #
    cnt_pos = []
    cnt_neg = []
    cnt_zero = []
    for i in range(len(templates_list_daily_dialog)):
        cnt_pos.append(diff_result[method][i].count(1))
        cnt_neg.append(diff_result[method][i].count(-1))
        cnt_zero.append(diff_result[method][i].count(0))
    print("method:", method)
    df = pd.DataFrame({'+1': cnt_pos, '-1': cnt_neg, '0': cnt_zero}).sort_values(by='+1', ascending=False)
    print(df)
    for k in [1, 3, 5, len(templates_list_daily_dialog)]:
        print_top_k_result(method, diff_result, df, k, n)

#################### start report_concise ####################
if report_concise:
    for rewrite_i in [0, 1, 2]:
        if data_split in ["train", "validation"]:
            dd_result = load_result_pt(templates_list_daily_dialog, data_split, pt_file_path, rewrite_i)
        else:
            assert data_split == "all"
            dd_result = load_result_pt(templates_list_daily_dialog, "train", f"../data/pt/{model}/train/Tv", rewrite_i)
            dd_result_validation = load_result_pt(templates_list_daily_dialog, "validation", f"../data/pt/{model}/validation/Tv", rewrite_i)
            for method in ["t1-tc-tp-ti-tv", "t1-tp-tc-ti-tv"]:
                for i in range(len(templates_list_daily_dialog)):
                    dd_result[method][i].extend(dd_result_validation[method][i])
        diff_result = {
                    "t1-tc-tp-ti-tv": [[] for _ in range(len(templates_list_daily_dialog))], #
                    "t1-tp-tc-ti-tv": [[] for _ in range(len(templates_list_daily_dialog))] #
        }

        for method in ["t1-tc-tp-ti-tv", "t1-tp-tc-ti-tv"]: #
            for i in range(len(templates_list_daily_dialog)):
                cur_new_ans_list = dd_result[method][i]
                assert len(cur_new_ans_list) == len(answer_list)
                for old_a, new_a in zip(answer_list, cur_new_ans_list):
                    if not extract_last_also: # newly added
                        diff_result[method][i].append(compare_old_new_answer(old_a, new_a)) # old method
                    else: # newly added
                        assert extract_last_also # newly added
                        if compare_old_new_answer(old_a, new_a) == 0: # newly added
                            diff_result[method][i].append(compare_old_new_answer(old_a, new_a.split(' ')[-1])) # newly added
                        else: # newly added
                            diff_result[method][i].append(compare_old_new_answer(old_a, new_a)) # newly added
            do_error_analysis(method, diff_result, templates_list_daily_dialog, cnt)
            print("===============")
    print(f"FINAL RAW RESULT: {concise_result}")
    print(f"MODEL: {model}")
    print("AVERAGED RESULT FOR TABLE:")
    for method in concise_result:
        df = pd.DataFrame(columns=['top-k', 'update', 'no update', 'unknown', 'upper bound'])
        print(f"METHOD: {method}")
        for k in concise_result[method]:
            assert len(concise_result[method][k]) == 3 and all(len(_) == 4 for _ in concise_result[method][k]) # update, no update, unknown, upper bound
            update, no_update, unknown, upper_bound = [round(sum(values)/len(values), 2) for values in zip(*concise_result[method][k])]
            unknown = round(100.0 - update - no_update, 2)
            assert round(update + no_update + unknown, 2) == 100.00
            update_std, no_update_std, unknown_std, upper_bound_std = [round(stdev(values), 2) for values in zip(*concise_result[method][k])]
            #print(f"top-{k}: {update}, {no_update}, {unknown}")
            df.loc[len(df)] = pd.Series([int(k), f"{update:05.2f}({update_std:04.2f})", f"{no_update:05.2f}({no_update_std:04.2f})", f"{unknown:05.2f}({unknown_std:04.2f})", f"{upper_bound:05.2f}({upper_bound_std:04.2f})"], index=df.columns)
            #df = df.append(pd.Series([int(k), f"{update:05.2f}({update_std:04.2f})", f"{no_update:05.2f}({no_update_std:04.2f})", f"{unknown:05.2f}({unknown_std:04.2f})", f"{upper_bound:05.2f}({upper_bound_std:04.2f})"], index=df.columns), ignore_index=True)
        print(df)


#################### end report_concise ####################

else:
    assert not report_concise
    dd_result = load_result_pt(templates_list_daily_dialog, data_split, pt_file_path, rewrite_i)
    diff_result = {
                "t1-tc-tp-ti-tv": [[] for _ in range(len(templates_list_daily_dialog))], #
                "t1-tp-tc-ti-tv": [[] for _ in range(len(templates_list_daily_dialog))] #
    }
    for method in ["t1-tc-tp-ti-tv", "t1-tp-tc-ti-tv"]: #
        for i in range(len(templates_list_daily_dialog)):
            cur_new_ans_list = dd_result[method][i]
            assert len(cur_new_ans_list) == len(answer_list)
            for old_a, new_a in zip(answer_list, cur_new_ans_list):
                if not extract_last_also: # newly added
                    diff_result[method][i].append(compare_old_new_answer(old_a, new_a)) # old method
                else: # newly added
                    assert extract_last_also # newly added
                    if compare_old_new_answer(old_a, new_a) == 0: # newly added
                        diff_result[method][i].append(compare_old_new_answer(old_a, new_a.split(' ')[-1])) # newly added
                    else: # newly added
                        diff_result[method][i].append(compare_old_new_answer(old_a, new_a)) # newly added
    for method in ["t1-tc-tp-ti-tv", "t1-tp-tc-ti-tv"]:
        do_error_analysis(method, diff_result, templates_list_daily_dialog, cnt)
        print("===============")
