"""
This script analyzes the ACR of the models for the GENERATE task.
"""

import os
import sys
import json
import numpy as np
from datasets import load_dataset
from transformers import AutoTokenizer

count = 100
styles = ['choose', 'option', 'generate']
datasets = ['bio', 'chem', 'cyber']
models = ['base', 'rmu', 'npo-bio', 'npo-cyber']
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta")

def wmdp_dataset_length(config_name='wmdp-bio'):
    if not config_name in ['wmdp-bio', 'wmdp-chem', 'wmdp-cyber']:
        raise ValueError(f'Unknown config_name: {config_name}')
    dataset = load_dataset("cais/wmdp", config_name)
    lengths = []
    for split, split_dataset in dataset.items():
        for example in split_dataset:
            lengths.append(tokenizer(example['choices'][example['answer']], return_length=True)['length'][0])
    return lengths[:count]

def load_file(file):
    results = []
    if not os.path.exists(file):
        for i in range(delta):
            results.append(-1)
        return results    
    data = json.load(open(file))
    for entry in data:
        if entry['success']:
            results.append(entry['free_tokens'])
        else:
            results.append(-1)
    return results

if __name__ == '__main__':
    for s, d in [(s, d) for s in styles for d in datasets]:
        print(f'Style: {s}, Dataset: {d}')
        delta = 5 if s == 'generate' else 10
        lengths = wmdp_dataset_length(f'wmdp-{d}')
        results = {m : [] for m in models}
        for m in models:
            for i in range(0, count, delta):
                results[m].extend(load_file(f'results/{m}_wmdp-{d}_{s}_{i}_{i+delta}.json'))
        ratios = {m : [] for m in models}
        for i in range(count):
            exists = [results[m][i] != -1 for m in models]
            if not all(exists):
                continue
            for m in models:
                ratios[m].append(lengths[i] / results[m][i])
        print(' Count:', len(ratios['base']))
        for m in models:
            print(f' Model: {m}, Median ratio:', np.median(ratios[m]))