import argparse
import random
import os

import pandas as pd
import torch
from tqdm import tqdm
from statistics import stdev

from datasets import load_dataset

from set_seed import set_seed
from generate_yes_no_index_list import generate_yes_no_index_list
from get_various_story_attr import get_various_story_attr
from get_templates_list_from_daily_dialog import get_templates_list_from_daily_dialog
#from load_result_pt import load_result_pt # this is for baseline data only
from compare_old_new_answer import compare_old_new_answer
from update_answer_list import update_answer_list

####################################################
def parse_argument():
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument('--data_split', required=True, choices=['train', 'validation', 'all'], help='Choose training set or validation set in CoQA dataset, all means we evaluate both (note that setting all flag will need to ensure report_concise is set to True and not load_from_scratch (too slow))')
    parser.add_argument('--model', choices=['gpt-3.5-turbo-0613', 'gpt-3.5-turbo-0125', 'gpt-4o-mini-2024-07-18', 'gemma-2-2b-it', 'gemma-2-9b-it', 'gemma-2-27b-it', 'vicuna-7b-v1.5-16k', 'vicuna-13b-v1.5-16k', 'vicuna-33b-v1.3', 'Llama-2-7b-chat-hf', 'Llama-2-13b-chat-hf', 'Meta-Llama-3-8B-Instruct'], help='specify the model, note that some models in choices may be deprecated in the future, now support models other than GPT') # newly added
    #parser.add_argument('--update_answer', action='store_true', help='Whether to update answer in train/validation answer (ensure that the gold answers will start with Yes/No)')
    parser.add_argument('--rewrite_i', type=int, choices=range(0, 3), help='Choose 1st, 2nd, or 3rd response of MTurk')
    #parser.add_argument('--is_trim', action='store_true', help='Whether to trim the story (so original support sentence is always at the end of the story)')
    #parser.add_argument('--instruct_update', action='store_true', help='Whether to trim the story (so original support sentence is always at the end of the story)')
    parser.add_argument('--run_oracle', action='store_true', help='Whether to evaluate the oracle performance of Recall (i.e., the old knowledge in the story is automatically replaced by us).')
    parser.add_argument('--report_concise', action='store_true', help='Whether to report the average performance in three run. Will overwrite --rewrite_i flag by running it three times.') # newly added
    parser.add_argument('--load_from_scratch', action='store_true', help='Whether to load the CoQA train/validation from scratch (note that loading training data from scratch will take around 2hr 20 mins...)')
    parser.add_argument('--extract_last_also', action='store_true', help='Whether to extract the last token if we cannot decide the answer is different in the first token (if an LLM response is ... so the answer to the last question is "No." <- this is frequently found in Llama-2-7b-chat-hf) [DEFAULT SHALL BE FALSE AS WE DO NOT SET THIS IN OUR PREVIOUS RESULTS]')
    args = parser.parse_args()
    return args

args = parse_argument()
data_split = args.data_split # "train", "validation"
model = args.model # newly added
rewrite_i = args.rewrite_i
update_answer = True # args.update_answer
#is_trim = args.is_trim
#instruct_update = args.instruct_update
run_oracle = args.run_oracle
report_concise = args.report_concise # newly added, can save time for analysis...
load_from_scratch = args.load_from_scratch # newly added, can save time for analysis...
extract_last_also = args.extract_last_also # DEFAULT SHALL BE FALSE AS WE DO NOT SET THIS IN OUR PREVIOUS RESULTS, you should only set this for Llama-2-7b-chat-hf in Verification or Recall

if data_split == "all":
    assert report_concise is True and not load_from_scratch

if extract_last_also: # newly added
    _ = input("Setting extract_last_also if the first token does not start with Yes/No... Note that DEFAULT SHALL BE FALSE AS WE DO NOT SET THIS IN OUR PREVIOUS RESULTS, you should only set this for Llama-2-7b-chat-hf in Verification or Recall to replicate our other experimental results...\nPress any key to continue...") # newly added

if not report_concise:
    assert isinstance(rewrite_i, int)

