import os
import json
import sys
import random
import spacy
import multiprocessing


# Books: 5
# Electronics: 10
# Home_and_Kitchen: 13
# Video Games: 23
task = ['penetrance', 'incidence']
bias = 'breast cancer'

def save_json(obj, path):
    '''
        Write a list of json object to the specified path.
    '''
    if os.path.isfile(path):
        reply = str(input('File already exsits. Overwrite? (y/n):'))
        os.system("rm {}".format(path))

    with open(path, 'w') as f:
        for d in obj:
            json.dump(d, f)
            f.write(os.linesep)


def worker(j, lines):
    data = []
    nlp = spacy.load('en_core_web_sm', disable=["tagger", "parser", "ner",
                                                "lemmatizer", "textcat"])

    for line in lines:
        example = json.loads(line)
        example['text'] = ' '.join([token.text.lower() for token in nlp(example['text'])])
        data.append(example)

    save_json(data, 'tmp_%d.json' % j)


def merge():
    data = []
    for i in range(80):
        data += open('tmp_%d.json' % i).readlines()
        os.system('rm tmp_%d.json' % i)

    with open('ask2me.json', 'w') as f:
        for line in data:
            f.write(line)


# with open(sys.argv[1], 'r') as f:
#     lines = f.readlines()

# plist = []
# for i in range(80):
#     plist.append(multiprocessing.Process(target=worker,
#         args=(i,lines[len(lines)//80*i:len(lines)//80*(i+1)])))

# for i in range(80):
#     plist[i].start()

# for i in range(80):
#     plist[i].join()

# merge()

# exit(0)



data = []
with open(sys.argv[1], 'r') as f:
    for line in f.readlines():
        data.append(json.loads(line))


random.seed(1)
random.shuffle(data)

for task_id, label in enumerate(task):
    data_dict = {
        '0_0': [],
        '0_1': [],
        '1_0': [],
        '1_1': [],
    }
    if task_id == 0:
        cur_data = data[:len(data)//2]
    else:
        cur_data = data[len(data)//2:]

    for example in cur_data:
        y = example[label]
        c = example[bias]

        data_dict['{}_{}'.format(y, c)].append({
            'y': y,
            'c': c,
            'text': example['text'],
        })

    for k, v in data_dict.items():
        print(k, ':  ', len(v))

    # write data out
    # add positive correlation
    min_data_size = min([len(v) for v in data_dict.values()])
    env3 = []
    for k, v in data_dict.items():
        env3 += v[:min_data_size//4]
        data_dict[k] = v[min_data_size//4:]


    env0 = data_dict['0_0'][:len(data_dict['0_0'])//3] +\
        data_dict['1_1'][:len(data_dict['1_1'])//3]

    env1 = data_dict['0_0'][len(data_dict['0_0'])//3:len(data_dict['0_0'])//3*2] +\
        data_dict['1_1'][len(data_dict['1_1'])//3:len(data_dict['1_1'])//3*2]

    env2 = data_dict['0_0'][len(data_dict['0_0'])//3*2:] +\
        data_dict['1_1'][len(data_dict['1_1'])//3*2:]

    # add negative correlation
    env0_pos = len(env0)
    env0 = env0 + data_dict['0_1'][:len(data_dict['0_1'])//10] +\
        data_dict['1_0'][:len(data_dict['1_0'])//10]
    env0_neg = len(env0) - env0_pos
    print(env0_pos / env0_neg)

    env1_pos = len(env1)
    env1 = env1 + data_dict['0_1'][len(data_dict['0_1'])//10:3*len(data_dict['0_1'])//10] +\
        data_dict['1_0'][len(data_dict['1_0'])//10:3*len(data_dict['1_0'])//10]
    env1_neg = len(env1) - env1_pos
    print(env1_pos / env1_neg)

    env2_pos = len(env2)
    env2 = env2 + data_dict['0_1'][3*len(data_dict['0_1'])//10:5*len(data_dict['0_1'])//10] +\
        data_dict['1_0'][3*len(data_dict['1_0'])//10:5*len(data_dict['1_0'])//10]
    env2_neg = len(env2) - env2_pos
    print(env2_pos / env2_neg)


    # # # save data to json
    save_json(env0, '{}_env_0.json'.format(label))
    save_json(env1, '{}_env_1.json'.format(label))
    save_json(env2, '{}_env_1_val.json'.format(label))
    save_json(env3, '{}_env_2.json'.format(label))
