import jsonlines
import argparse
import os

parser = argparse.ArgumentParser()
parser.add_argument('--path', type=str, default='samples_coco/eagle_t2i-mscoco_text_features-size-256-size-256-VQ-16-topk-1000-topp-1.0-temperature-1.0-cfg-7.5-seed-0-warmup-0-adative-jsd-a-0.0001-b-1000.0')
args = parser.parse_args()

tot_accepted_length_list = []
cnt = 0
with jsonlines.open(os.path.join(args.path, "sampling_stats.json")) as reader:
    for i, obj in enumerate(reader):
        if 'accepted_length' in obj:
            tot_accepted_length_list.extend(obj['accepted_length'])
        else:
            tot_accepted_length_list.extend(obj['accept_length'])
        cnt += 1
# tot_accepted_length_list = [min(x, 6) for x in tot_accepted_length_list]
with open(os.path.join(args.path, "accepted_length.txt"), 'w') as f:
    f.write("count: %d\n" % cnt)
    f.write("mean: %f\n" % (sum(tot_accepted_length_list)/len(tot_accepted_length_list)))