#is_trim = 1 if is_trim else 0
#instruct_update = 1 if instruct_update else 0

if data_split == "train": # the code is added because of the evaluation of training data
    start_idx = 5000 # the code is added because of the evaluation of training data

if not run_oracle:
    pt_file_path = f"../data/pt/{model}/{data_split}/Tr" # newly added
else:
    assert run_oracle
    pt_file_path = f"../data/pt/{model}/{data_split}/Tr_ablation" # newly added
if data_split != "all":
    assert os.path.exists(pt_file_path) # newly added

print(f"Evaluating the {model} model Recall performance on KEIC {data_split} (run_oracle is {run_oracle})....") # newly added
#print(f"Other configs: is_trim = {is_trim}") # newly added

#################### start load_from_scratch ####################

if load_from_scratch:
    ####################################################

    _d_coqa = load_dataset('coqa')
    data = _d_coqa[data_split]


    yn_qa_idx_list = generate_yes_no_index_list(data)
    set_seed(0)
    sel_idx_list = [random.choice(i) if i else -1 for i in yn_qa_idx_list]

    ####################################################
    # start error analysis

    question_list = []
    answer_list = []
    old_sent_list = []
    #new_sent_list = []
    cnt = 0
    for i in tqdm(range(data.num_rows)):
        if sel_idx_list[i] == -1:
            continue
        _, question, answer, _, _, support_sent = get_various_story_attr(data, i, sel_idx_list[i])
        question_list.append(question)
        answer_list.append(answer)
        old_sent_list.append(support_sent)
        #new_sent_list.append(response_t3_text_list[cnt])
        cnt += 1

    assert cnt == len(question_list) == len(answer_list) == len(old_sent_list) #== len(new_sent_list)

    # need to update answer_list since some may not start with yes/no
    # Update: not necessary
    #for i,v in enumerate(answer_list):
    #    if 'yes' != v[:3].lower() and 'no' != v[:2].lower():
    #        print(f"Q{i}: {question_list[i]}")
    #        print(f"A{i}: {v}")


    #original_idx_list = [i for i in range(len(sel_idx_list)) if sel_idx_list[i] >= 0]
    # if you want to view the corresponding story, use data['story'][original_idx_list[i]]

    # Note: does not have effective result
    if update_answer:
        answer_list = update_answer_list(data_split, answer_list)
        print('Answer updated!')

    if data_split == "train": # the code is added because of the evaluation of training data
        question_list = question_list[start_idx:] # the code is added because of the evaluation of training data
        answer_list = answer_list[start_idx:] # the code is added because of the evaluation of training data
        old_sent_list = old_sent_list[start_idx:] # the code is added because of the evaluation of training data
        cnt = 1317 # the code is added because of the evaluation of training data
        assert cnt == len(question_list) == len(answer_list) == len(old_sent_list) # the code is added because of the evaluation of training data

    ##### for testing #####
    if data_split == "train":
        pre_processed_coqa_data = torch.load(f"pre_processed_coqa_{data_split}_data_1317_for_evaluation.pt")
    else:
        pre_processed_coqa_data = torch.load(f"pre_processed_coqa_{data_split}_data_464_for_evaluation.pt")
    assert question_list == pre_processed_coqa_data["question_list"]
    assert answer_list == pre_processed_coqa_data["answer_list"]
    assert old_sent_list == pre_processed_coqa_data["old_sent_list"]
    assert cnt == pre_processed_coqa_data["cnt"]
    print("After comparison, they are the same...")

#################### end load_from_scratch ####################

