"""
This script analyzes the success rates of different models across datasets and downstream tasks.
"""

import os
import sys
import json
import numpy as np

count = 100
styles = ['choose', 'option', 'generate']
datasets = ['bio', 'chem', 'cyber']
models = ['base', 'rmu', 'npo-bio', 'npo-cyber']

def analyze(files):
    success = []
    tokenct = []
    for file in files:
        if not os.path.exists(file):
            continue
        data = json.load(open(file))
        for entry in data:
            success.append(entry['success'])
            if entry['success']:
                tokenct.append(entry['free_tokens'])
            else:
                tokenct.append(5)
    print(' Count:', len(success))
    print(' Success rate:', np.mean(success))

if __name__ == '__main__':
    for s, d, m in [(s, d, m) for s in styles for d in datasets for m in models]:
        print(f'Style: {s}, Dataset: {d}, Model: {m}')
        delta = 5 if s == 'generate' else 10
        files = []
        for i in range(0, count, delta):
            files.append(f'results/{m}_wmdp-{d}_{s}_{i}_{i+delta}.json')
        analyze(files)