import fasttext
import logging
import argparse
import os

'''
Files should be
__label__1  sentence 1
__label__2  sentence 2
__label__1  sentence 4
input             # training file path (required)
lr                # learning rate [0.1]
dim               # size of word vectors [100]
ws                # size of the context window [5]
epoch             # number of epochs [5]
minCount          # minimal number of word occurences [1]
minCountLabel     # minimal number of label occurences [1]
minn              # min length of char ngram [0]
maxn              # max length of char ngram [0]
neg               # number of negatives sampled [5]
wordNgrams        # max length of word ngram [1]
loss              # loss function {ns, hs, softmax, ova} [softmax]
bucket            # number of buckets [2000000]
thread            # number of threads [number of cpus]
lrUpdateRate      # change the rate of updates for the learning rate [100]
t                 # sampling threshold [0.0001]
label             # label prefix ['__label__']
verbose           # verbose [2]
pretrainedVectors # pretrained word vectors (.vec file) for supervised learning []
'''

logger = logging.getLogger(__name__)
if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    # Config Decoder
    parser.add_argument("--training_data", default='train_data_raw.txt',
                        help="")  # train_augmented_data_eda #train_data_raw #train_augmented_data_naacl
    parser.add_argument("--val_data", default='test_fastText.txt', help="")
    parser.add_argument("--data_dir", default='yelp_small_train_augmented', help="")
    parser.add_argument("--saving_model_path", default='fastText', help="")
    parser.add_argument("--model_id", default='fastText', help="")

    args = parser.parse_args()
    args.saving_model_path = '{}'.format(args.saving_model_path)

    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )
    try:
        os.makedirs(args.saving_model_path)
    except:
        logger.info("Folder {} already exists".format(args.saving_model_path))

    logger.info("***** Start Training *****")

    model = fasttext.train_supervised(os.path.join(args.data_dir, args.training_data), epoch=10000)
    print(model.words)
    print(model.labels)
    model.save_model('{}.bin'.format(os.path.join(args.saving_model_path, args.model_id)))

    model = fasttext.load_model('{}.bin'.format(os.path.join(args.saving_model_path, args.model_id)))


    def print_results(N, p, r):
        str = ''
        str += '{}\n'.format("N\t{}".format(N))
        str += '{}\n'.format("P@{}\t{:.3f}".format(1, p))
        str += '{}\n'.format("R@{}\t{:.3f}".format(1, r))
        print(str)
        with open(os.path.join(args.data_dir, 'results_{}.txt'.format(args.model_id)), 'w') as file:
            file.write(str)


    logger.info("***** Start Testing on custom examples *****")

    print_results(*model.test(os.path.join(args.data_dir, args.val_data)))

    logger.info("***** Examples *****")
    sentence = 'I hate this stupid restaurant food service was so slow'
    label = model.predict(sentence)
    logger.info("Sentence : {} \t Label {}".format(sentence, label))
    sentence = 'I loved this amazin restaurant food was so great'
    label = model.predict(sentence)
    logger.info("Sentence : {} \t Label {}".format(sentence, label))