else:
    assert not load_from_scratch
    if data_split == "all":
        pre_processed_coqa_data_train = torch.load(f"pre_processed_coqa_train_data_1317_for_evaluation.pt")
        question_list = pre_processed_coqa_data_train["question_list"]
        answer_list = pre_processed_coqa_data_train["answer_list"]
        old_sent_list = pre_processed_coqa_data_train["old_sent_list"]
        cnt = pre_processed_coqa_data_train["cnt"]
        pre_processed_coqa_data_validation = torch.load(f"pre_processed_coqa_validation_data_464_for_evaluation.pt")
        question_list.extend(pre_processed_coqa_data_validation["question_list"])
        answer_list.extend(pre_processed_coqa_data_validation["answer_list"])
        old_sent_list.extend(pre_processed_coqa_data_validation["old_sent_list"])
        cnt += pre_processed_coqa_data_validation["cnt"]
        assert len(question_list) == len(answer_list) == len(old_sent_list) == cnt == 1781
        print("Successfully load from pre-processed CoQA data!")
    else:
        assert data_split != "all"
        if data_split == "train":
            pre_processed_coqa_data = torch.load(f"pre_processed_coqa_{data_split}_data_1317_for_evaluation.pt")
        else:
            assert data_split == "validation"
            pre_processed_coqa_data = torch.load(f"pre_processed_coqa_{data_split}_data_464_for_evaluation.pt")
        question_list = pre_processed_coqa_data["question_list"]
        answer_list = pre_processed_coqa_data["answer_list"]
        old_sent_list = pre_processed_coqa_data["old_sent_list"]
        cnt = pre_processed_coqa_data["cnt"]
        print("Successfully load from pre-processed CoQA data!")


templates_list_daily_dialog = get_templates_list_from_daily_dialog()
top_k_idx = list(range(15)) # list(range(15)) # [7, 2, 1, 5, 8, 3] # [7, 2, 1] # test all in the future
if model == 'gpt-3.5-turbo-0613' and data_split in ['all', 'train']: # we only test these templates in gpt-3.5-turbo-0613 train data, need to overwrite
    top_k_idx = [7, 2, 1, 5, 8, 3]
templates_list_daily_dialog = [templates_list_daily_dialog[i] for i in top_k_idx]
assert len(top_k_idx) == len(templates_list_daily_dialog)


concise_result = {"t1-tc-tr-tp-ti": {_: [] for _ in [1, 3, 5, len(templates_list_daily_dialog)]}, "t1-tp-tc-tr-ti": {_: [] for _ in [1, 3, 5, len(templates_list_daily_dialog)]}} # newly added, will only print this if report_concise is set


def load_result_pt(templates_list_daily_dialog, data_split, pt_file_path, rewrite_i, run_oracle):
    dd_result = {
                "t1-tc-tr-tp-ti": [[] for _ in range(len(templates_list_daily_dialog))],
                "t1-tp-tc-tr-ti": [[] for _ in range(len(templates_list_daily_dialog))]
                }
    for method in ["t1-tc-tr-tp-ti", "t1-tp-tc-tr-ti"]:
        for i in range(len(templates_list_daily_dialog)):
            if method == "t1-tc-tr-tp-ti":
                if rewrite_i is None:
                    print("Should not be here...")
                    assert False
                    #data = torch.load(f'{pt_file_path}/coqa_{data_split}_yes_no_rewrite_c{i}_p_chatgpt.pt')
                elif rewrite_i in [0, 1, 2]:
                    if not run_oracle:
                        data = torch.load(f'{pt_file_path}/coqa_{data_split}_yes_no_mturk_rewrite_{rewrite_i}_c{top_k_idx[i]}_r_p_chatgpt.pt')
                    else:
                        data = torch.load(f'{pt_file_path}/coqa_{data_split}_yes_no_mturk_rewrite_{rewrite_i}_c{top_k_idx[i]}_r_p_oracle_chatgpt.pt')
                else:
                    print("Should not be here...")
                    assert False
                dd_result[method][i] = data
            else:
                assert method == "t1-tp-tc-tr-ti"
                if rewrite_i is None:
                    print("Should not be here...")
                    assert False
                    #data = torch.load(f'{pt_file_path}/coqa_{data_split}_yes_no_rewrite_p_c{i}_chatgpt.pt')
                elif rewrite_i in [0, 1, 2]:
                    if not run_oracle:
                        data = torch.load(f'{pt_file_path}/coqa_{data_split}_yes_no_mturk_rewrite_{rewrite_i}_p_c{top_k_idx[i]}_r_chatgpt.pt')
                    else:
                        data = torch.load(f'{pt_file_path}/coqa_{data_split}_yes_no_mturk_rewrite_{rewrite_i}_p_c{top_k_idx[i]}_r_oracle_chatgpt.pt')
                else:
                    print("Should not be here...")
                    assert False
                dd_result[method][i] = data
    return dd_result

