import os
import numpy as np
import json
from tqdm import trange

with open("merge_belong_inds.json") as fd:
    d = json.load(fd)

belongs = list(set(x['belong'] for x in d))
print(belongs)

data_by_belongs = {x:[] for x in belongs}
for item in d:
    data_by_belongs[item['belong']].append(item)

for k in data_by_belongs:
    print(k, len(data_by_belongs[k]))

import os
output_dir = "./random_splits" 
os.makedirs(output_dir, exist_ok=True)

avg_rewards = []

print(len(data_by_belongs))

for i in trange(200):
    from random import random
    values = [random() for i in belongs]
    sum_values = 1.*sum(values)
    counts = [int(1000*v/sum_values) for v in values]
    one_subset = []
    from random import sample
    for ib, bel in enumerate(belongs):
        one_subset.extend(sample(data_by_belongs[bel],counts[ib]))
    reward_scores = []
    for item in one_subset:
        reward_scores.append(item['reward'])
    avg_reward = sum(reward_scores)/len(reward_scores)
    with open(os.path.join(output_dir, f"mb_rand_{i:04d}.json"), "w") as fd:
        json.dump(one_subset, fd, indent=1)
    
    avg_rewards.append(avg_reward)

print(avg_rewards)
print(np.var(avg_rewards))

# os.makedirs("./random_splits_v2")

