import os
import json
import joblib
from glob import glob
from optparse import OptionParser


from common import file_handling as fh
from common.labels import encode_labels
from common.docs import encode_documents_as_bow
from models.lr import predict, LogisticRegression


# Run predictions on held out data using multiple models (assuming the same config and vocab for each)

def main():
    usage = "%prog"
    parser = OptionParser(usage=usage)
    parser.add_option('--dataset', type=str, default='neurips',
                      help='Dataset to make predictions on [neurips|icml]: default=%default')

    (options, args) = parser.parse_args()

    dataset = options.dataset
    infile = os.path.join('data', dataset, 'parsed.jsonlist')

    model_files = glob(os.path.join('data', 'classification', 'exp', '*', 'partition_t300_s42', 'linear_f1_binarize_n1_l1', 'model.nontest.pkl'))

    first_model_file = model_files[0]
    first_model_dir = os.path.split(first_model_file)[0]
    config_file = os.path.join(first_model_dir, 'config.json')
    config = fh.read_json(config_file)
    vocab_file = os.path.join(first_model_dir, 'vocab.json')
    vocab = fh.read_json(vocab_file)

    print("Loading data")
    with open(infile) as f:
        docs = f.readlines()
    docs = [json.loads(line) for line in docs]
    for i, doc in enumerate(docs):
        doc['_i'] = 'tr_' + str(i)

    print("Encoding data using first model")
    ids, line_indices, counts, _, instance_weights, instance_confounds = encode_documents_as_bow(docs, vocab, config, confounders=None, confound_matrix=None, side_data=None)

    for model_file in model_files:
        print("Loading model for", model_file)
        model = joblib.load(model_file)

        model_dir = os.path.split(model_file)[0]
        label_file = os.path.join(model_dir, 'labels.json')
        label_vocab = fh.read_json(label_file)

        labels = None

        output_dir = model_dir

        predict([model], counts, labels, instance_weights, ids, line_indices, label_vocab, output_dir, dataset, do_evaluation=False)


if __name__ == '__main__':
    main()
