import os
from collections import defaultdict
import json

from evalplus.data import get_mbpp_plus

from inference_rlhf.code.helpers.io import json_dump

TASK = 'mbpp'
AMLT_RESULTS_DIR = 'data/mbpp/gpt-4o-mini'

jsonl_name_to_files = defaultdict(list)
for root, dirs, files in os.walk(AMLT_RESULTS_DIR):
    for file in files:
        if file.endswith('.jsonl') and 'raw' not in file:
            jsonl_name_to_files[file].append(os.path.join(root, file))

# sort lists to guarantee reproducibility
for file_name, files in jsonl_name_to_files.items():
    jsonl_name_to_files[file_name] = list(sorted(files))

for file_name, files in jsonl_name_to_files.items():
    task_id_to_prompt_idx = {}
    dataset_dict = get_mbpp_plus(version="default")
    for i, problem_key in enumerate(dataset_dict):
        task_id_to_prompt_idx[problem_key] = i

    all_responses = defaultdict(list)
    all_responses_jsonl = defaultdict(list)
    for file in files:
        with open(file, 'r') as f:
            for line in f:
                try:
                    data = json.loads(line)
                except:
                    continue
                all_responses[data['task_id']].append({
                    "prompt_idx": task_id_to_prompt_idx[data['task_id']],
                    "response": data['solution'],
                })
                all_responses_jsonl[data['task_id']].append(data)

    print(f'Collected {len(all_responses)} responses for {file_name}')

    if 'Qwen2.5-3B-Instruct' in file_name:
        policy = 'qwen-25-3b'
    elif 'Llama-3.1-8B-Instruct' in file_name:
        policy = 'llama-3-8b'
    elif 'Llama-3.2-3B-Instruct' in file_name:
        policy = 'llama-3-3b'
    elif 'Mistral-7B-Instruct' in file_name:
        policy = 'mistral-7b'
    elif 'Qwen2.5-0.5B-Instruct' in file_name:
        policy = 'qwen-25-05b'
    elif 'Qwen2.5-7B-Instruct' in file_name:
        policy = 'qwen-25-7b'
    elif 'Qwen2.5-14B-Instruct' in file_name:
        policy = 'qwen-25-14b'
    elif 'Qwen2.5-Coder-3B-Instruct' in file_name:
        policy = 'qwen-25-coder-3b'
    elif 'Qwen2.5-Coder-7B-Instruct' in file_name:
        policy = 'qwen-25-coder-7b'
    elif 'Qwen2.5-Coder-14B-Instruct' in file_name:
        policy = 'qwen-25-coder-14b'
    elif 'Phi-3-medium-4k-instruct' in file_name:
        policy = 'phi-3-medium'
    elif 'phi-4' in file_name:
        policy = 'phi-4'
    elif 'gpt-4o-mini' in file_name:
        policy = 'gpt-4o-mini'
    else:
        raise ValueError(f'Unknown policy: {file_name}')

    for task_id, responses in all_responses.items():
        new_file_name = file_name.replace(".jsonl", ".json")
        json_dump(responses, f'data/{TASK}/{policy}/{new_file_name}')

        # # write big jsonl file
        # with open(f'data/{TASK}/{policy}/{new_file_name.replace(".json", ".jsonl")}', 'w') as f:
        #     for jsonl_response in all_responses_jsonl[task_id]:
        #         f.write(json.dumps(jsonl_response) + '\n')

