from olym_gen.utils.sample_utils import create_random_sample
import olym_gen.utils.utils
import os
import json
import sys
import asyncio
from olym_gen.generator.generate_check import main as check_main

root_dir='validset_baseline/'

models = [('local_vllm', 'xxxxx')]


def _convert_to_jsonl(input_path):
    output_path = input_path + ".jsonl"
    from olym_gen.generator.generate_rephrase import RephraseGenerator
    generator = RephraseGenerator()
    generator.from_json_to_jsonl(output_path, input_path, 'question', 'orig_solution', 'field', 'proofs')
    return output_path

extra_model_paras = {
    "temperature": 0.6,
    "top_p": 0.9,
    "chat_template_kwargs": {"enable_thinking": "True"}
}

async def generate_check(input_path, output_path, check_model_provider, check_model_name, num_returns=4, check_worker=3000):
    argv = [
        check_model_provider,
        "--model", check_model_name,
        "--file", input_path,
        "--save_path", output_path,
        "--num_returns", str(num_returns),
        "--num_worker", str(check_worker),
        "--max_token", str(32000),
        "--extra_model_paras", json.dumps(extra_model_paras),
        "--local_model",
        "--resume",
    ]
    await check_main(argv)

def _generate_mask_norm(input_path, output_path=None):
    print(f"Normalizing mask for {input_path}")
    if output_path is None:
        output_path = os.path.join(input_path, os.pardir, 'mask_norm')
    from olym_gen.utils.utils import normalize_mask_completion
    normalize_mask_completion(input_path, output_path)
    return output_path

def generate_mask_norm(input_path, output_path=None):
    norm_path = _generate_mask_norm(input_path, output_path)
    norm_file = _convert_to_jsonl(norm_path)
    return norm_file

async def main():
    check_configs = []
    for root, dirs, files in os.walk(root_dir):
        if len(dirs) > 0 or len(files) == 0:
            continue
        if 'mask_norm' in root:
            continue
        file_ex = files[0]
        if file_ex.endswith('.json'):
            num_returns = 8 if 'testset' in root else 4
            input_path = generate_mask_norm(root) if 'mask' in root else _convert_to_jsonl(root)

            for provider, model in models:
                output_path = os.path.join('save','baseline',model+'-generate',input_path.rstrip('.jsonl'))
                if not os.path.exists(output_path):
                    os.makedirs(output_path)
                check_configs.append((input_path, output_path, provider, model, num_returns))
        if file_ex.endswith('.jsonl'):
            for file in files:
                if file.endswith('.jsonl'):
                    num_returns = 8 if 'testset' in root else 4
                    input_path = os.path.join(root, file)
                    for provider, model in models:
                        output_path = os.path.join('save','baseline',model+'-generate',root, file.rstrip('.jsonl'))
                        if not os.path.exists(output_path):
                            os.makedirs(output_path)
                        check_configs.append((input_path, output_path, provider, model, num_returns))

    # Run all checks in parallel
    # await asyncio.gather(*(generate_check(*config) for config in check_configs))
    for config in check_configs:
        await generate_check(*config)

if __name__ == "__main__":
    asyncio.run(main())