import sys
import os
import json
import random
from copy import deepcopy
from tqdm import tqdm

root_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.insert(0, root_dir)
from verl_utils.reward.reward_func import reward_func


def read_jsonl(fp):
    with open(fp, 'r') as f:
        return [json.loads(l) for l in f]

CRITIQUE_PROMPT = """Below you are presented with a question and a tentative response. Your task is to evaluate and assign a rating to the response based on the following clear criteria:

Rating Criteria:

1. Missing final answer enclosed in \\boxed{} at the end: assign \\boxed{-1}.
2. Correct response with the final answer enclosed in \\boxed{} at the end: assign \\boxed{1}.
3. Incorrect response with the final answer enclosed in \\boxed{} at the end: assign \\boxed{-0.5}.

### Question Begin ###
__special_original_question__
### Question End ###

### Response Begin ###
__special_original_response__
### Response End ###

Briefly summarize your analysis, then clearly state your final rating value enclosed in \\boxed{} at the end.
"""


messages_template =  [
    {'content': 'Please reason step by step, and put your final answer within \\boxed{}.', 'role': 'system'},
    {'content': '', 'role': 'user'}
]

model_size = '1.5b'

if __name__ == '__main__':
    random.seed(666)
    
    # Read input file from CLI arg
    if len(sys.argv) < 3:
        print("Usage: python script.py <input_file.jsonl> <output_dir>")
        sys.exit(1)
    input_file = sys.argv[1]
    data_from_eval = read_jsonl(input_file)
    output_dir = sys.argv[2]
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    
    # Create ID->item mapping
    item_dict = {item['id']: item for item in data_from_eval}
    math_ids = list(item_dict.keys())
    
    # Sample IDs for different splits
    splits = [
        (random.sample(math_ids, 1024), '012'),
        # (random.sample(math_ids, 2131), '025'),
        # (random.sample(math_ids, 4262), '050'),
        # (math_ids, '100')
    ]
    
    # Process each split
    for id_list, pct in splits:
        new_data = {k: [] for k in ["id", "problem", "messages", "subject", 
                                  "level", "answer", "solution", "expected_answer"]}
        
        for id in tqdm(id_list, desc=f"Processing {pct}%"):
            item = item_dict.get(id)
            if not item: continue
            
            # Generate response and reward
            response = random.choice(item['response'])
            reward = reward_func(None, response, item['answer'], None)
            
            # Format messages
            messages = deepcopy(messages_template)
            prompt = CRITIQUE_PROMPT.replace('__special_original_question__', item['problem'])\
                                   .replace('__special_original_response__', response)
            messages[1]['content'] = prompt
            
            # Populate data
            new_data['id'].append(id)
            new_data['problem'].append(item['problem'])
            new_data['messages'].append(messages)
            new_data['subject'].append(item['subject'])
            new_data['level'].append(int(item['level']))
            new_data['answer'].append(str(reward))
            new_data['solution'].append(item['solution'])
            new_data['expected_answer'].append(item['answer'])
        
        # Write to JSONL
        output_path = os.path.join(output_dir, f"rise_{model_size}_offline_{pct}.jsonl")
        with open(output_path, 'w') as f:
            for i in range(len(new_data['id'])):
                f.write(json.dumps({k: v[i] for k, v in new_data.items()}) + '\n')
        
        print(f"Generated {len(new_data['id'])} entries in {output_path}")
