# Written by Seonwoo Min, LG AI Research (seonwoo.min0@gmail.com)

import os
import sys
import pickle
import numpy as np
from tqdm import tqdm
from collections import Counter

from src.data import Vocabulary
from src.pycocoevalcap.tokenizer.ptbtokenizer import PTBTokenizer


def main(path):
    tokenized_texts, file_names, vocab = build_vocab(path)
    embed_texts(path, tokenized_texts, file_names, vocab)


def embed_texts(path, tokenized_texts, file_names, vocab):
    texts_w, lengths = [], []
    MaxLength = 64 if "domain_net" in path else 32
    for text in tokenized_texts:
        lengths.append(len(text) + 1)
        if len(text) < MaxLength - 2:
            text += ["<EOS>"] * (MaxLength - 2 - len(text))
        elif len(text) > MaxLength - 2:
            text = text[:MaxLength - 2]
        texts_w.append(np.array([vocab(vocab.start_token)] + [vocab(word) for word in text] + [vocab(vocab.end_token)]))

    with open(os.path.join(path, "texts_w.pkl"), 'wb') as f:
        pickle.dump(dict(list(zip(file_names, texts_w))), f)
    with open(os.path.join(path, "lengths.pkl"), 'wb') as f:
        pickle.dump(dict(list(zip(file_names, lengths))), f)


def build_vocab(path, threshold=1):
    tokenized_texts, file_names = [], []
    for domain_name in tqdm(sorted(os.listdir(path))):
        if '.' in domain_name:
            continue
        print(f"Start Domain: {domain_name}")
        for class_name in tqdm(sorted(os.listdir(os.path.join(path, domain_name)))):
            for filename in tqdm(sorted(os.listdir(os.path.join(path, domain_name, class_name)))):
                with open(os.path.join(path, domain_name, class_name, filename), "r") as FILE:
                    try:
                        line = FILE.readlines()[0].strip()
                    except:
                        print(os.path.join(path, domain_name, class_name, filename))
                        input()
                    file_names.append(domain_name + '_' + class_name + '_' + ''.join(filename.split('.')[:-1]))
                    tokenized_texts.append(PTBTokenizer().tokenize({0: [{'caption': line}]})[0][0])

    counter = Counter()
    for text in tokenized_texts:
        counter.update(text)
    words = [word for word, cnt in counter.items() if cnt >= threshold]

    vocab = Vocabulary()
    for word in words:
        vocab.add_word(word)
    Vocabulary.save(vocab, os.path.join(path, "vocab.pkl"))

    return tokenized_texts, file_names, vocab


if __name__ == "__main__":
    main(sys.argv[1])
