'''
splits data into train and test sets
'''

import argparse
import json
import os
import random
import time
import sys

from collections import OrderedDict

from constants import DATASETS, SEED_FILES

def create_jsons_for(user_files, which_set, max_users, include_hierarchy):
    '''used in split-by-user case'''
    user_count = 0
    json_index = 0
    users = []
    if include_hierarchy:
        hierarchies = []
    else:
        hierarchies = None
    num_samples = []
    user_data = {}
    for (i, t) in enumerate(user_files):
        if include_hierarchy:
            (u, h, ns, f) = t
        else:
            (u, ns, f) = t

        file_dir = os.path.join(subdir, f)
        with open(file_dir, 'r') as inf:
            data = json.load(inf)
        
        users.append(u)
        if include_hierarchy:
            hierarchies.append(h)
        num_samples.append(ns)
        user_data[u] = data['user_data'][u]
        user_count += 1

        if (user_count == max_users) or (i == len(user_files) - 1):

            all_data = {}
            all_data['users'] = users
            if include_hierarchy:
                all_data['hierarchies'] = hierarchies
            all_data['num_samples'] = num_samples
            all_data['user_data'] = user_data

            data_i = f.find('data')
            num_i = data_i + 5
            num_to_end = f[num_i:]
            param_i = num_to_end.find('_')
            param_to_end = '.json'
            if param_i != -1:
                param_to_end = num_to_end[param_i:]
            nf = '%s_%d%s' % (f[:(num_i-1)], json_index, param_to_end)
            file_name = '%s_%s_%s.json' % ((nf[:-5]), which_set, arg_label)
            ouf_dir = os.path.join(dir, which_set, file_name)

            print('writing %s' % file_name)
            with open(ouf_dir, 'w') as outfile:
                json.dump(all_data, outfile)

            user_count = 0
            json_index += 1
            users = []
            if include_hierarchy:
                hierarchies = []
            num_samples = []
            user_data = {}

parser = argparse.ArgumentParser()

parser.add_argument('--name',
                help='name of dataset to parse; default: sent140;',
                type=str,
                choices=DATASETS,
                default='sent140')
parser.add_argument('--by_user',
                help='divide users into training and test set groups;',
                dest='user', action='store_true')
parser.add_argument('--by_sample',
                help="divide each user's samples into training and test set groups;",
                dest='user', action='store_false')
parser.add_argument('--frac',
                help='fraction in training set; default: 0.9;',
                type=float,
                default=0.9)
parser.add_argument('--seed',
                help='seed for random partitioning of test/train data',
                type=int,
                default=None)

parser.set_defaults(user=False)

args = parser.parse_args()

print('------------------------------')
print('generating training and test sets')

parent_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
dir = os.path.join(parent_path, args.name, 'data')
subdir = os.path.join(dir, 'rem_user_data')
files = []
if os.path.exists(subdir):
    files = os.listdir(subdir)
if len(files) == 0:
    subdir = os.path.join(dir, 'sampled_data')
    if os.path.exists(subdir):
        files = os.listdir(subdir)
if len(files) == 0:
    subdir = os.path.join(dir, 'all_data')
    files = os.listdir(subdir)
files = [f for f in files if f.endswith('.json')]

rng_seed = (args.seed if (args.seed is not None and args.seed >= 0) else int(time.time()))
rng = random.Random(rng_seed)
if os.environ.get('LEAF_DATA_META_DIR') is not None:
    seed_fname = os.path.join(os.environ.get('LEAF_DATA_META_DIR'), SEED_FILES['split'])
    with open(seed_fname, 'w+') as f:
        f.write("# split_seed used by sampling script - supply as "
                "--spltseed to preprocess.sh or --seed to utils/split_data.py\n")
        f.write(str(rng_seed))
    print ("- random seed written out to {file}".format(file=seed_fname))
else:
    print ("- using random seed '{seed}' for sampling".format(seed=rng_seed))

arg_label = str(args.frac)
arg_label = arg_label[2:]

# check if data contains information on hierarchies
file_dir = os.path.join(subdir, files[0])
with open(file_dir, 'r') as inf:
    data = json.load(inf)
include_hierarchy = 'hierarchies' in data

