import sys
import os
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from src.model_loader import model_inference_batch_vllm
from src.helper import extract_math_answer, verify_from_response
from tqdm import tqdm
import json


def run_math_inference(model, tokenizer, data, save_path, batch_size=8, max_new_tokens=3000, COT=True, temperature=0, base_model_prompt=None):
    dataset_question, dataset_answer = data
    print('we have', len(dataset_question), 'examples')
    submission = []
    full_record = []
    os.makedirs(save_path, exist_ok=True)
    for batch_idx in tqdm(range(0, len(dataset_question), batch_size)):
        batch_slice = dataset_question[batch_idx:batch_idx + batch_size]
    
        # Prepare batch of messages
        batch_messages = [
            [{'role': 'user', 'content': task}]
            for task in batch_slice
        ]
        
        # Run batch inference always use vllm
        batch_responses = model_inference_batch_vllm(model, tokenizer, batch_messages, temperature=temperature, max_new_tokens=max_new_tokens, prompt=base_model_prompt)


        # Process results
        for response_idx in range(len(batch_responses)):
            try:
                answer = extract_math_answer(batch_responses[response_idx])
            except Exception as e:
                print(f"Error processing response: {e}")
                answer = None
            
            submission.append(answer)
            # this means we are sampling for our negative examples because we need labels for them
            full_record.append({
                'answer': answer,
                'response': batch_responses[response_idx],
                'label': dataset_answer[batch_idx + response_idx],
                'question': dataset_question[batch_idx + response_idx],
            })
                
    # Final save
    with open(f'{save_path}/submission.json', 'w') as f:
        json.dump(submission, f)
    with open(f'{save_path}/full_record.json', 'w') as f:
        json.dump(full_record, f, indent=4)
    # clear checkpoint files
    for file in os.listdir(save_path):
        if 'check_point' in file:
            os.remove(os.path.join(save_path, file))
    print('All done!')
    return submission, full_record

def run_feedback_inference(model, tokenizer, full_record_path, save_path, batch_size=8, max_new_tokens=3000, feedback_prompt=None, base_model_prompt=None, temperature=0.0):
    reasoning_prompt = "Please reason step by step, and put your final answer within \\boxed{}. "
    full_record = []
    os.makedirs(save_path, exist_ok=True)
    with open(f'{full_record_path}/full_record.json', 'r') as f:
        initial_run = json.load(f)
    for batch_idx in tqdm(range(0, len(initial_run), batch_size)):
        batch_slice = initial_run[batch_idx:batch_idx + batch_size]
        # Prepare batch of messages
        # this is when we have a good reference model that can revise the previous wrong answer
        batch_messages = [
            [{'role': 'user', 'content': feedback_prompt.format(question=reasoning_prompt + row['question'], initial_response=row['response'])}]
            for row in batch_slice
        ]
        if base_model_prompt:
            batch_messages = [
                [row['question'], row['response']]
                for row in batch_slice
            ]
        batch_responses = model_inference_batch_vllm(model, tokenizer, batch_messages, max_new_tokens=max_new_tokens, retro=False, prompt=base_model_prompt, temperature=temperature)

        # Process results
        for response_idx in range(len(batch_responses)):
            current_messages = batch_messages[response_idx] 
            full_record.append({
                'question': initial_run[batch_idx + response_idx]['question'],
                'response': batch_responses[response_idx],
                'initial_response': initial_run[batch_idx + response_idx]['response'],
                'label': extract_math_answer(initial_run[batch_idx + response_idx]['label']),
                'messages': batch_messages[response_idx] + [{'role': 'assistant', 'content': batch_responses[response_idx]}]
            })

    # Final save
    with open(f'{save_path}/full_record.json', 'w') as f:
        json.dump(full_record, f, indent=4)
    # clear checkpoint files
    for file in os.listdir(save_path):
        if 'check_point' in file:
            os.remove(os.path.join(save_path, file))
    print('All done!')
    return full_record

def run_math_inference_with_feedback(model, tokenizer, data, full_record_path, save_path, batch_size=8, max_new_tokens=3000, refine_prompt=None, skip_if_verified=False, base_model_prompt=None, exclude_final_assessment=False):
    # skip if verified means we skip the second attempt if the verifier thinks the initial response is correct
    reasoning_prompt = "Please reason step by step, and put your final answer within \\boxed{}. "
    dataset_question, dataset_answer = data
    submission = []
    full_record = []
    os.makedirs(save_path, exist_ok=True)
    with open(f'{full_record_path}/full_record.json', 'r') as f:
        messages_record = json.load(f)
    for batch_idx in tqdm(range(0, len(messages_record), batch_size)):
        batch_slice = messages_record[batch_idx:batch_idx + batch_size]
        # Prepare batch of messages
        # this is when we have a good reference model that can revise the previous wrong answer
        ### fixme
        # batch_messages = [
        #     [{'role': 'user', 'content': refine_prompt.format(question=row['question'], initial_response=row['initial_response'], feedback=row['response'])}]
        #     for row in batch_slice
        # ]
        batch_messages = []
        for row in batch_slice:
            current_feedback = row['response']
            if exclude_final_assessment:
                current_feedback = current_feedback.split("Final Assessment:")[0].strip() # remove any trailing assessment
            current_message = refine_prompt.format(question=row['question'], initial_response=row['initial_response'], feedback=current_feedback)
            batch_messages.append([{'role': 'user', 'content': current_message}])
            # print(current_message) # FIXME
            # print("+++++++++++++++++++++++++++++++")
        print(batch_messages[0][0]['content']) # FIXME
        # Run batch inference
        if base_model_prompt:
            # batch_messages = [
            #     [row['question'], row['initial_response'], row['response']]
            #     for row in batch_slice
            # ]
            batch_messages = []
            for row in batch_slice:
                current_feedback = row['response']
                if exclude_final_assessment:
                    current_feedback = current_feedback.split("Final Assessment:")[0].strip() # remove any trailing assessment
                batch_messages.append([row['question'], row['initial_response'], current_feedback])
                # print(current_message) # FIXME
                # print("+++++++++++++++++++++++++++++++")
        batch_responses = model_inference_batch_vllm(model, tokenizer, batch_messages, max_new_tokens=max_new_tokens, retro=False, prompt=base_model_prompt)

        # Process results
        for response_idx in range(len(batch_responses)):
            # Check if verifier thinks initial response is correct
            if verify_from_response(batch_slice[response_idx]['response']) and skip_if_verified:
                # Use initial response instead of refined response
                response = batch_slice[response_idx]['initial_response']
            else:
                # Use refined response from model
                response = batch_responses[response_idx]
            
            try:
                answer = extract_math_answer(response)
            except Exception as e:
                print(f"Error processing response: {e}")
                answer = None
            full_record.append({
                'question': batch_messages[response_idx],
                'response': response,
                'label': extract_math_answer(dataset_answer[batch_idx + response_idx]),
                'answer': answer,
            })
            submission.append(answer)

    # Final save
    with open(f'{save_path}/submission.json', 'w') as f:
        json.dump(submission, f)
    with open(f'{save_path}/full_record.json', 'w') as f:
        json.dump(full_record, f, indent=4)
    # clear checkpoint files
    for file in os.listdir(save_path):
        if 'check_point' in file:
            os.remove(os.path.join(save_path, file))
    print('All done!')
    return submission, full_record