import os
import json
from collections import Counter
from optparse import OptionParser

import numpy as np


def main():
    usage = "%prog"
    parser = OptionParser(usage=usage)
    parser.add_option('--name', type=str, default='partition',
                      help='Base name: default=%default')
    parser.add_option('--n-test', type=int, default=300,
                      help='Number of test folds: default=%default')
    parser.add_option('--min-count', type=int, default=10,
                      help='Min count: default=%default')
    parser.add_option('--seed', type=int, default=42,
                      help='Random seed: default=%default')

    (options, args) = parser.parse_args()

    name = options.name
    n_test = options.n_test
    min_count = options.min_count
    seed = options.seed

    data_file = os.path.join('data', 'classification', 'all.tokenized.jsonlist')
    with open(os.path.join('data', 'classification', 'value_counts.json')) as f:
        value_counter = Counter(json.load(f))

    for value, count in value_counter.items():
        if count >= min_count:
            print(value)
            create_partition(data_file, value, name, n_test, seed)


def create_partition(data_file, task, name='partition', n_test=0, seed=42):

    np.random.seed(seed)

    output = {}

    print("Reading data from", data_file)
    with open(data_file) as f:
        lines = f.readlines()
    data = [json.loads(line) for line in lines]
    print("Found {:d} instances".format(len(data)))

    data_indices = [i for i, line in enumerate(data) if task in line]
    print("Found {:d} instances with labels".format(len(data_indices)))
    if n_test > 0:
        print("Selecting {:d} test instances".format(n_test))
        test_indices = np.random.choice(data_indices, size=n_test, replace=False)
        # convert to basic ints for serialization
        test_indices = [int(i) for i in test_indices]
        test_set = set(test_indices)
        nontest_indices = [i for i in data_indices if i not in test_set]
    else:
        test_indices = []
        nontest_indices = data_indices
    dev_indices = []
    train_indices = nontest_indices

    output['train_file'] = data_file
    output['dev_file'] = data_file
    output['test_file'] = data_file
    output['unlabeled_file'] = None
    output['dev_folds'] = 1
    output['stratified'] = False
    output['train_indices'] = train_indices
    output['dev_indices'] = dev_indices
    output['test_indices'] = test_indices
    output['task'] = task
    output['seed'] = seed

    partition_name = name + '_'
    partition_name += 't' + str(n_test) + '_s' + str(seed)
    output_dir = os.path.join(os.path.split(data_file)[0], 'exp', task, partition_name)
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    output_file = os.path.join(output_dir, 'partition.json')

    with open(output_file, 'w') as f:
        json.dump(output, f, indent=None)


if __name__ == '__main__':
    main()
