import glob
import numpy
import pandas as pd
from train import (load_bats_pairs, BaseWord2VecWrapper, load_bats_for_train,
                   load_google_analogy_examples_by_category,
                   load_google_analogy_pairs_by_category, preprocess_ds, load_model_state_dict)
from utils import get_device


def delete_non_included_pairs(pairs, model):
    pairs = [(a, [b for b in bs if b in model.word_to_index]) for (a, bs) in pairs]  # delete non-included candidates
    return [(a, bs) for (a, bs) in pairs if a in model.word_to_index and len(bs) > 0]  # delete non-included pairs


def bats_table():
    model = BaseWord2VecWrapper()
    arr = []
    for filename in sorted(glob.glob("data/BATS_3.0/**/*")):
        _, _, category, subcategory = filename.split("/")
        category = category.split("_")[1]
        subcategory = subcategory.split(".")[0]
        subcategory = subcategory.replace("[", "").replace("]", "")
        with open(filename) as f:
            example = f.readline()[:-1]  # delete "\n"
            example = example.replace("\t", ":")
            if (len(example) > 20):
                example = example[:20] + "..."
        pairs = load_bats_pairs(filename)
        num_ = len(pairs)
        used_ = len(delete_non_included_pairs(pairs, model))
        arr.append([category, subcategory, example, num_, used_])

    df = pd.DataFrame(arr,
                      columns=["category", "subcategory", "example", "pair", "used"])
    caption = "Detailed explanation of bigger analogy test set. " \
        "num refers to the whole relation size and used refers to the number included in the word2vec model."
    return df.to_latex(index=False, label="tbl:bats-detail", caption=caption)


def google_table():
    # model = BaseWord2VecWrapper()
    arr = []
    for subcategory, examples in load_google_analogy_examples_by_category().items():
        category = "Syntactic" if subcategory[:4] == "gram" else "Semantic"
        example = f"{examples[0][0]}:{examples[0][1][0]}"
        num = len(examples)
        arr.append([category, subcategory, example, num, num])
    df = pd.DataFrame(arr, columns=["category", "subcategory", "example", "num", "used"])
    caption = "Detailed explanation of Google analogy test set. " \
        "num refers to the whole relation size and used refers to the number included in the word2vec model."
    return df.to_latex(index=False, label="tbl:google-detail", caption=caption)


def format_corrects(corrects, num):
    row = []
    for correct in corrects:
        accuracy_str = f"{correct} ({correct / num:.2%})"
        if correct == max(corrects):
            accuracy_str = f"\textbf{{{accuracy_str}}}"
        row.append(accuracy_str.replace("%", "\\%"))
    return row


MODEL_NAME_DICT = {
    "BaseWord2VecWrapper": "WV",
    "MLPWord2VecWrapper": "WV + MLP",
    "AGNWord2VecWrapper": "WV + AGN",
}


def construct_dataframe(models, dic, overall=True, exclude_abc=False):  # {category: data_ds}
    index = []
    arr = []
    num_acc = 0
    corrects_acc = [0] * len(models)
    for category, data_ds in dic.items():
        if len(category) > 15:
            category = category[:15] + "..."
        index.append(category)
        num_acc += len(data_ds[0])
        corrects = [model.eval_data(data_ds, exclude_abc=exclude_abc) for model in models]
        for i in range(len(corrects)):
            corrects_acc[i] += corrects[i]
        arr.append([len(data_ds[0])] + format_corrects(corrects, len(data_ds[0])))
    if overall:
        index.insert(0, "Overall")
        arr.insert(0, [num_acc] + format_corrects(corrects_acc, num_acc))
    columns = ["num"] + [MODEL_NAME_DICT[type(model).__name__] for model in models]
    return pd.DataFrame(arr, columns=columns, index=index)


def eval_bats_main(models, exclude_abc=False):
    partitions = [900, 1782, 2414, 3314]
    categories = ["Inflectional", "Derivational", "Lexicographic", "Encyclopedic"]

    _, _, _, test_ds = load_bats_for_train(models[0])
    ta, tb, tc, tds = test_ds
    dic = {}
    for i in range(len(partitions)):
        part = numpy.arange(0 if i == 0 else partitions[i - 1], partitions[i])
        dic[categories[i]] = (ta[part], tb[part], tc[part], tds[part])

    df = construct_dataframe(models, dic, exclude_abc=exclude_abc)
    caption = "Accuracy on bigger analogy test set."
    return df.to_latex(label="tbl:eval-bats-main", caption=caption, escape=False)


def eval_google_main(models, exclude_abc=False):
    dic = {"Semantic": [], "Syntactic": []}
    for subcategory, examples in load_google_analogy_examples_by_category().items():
        if subcategory[:4] == "gram":
            dic["Syntactic"] += examples
        else:
            dic["Semantic"] += examples
    dic["Semantic"] = preprocess_ds(dic["Semantic"], models[0])
    dic["Syntactic"] = preprocess_ds(dic["Syntactic"], models[0])

    df = construct_dataframe(models, dic, exclude_abc=exclude_abc)
    caption = "Accuracy on Google analogy test set."
    return df.to_latex(label="tbl:eval-google-main", caption=caption, escape=False)


def eval_bats_detail(models, exclude_abc=False):
    # partition index of test data
    partitions = [90, 180, 270, 360, 450, 540, 630, 720, 810, 900,
                  990, 1080, 1170, 1260, 1332, 1422, 1512, 1602, 1692, 1782,
                  1838, 1894, 1914, 1934, 1964, 2054, 2144, 2234, 2324, 2414,
                  2504, 2594, 2684, 2774, 2864, 2954, 3044, 3134, 3224, 3314]

    _, _, _, test_ds = load_bats_for_train(models[0])
    ta, tb, tc, tds = test_ds
    IDEL = "IDEL"
    dic = {}
    for i in range(len(partitions)):
        part = numpy.arange(0 if i == 0 else partitions[i - 1], partitions[i])
        dic[f"{IDEL[i // 10]}{(i % 10 + 1):0=2}"] = (ta[part], tb[part], tc[part], tds[part])

    df = construct_dataframe(models, dic, overall=False, exclude_abc=exclude_abc)
    caption = "Model comparison for each subcategory of bigger analogy test set."
    return df.to_latex(label="tbl:eval-bats-detail", caption=caption, escape=False)


def eval_google_detail(models, exclude_abc=False):
    dic = {subcategory: preprocess_ds(examples, models[0])
           for subcategory, examples in load_google_analogy_examples_by_category().items()}
    df = construct_dataframe(models, dic, overall=False, exclude_abc=exclude_abc)
    caption = "Model comparison for each subcategory of Google analogy test set."
    return df.to_latex(label="tbl:eval-google-detail", caption=caption, escape=False)


if __name__ == "__main__":
    base = BaseWord2VecWrapper().to(get_device())
    mlp = load_model_state_dict("MLPWord2VecWrapper_300_223_4_0116200114.pt").to(get_device())
    agn = load_model_state_dict("AGNWord2VecWrapper_300_151_5_0117205351.pt").to(get_device())
    print(eval_bats_main([base, mlp, agn]))
    print(eval_google_main([base, mlp, agn]))
    # print(bats_table())
    # print(google_table())