def print_top_k_result(method, diff_result, df, k, n):
    f = lambda x: round((x/n)*100, 2)
    print(f"top {k}:")
    top_k_idx_list = df.index.tolist()[:k]
    tmp = [diff_result[method][_] for _ in top_k_idx_list]
    upper_bound = len(list(filter(lambda x: x > 0, list(map(lambda x: x.count(1), zip(*tmp))))))
    vote = list(map(sum, zip(*tmp))) # majority voting
    pos = len(list(filter(lambda x: x > 0, vote)))
    neg = len(list(filter(lambda x: x < 0, vote)))
    zero = len(list(filter(lambda x: x == 0, vote)))
    print(f"upper bound: {upper_bound} ({f(upper_bound)}%)")
    print(f"(+,-,0) = ({pos} ({f(pos)}%), {neg} ({f(neg)}%), {zero} ({f(zero)}%))")
    concise_result[method][k].append([f(pos), f(neg), round(100-f(pos)-f(neg), 2), f(upper_bound)]) # newly added, it is the GLOBAL VARIABLE IN THIS FILE
    assert isinstance(f(pos), float) and isinstance(f(neg), float) and isinstance(f(zero), float) # newly added

def do_error_analysis(method, diff_result, templates_list_daily_dialog, n):
    assert method in ["t1-tc-tr-tp-ti", "t1-tp-tc-tr-ti"] #
    cnt_pos = []
    cnt_neg = []
    cnt_zero = []
    for i in range(len(templates_list_daily_dialog)):
        cnt_pos.append(diff_result[method][i].count(1))
        cnt_neg.append(diff_result[method][i].count(-1))
        cnt_zero.append(diff_result[method][i].count(0))
    print("method:", method)
    df = pd.DataFrame({'+1': cnt_pos, '-1': cnt_neg, '0': cnt_zero, 'index': top_k_idx}).sort_values(by='+1', ascending=False) #
    print(df)
    for k in [1, 3, 5, len(templates_list_daily_dialog)]:
        print_top_k_result(method, diff_result, df, k, n)

