import glob
import datetime
import pickle
import numpy
import gensim.test.utils
import gensim.downloader
import torch
from utils import get_device

"""
pair: (str, list(str))
example: str, list(str), str, list(str)
"""


def load_google_analogy_pairs_by_category(lowercase=False):
    dic = {}
    with open(gensim.test.utils.datapath("questions-words.txt")) as f:
        for line in f:
            if line[0] == ":":
                category = line[2:-1]  # remove ": " and "\n"
                dic[category] = []
            else:
                if lowercase:
                    line = line.lower()
                a, b, c, d = line[:-1].split()  # remove "\n"
                if (a, [b]) not in dic[category]:
                    dic[category].append((a, [b]))
                if (c, [d]) not in dic[category]:
                    dic[category].append((c, [d]))
    return dic


def load_google_analogy_examples_by_category(lowercase=False):
    dic = {}
    with open(gensim.test.utils.datapath("questions-words.txt")) as f:
        for line in f:
            if line[0] == ":":
                category = line[2:-1]
                dic[category] = []
            else:
                if lowercase:
                    line = line.lower()
                a, b, c, d = line[:-1].split()
                dic[category].append((a, [b], c, [d]))
    return dic


def is_syntactic(category): return category[:4] == "gram"


def load_bats_pairs(filename):
    pairs = []
    with open(filename) as f:
        for line in f:
            a, bs = line[:-1].split()  # remove "\n"
            bs = bs.split("/")  # multiple candidates are separated by "/"
            pairs.append((a, bs))
    return pairs


def load_bats_pairs_by_category():
    dic = {}
    for filename in sorted(glob.glob("data/BATS_3.0/**/*")):
        dic[filename.split("/")[-1]] = load_bats_pairs(filename)
    return dic


def delete_non_included(pairs, wv):
    pairs = [(a, [b for b in bs if b in wv]) for (a, bs) in pairs]  # delete non-included candidates
    return [(a, bs) for (a, bs) in pairs if a in wv and len(bs) > 0]  # delete non-included pairs


def split_pairs(pairs, train_ratio=0.6, valid_ratio=0.2, test_ratio=0.2):
    pairs = numpy.random.permutation(pairs)
    train_pairs = pairs[:int(len(pairs) * train_ratio)]
    valid_pairs = pairs[int(len(pairs) * train_ratio): int(len(pairs) * (train_ratio + valid_ratio))]
    test_pairs = pairs[int(len(pairs) * (train_ratio + valid_ratio)):]
    train_data = [(a, b, c, d) for (a, b) in train_pairs for (c, d) in train_pairs if (a, b) != (c, d)]
    valid_data = [(a, b, c, d) for (a, b) in valid_pairs for (c, d) in valid_pairs if (a, b) != (c, d)]
    test_data = [(a, b, c, d) for (a, b) in test_pairs for (c, d) in test_pairs if (a, b) != (c, d)]
    return train_data, valid_data, test_data


def split_dic_into_distinct_examples(dic, wv, train_ratio=0.6, valid_ratio=0.2, test_ratio=0.2):
    train_data, valid_data, test_data = [], [], []
    for category, pairs in dic.items():
        pairs = delete_non_included(pairs, wv)
        print(f"{category}: {len(pairs)}", flush=True)
        train_tmp, valid_tmp, test_tmp = split_pairs(pairs, train_ratio, valid_ratio, test_ratio)
        train_data += train_tmp
        valid_data += valid_tmp
        test_data += test_tmp
        # print(len(test_data), end=", ") # for category_partition
    return train_data, valid_data, test_data


def save_bats_split():
    dic = load_bats_pairs_by_category()
    wv = gensim.downloader.load("word2vec-google-news-300")
    with open('word_analogy/data/bats_split.pickle', 'wb') as f:
        pickle.dump(split_dic_into_distinct_examples(dic, wv, 0.6, 0.2, 0.2), f)


def load_bats_split():
    with open('word_analogy/data/bats_split.pickle', 'rb') as f:
        return pickle.load(f)


def preprocess_d(data, model):
    """ list((str, list(str), str, list(str))) -> [batch (int)], [batch (int)], [batch (int)], [batch (int)] """
    a_batch = numpy.array([model.word_to_index[a] for a, bs, c, ds in data])
    b_batch = numpy.array([model.word_to_index[bs[0]] for a, bs, c, ds in data])
    c_batch = numpy.array([model.word_to_index[c] for a, bs, c, ds in data])
    d_batch = numpy.array([model.word_to_index[ds[0]] for a, bs, c, ds in data])
    return a_batch, b_batch, c_batch, d_batch


def preprocess_ds(data, model):
    """ list((str, list(str), str, list(str))) -> [batch (int)], [batch (int)], [batch (int)], [batch (list(int))] """
    a_batch = numpy.array([model.word_to_index[a] for a, bs, c, ds in data])
    b_batch = numpy.array([model.word_to_index[bs[0]] for a, bs, c, ds in data])
    c_batch = numpy.array([model.word_to_index[c] for a, bs, c, ds in data])
    ds_batch = numpy.array([[model.word_to_index[d] for d in ds] for a, bs, c, ds in data], dtype=object)
    return a_batch, b_batch, c_batch, ds_batch


def load_bats_for_train(model):
    train_data, valid_data, test_data = load_bats_split()
    return (preprocess_d(train_data, model), preprocess_ds(train_data, model),
            preprocess_ds(valid_data, model), preprocess_ds(test_data, model))


def save_model_state_dict(model):  # {model_name}_{dim}_{hidden_dim}_{layer_num}_{timestamp}
    timestamp = datetime.datetime.now().strftime("%m%d%H%M%S")
    filename = f"{type(model).__name__}_{model.dim}_{model.hidden_dim}_{model.layer_num}_{timestamp}.pt"
    torch.save(model.state_dict(), f"word_analogy/trained_model/{filename}")


def load_model_state_dict(filename):
    model_name, dim, hidden_dim, layer_num, _ = filename.split("_")
    model = eval(f"{model_name}(hidden_dim={hidden_dim}, layer_num={layer_num})")
    model.load_state_dict(torch.load(f"word_analogy/trained_model/{filename}", map_location=get_device()), strict=False)
    return model.to(get_device())
