import os

import jsonlines
import argparse
import numpy as np

keywords = ['alternative', 'wait', 'but', 'check']


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--input_file', type=str)
    args = parser.parse_args()

    # data = list(jsonlines.open(args.input_file.format('math_500')))
    data =  list(jsonlines.open(args.input_file.format('omni_math')))
    thoughts_len = []
    for item in data:
        if 'clarification' in item['answer'].lower():
            continue
        thoughts = item['thought'].split('\n\n')
        r_thoughts = []
        for thought in thoughts:
            flag = False
            for keyword in keywords:
                if keyword in thought.lower():
                    flag = True
                    break
            if flag:
                r_thoughts.append(thought)
        thoughts_len.append(len(r_thoughts))
    print(np.mean(thoughts_len))


if __name__ == '__main__':
    main()