#################### start report_concise ####################
if report_concise:
    for rewrite_i in [0, 1, 2]:
        if data_split in ["train", "validation"]:
            dd_result = load_result_pt(templates_list_daily_dialog, data_split, pt_file_path, rewrite_i, run_oracle)
        else:
            assert data_split == "all"
            if not run_oracle:
                dd_result = load_result_pt(templates_list_daily_dialog, "train", f"../data/pt/{model}/train/Tr", rewrite_i, run_oracle)
                dd_result_validation = load_result_pt(templates_list_daily_dialog, "validation", f"../data/pt/{model}/validation/Tr", rewrite_i, run_oracle)
            else:
                assert run_oracle
                dd_result = load_result_pt(templates_list_daily_dialog, "train", f"../data/pt/{model}/train/Tr_ablation", rewrite_i, run_oracle)
                dd_result_validation = load_result_pt(templates_list_daily_dialog, "validation", f"../data/pt/{model}/validation/Tr_ablation", rewrite_i, run_oracle)
            for method in ["t1-tc-tr-tp-ti", "t1-tp-tc-tr-ti"]:
                for i in range(len(templates_list_daily_dialog)):
                    dd_result[method][i].extend(dd_result_validation[method][i])
        diff_result = {
                    "t1-tc-tr-tp-ti": [[] for _ in range(len(templates_list_daily_dialog))], #
                    "t1-tp-tc-tr-ti": [[] for _ in range(len(templates_list_daily_dialog))] #
        }

        for method in ["t1-tc-tr-tp-ti", "t1-tp-tc-tr-ti"]: #
            for i in range(len(templates_list_daily_dialog)):
                cur_new_ans_list = dd_result[method][i]
                assert len(cur_new_ans_list) == len(answer_list)
                for old_a, new_a in zip(answer_list, cur_new_ans_list):
                    if not extract_last_also: # newly added
                        diff_result[method][i].append(compare_old_new_answer(old_a, new_a)) # old method
                    else: # newly added
                        assert extract_last_also # newly added
                        if compare_old_new_answer(old_a, new_a) == 0: # newly added
                            diff_result[method][i].append(compare_old_new_answer(old_a, new_a.split(' ')[-1])) # newly added
                        else: # newly added
                            diff_result[method][i].append(compare_old_new_answer(old_a, new_a)) # newly added
            do_error_analysis(method, diff_result, templates_list_daily_dialog, cnt)
            print("===============")
    print(f"FINAL RAW RESULT: {concise_result}")
    print(f"MODEL: {model}")
    print("AVERAGED RESULT FOR TABLE:")
    for method in concise_result:
        df = pd.DataFrame(columns=['top-k', 'update', 'no update', 'unknown', 'upper bound'])
        print(f"METHOD: {method}")
        for k in concise_result[method]:
            assert len(concise_result[method][k]) == 3 and all(len(_) == 4 for _ in concise_result[method][k]) # update, no update, unknown, upper bound
            update, no_update, unknown, upper_bound = [round(sum(values)/len(values), 2) for values in zip(*concise_result[method][k])]
            unknown = round(100.0 - update - no_update, 2)
            assert round(update + no_update + unknown, 2) == 100.00
            update_std, no_update_std, unknown_std, upper_bound_std = [round(stdev(values), 2) for values in zip(*concise_result[method][k])]
            #print(f"top-{k}: {update}, {no_update}, {unknown}")
            df.loc[len(df)] = pd.Series([int(k), f"{update:05.2f}({update_std:04.2f})", f"{no_update:05.2f}({no_update_std:04.2f})", f"{unknown:05.2f}({unknown_std:04.2f})", f"{upper_bound:05.2f}({upper_bound_std:04.2f})"], index=df.columns)
            #df = df.append(pd.Series([int(k), f"{update:05.2f}({update_std:04.2f})", f"{no_update:05.2f}({no_update_std:04.2f})", f"{unknown:05.2f}({unknown_std:04.2f})", f"{upper_bound:05.2f}({upper_bound_std:04.2f})"], index=df.columns), ignore_index=True)
        print(df)


#################### end report_concise ####################

else:
    assert not report_concise # make sure it does not go here...
    assert data_split != "all" # make sure it does not go here...

    dd_result = load_result_pt(templates_list_daily_dialog, data_split, pt_file_path, rewrite_i, run_oracle)
    diff_result = {
                "t1-tc-tr-tp-ti": [[] for _ in range(len(templates_list_daily_dialog))], #
                "t1-tp-tc-tr-ti": [[] for _ in range(len(templates_list_daily_dialog))] #
    }
    for method in ["t1-tc-tr-tp-ti", "t1-tp-tc-tr-ti"]: #
        for i in range(len(templates_list_daily_dialog)):
            cur_new_ans_list = dd_result[method][i]
            assert len(cur_new_ans_list) == len(answer_list)
            for old_a, new_a in zip(answer_list, cur_new_ans_list):
                if not extract_last_also: # newly added
                    diff_result[method][i].append(compare_old_new_answer(old_a, new_a)) # old method
                else: # newly added
                    assert extract_last_also # newly added
                    if compare_old_new_answer(old_a, new_a) == 0: # newly added
                        diff_result[method][i].append(compare_old_new_answer(old_a, new_a.split(' ')[-1])) # newly added
                    else: # newly added
                        diff_result[method][i].append(compare_old_new_answer(old_a, new_a)) # newly added
    for method in ["t1-tc-tr-tp-ti", "t1-tp-tc-tr-ti"]: #
        do_error_analysis(method, diff_result, templates_list_daily_dialog, cnt)
        print("===============")
