import re
import pandas as pd
def format_MMLU(example):
    """
    Converts a single example into a multiple-choice format.
    This function can be used with the dataset.map() method.
    
    Args:
        example: A single example from the dataset
        
    Returns:
        The transformed example in multiple-choice format
    """
    # Convert choices to a dictionary with numerical indices as keys
    choices_display = ""
    for i, choice in enumerate(example['choices']):
        choices_display += f"Option {i}: {choice}\n"

    prompt = ''
    prompt += example['question'] + '\n'
    prompt += choices_display + '\n'
    prompt += "IMPORTANT: Your answer must be ONLY the numerical index (0, 1, 2, or 3) of the correct option. "
    
    # Return the modified example
    return {
        'question': prompt,
        'answer': example['answer']  # Assuming 'answer' is the field with the correct answer
    }

def format_MED(example):
    """
    Converts a single example into a multiple-choice format.
    This function can be used with the dataset.map() method.
    
    Args:
        example: A single example from the dataset
        
    Returns:
        The transformed example in multiple-choice format
    """
    # Convert choices to a dictionary with numerical indices as keys
    choices_display = ""

    prompt = ''
    prompt += example['question'] + '\n'
    prompt += f"Option 0: {example['opa']}\n"
    prompt += f"Option 1: {example['opb']}\n"
    prompt += f"Option 2: {example['opc']}\n"
    prompt += f"Option 3: {example['opd']}\n"
    prompt += "\n"
    prompt += "IMPORTANT: Your answer must be ONLY the numerical index (0, 1, 2, or 3) of the correct option. "
    
    # Return the modified example
    return {
        'question': prompt,
        'answer': example['cop']  # Assuming 'answer' is the field with the correct answer
    }

def extract_math_answer(text):
    try:
        try:
            result = int(text)
            return result
        except Exception:
            pass
        # if we are dealing with original data we would have ####
        if '####' in text:
            idx = text.find('####')
            try:
                result = int(text[idx+5:].replace(',', ''))
            except Exception as e:
                print(f"Error converting to int: {e}")
                return None
            return result

        final_answer_match = re.search(r'Final Answer:\s*```\s*((?:\d\s+\d\s+\d\s+\d\s*\n?){4})\s*```', text, re.MULTILINE)
    
        if final_answer_match:
            grid_text = final_answer_match.group(1)
            
            # Extract all digits from the grid
            digits = re.findall(r'\d', grid_text)
            
            # Join all digits into a single string
            result = ''.join(digits)
            return int(result)
        else:
            # Find all content inside \boxed{...}
            boxed_pattern = r'boxed\{([^}]*)\}'
            boxed_matches = re.findall(boxed_pattern, text)
            
            if not boxed_matches:
                return None
            try:
                # Take the last boxed answer (usually the final answer)
                boxed_content = boxed_matches[-1]
                
                # Extract only digits from the boxed content
                # This ignores formatting like commas, dollar signs, etc.
                numbers_only = re.sub(r'[^\d.]', '', boxed_content)
                
                return int(float(numbers_only))
            except Exception as e:
                print(e)
                return None
    except Exception as e:
        return None

def extract_all_math_answer(text):
    # If we are dealing with original data marked by "####", extract the following number
    
    # Find all content inside boxed{...}
    boxed_pattern = r'boxed\{([^}]*)\}'
    boxed_matches = re.findall(boxed_pattern, text)
    
    if not boxed_matches:
        return []
    
    results = []
    for content in boxed_matches:
        # Remove any characters that are not digits or decimal points.
        numbers_only = re.sub(r'[^\d.]', '', content)
        try:
            # Convert to number and then to int.
            number = int(float(numbers_only))
            results.append(number)
        except Exception as e:
            print(f"Error parsing '{content}': {e}")
    return results

def base_model_chat_template(messages, retro=False, tokenizer=None):
    result = []
    if retro:
        for message in messages:
            current = ""
            for role, content in message.items():
                if role == "user":
                    current += f"User: {content}\n"
                elif role == "assistant":
                    current += f"Assistant: {content}"
            result.append(current)

def extract_last_question_and_solution(text):
    # Find last Problem and Solution markers
    last_problem_idx = text.rfind("Problem:")
    last_solution_idx = text.rfind("Solution:")
    last_analysis_idx = text.find("Analysis:", last_solution_idx)
    
    # Extract last question (between Problem and Solution)
    last_question = text[last_problem_idx + len("Problem:"): last_solution_idx].strip() if last_problem_idx != -1 and last_solution_idx != -1 else None
    
    # Extract last solution (between Solution and Analysis)
    if last_solution_idx != -1:
        if last_analysis_idx != -1:
            last_solution = text[last_solution_idx: last_analysis_idx].strip()
        else:
            last_solution = text[last_solution_idx:].strip()
    else:
        last_solution = None
    
    return last_question.strip(), last_solution[9:].strip()

def verify_from_response(response):
    # this function is used to determine if the verifier thinks the response is correct
    # it will return True if the verifier thinks the response is correct, False otherwise
    if "there is no error" in response.lower():
        return True
    else:
        return False
    
def evaluate_record(full_record):
    correct, valid = 0, 0
    for record in full_record:
        if extract_math_answer(record['response']) == extract_math_answer(record['label']):
            correct += 1
        if isinstance(extract_math_answer(record['response']), int):
            valid += 1
    eval_df = {
        'number_of_predictions': [len(full_record)],
        'valid_answer': [valid/len(full_record)],
        'accuracy': [correct/len(full_record)]
    }
    print('=>'*50)
    # print df
    for key in eval_df.keys():
        print(f'{key}: {eval_df[key][0]}')
    print('=>'*50)
    return pd.DataFrame(eval_df)