if (args.user):
    print('splitting data by user')

    # 1 pass through all the json files to instantiate arr
    # containing all possible (user, .json file name) tuples
    user_files = []
    for f in files:
        file_dir = os.path.join(subdir, f)
        with open(file_dir, 'r') as inf:
            # Load data into an OrderedDict, to prevent ordering changes
            # and enable reproducibility
            data = json.load(inf, object_pairs_hook=OrderedDict)
        if include_hierarchy:
            user_files.extend([(u, h, ns, f) for (u, h, ns) in 
                zip(data['users'], data['hierarchies'], data['num_samples'])])
        else:
            user_files.extend([(u, ns, f) for (u, ns) in 
                zip(data['users'], data['num_samples'])])

    # randomly sample from user_files to pick training set users
    num_users = len(user_files)
    num_train_users = int(args.frac * num_users)
    indices = [i for i in range(num_users)]
    train_indices = rng.sample(indices, num_train_users)
    train_blist = [False for i in range(num_users)]
    for i in train_indices:
        train_blist[i] = True
    train_user_files = []
    test_user_files = []
    for i in range(num_users):
        if (train_blist[i]):
            train_user_files.append(user_files[i])
        else:
            test_user_files.append(user_files[i])

    max_users = sys.maxsize
    if args.name == 'femnist':
        max_users = 50 # max number of users per json file
    create_jsons_for(train_user_files, 'train', max_users, include_hierarchy)
    create_jsons_for(test_user_files, 'test', max_users, include_hierarchy)

else:
    print('splitting data by sample')

    for f in files:
        file_dir = os.path.join(subdir, f)
        with open(file_dir, 'r') as inf:
            # Load data into an OrderedDict, to prevent ordering changes
            # and enable reproducibility
            data = json.load(inf, object_pairs_hook=OrderedDict)

        num_samples_train = []
        user_data_train = {}
        num_samples_test = []
        user_data_test = {}

        user_indices = [] # indices of users in data['users'] that are not deleted

        removed = 0
        for i, u in enumerate(data['users']):

            curr_num_samples = len(data['user_data'][u]['y'])
            if curr_num_samples >= 2:
                # ensures number of train and test samples both >= 1
                num_train_samples = max(1, int(args.frac * curr_num_samples))
                if curr_num_samples == 2:
                    num_train_samples = 1

                num_test_samples = curr_num_samples - num_train_samples

                indices = [j for j in range(curr_num_samples)]
                if args.name in ['shakespeare']:
                    train_indices = [i for i in range(num_train_samples)]
                    test_indices = [i for i in range(num_train_samples + 80 - 1, curr_num_samples)]
                else:
                    train_indices = rng.sample(indices, num_train_samples)                    
                    test_indices = [i for i in range(curr_num_samples) if i not in train_indices]
                
                if len(train_indices) >= 1 and len(test_indices) >= 1:
                    user_indices.append(i)
                    num_samples_train.append(num_train_samples)
                    num_samples_test.append(num_test_samples)
                    user_data_train[u] = {'x': [], 'y': [], 'a': []}
                    user_data_test[u] = {'x': [], 'y': [], 'a': []}

                    train_blist = [False for _ in range(curr_num_samples)]
                    test_blist = [False for _ in range(curr_num_samples)]
                    
                    for j in train_indices:
                        train_blist[j] = True
                    for j in test_indices:
                        test_blist[j] = True

                    for j in range(curr_num_samples):
                        if (train_blist[j]):
                            user_data_train[u]['x'].append(data['user_data'][u]['x'][j])
                            user_data_train[u]['y'].append(data['user_data'][u]['y'][j])
                            user_data_train[u]['a'].append(data['user_data'][u]['a'][j])
                        elif (test_blist[j]):
                            user_data_test[u]['x'].append(data['user_data'][u]['x'][j])
                            user_data_test[u]['y'].append(data['user_data'][u]['y'][j])
                            user_data_test[u]['a'].append(data['user_data'][u]['a'][j])

        users = [data['users'][i] for i in user_indices]

        all_data_train = {}
        all_data_train['users'] = users
        all_data_train['num_samples'] = num_samples_train
        all_data_train['user_data'] = user_data_train

        all_data_test = {}
        all_data_test['users'] = users
        all_data_test['num_samples'] = num_samples_test
        all_data_test['user_data'] = user_data_test 

        if include_hierarchy:
            hierarchies = [data['hierarchies'][i] for i in user_indices] 
            all_data_train['hierarchies'] = hierarchies
            all_data_test['hierarchies'] = hierarchies

        file_name_train = '%s_train_%s.json' % ((f[:-5]), arg_label)
        file_name_test = '%s_test_%s.json' % ((f[:-5]), arg_label)
        ouf_dir_train = os.path.join(dir, 'train', file_name_train)
        ouf_dir_test = os.path.join(dir, 'test', file_name_test)
        print('writing %s' % file_name_train)
        with open(ouf_dir_train, 'w') as outfile:
            json.dump(all_data_train, outfile)
        print('writing %s' % file_name_test)
        with open(ouf_dir_test, 'w') as outfile:
            json.dump(all_data_test, outfile)
