#!/usr/bin/env python3

import json
import random
import re
import argparse
import math
from typing import List, Dict, Any, Optional, Tuple

def is_integer_answer(answer: str) -> bool:
    """Check if the answer is an integer."""
    try:
        int_answer = int(answer.replace(',', ''))
        return True
    except (ValueError, TypeError):
        return False

def find_integers_in_problem(problem: str) -> List[Dict]:
    """Find all integers in the problem statement with their positions."""
    # Find all numbers in the problem with positions
    number_pattern = r'\b(\d|,)+\.?\d*\b'
    matches = {}
    
    for match in re.finditer(number_pattern, problem):

        try:
            num = int(match.group().replace(',', ''))

            if num not in matches:
                matches[num] = []
            
            matches[num].append(match.span())

        except ValueError:
            continue
    
    return matches

def replace_integer_with_variable(problem: str, integer_to_replace: int, variable_name: str, positions: List[Tuple[int, int]]) -> str:
    """Replace a specific integer in the problem with a variable."""
    # Use word boundaries to ensure we replace whole numbers only
    for start, end in reversed(positions):
        assert int(problem[start:end].replace(',', '')) == integer_to_replace, f"Problem: {problem} start: {start} end: {end} integer_to_replace: {integer_to_replace}"
        problem = problem[:start] + variable_name + problem[end:]
    return problem

def create_variable_relationship(previous_answer: int, target_integer: int) -> str:
    """Create a relationship between the previous answer and the target integer."""
    relationships = [
        f"" if previous_answer == target_integer else None,
        f" plus {target_integer - previous_answer}" if target_integer > previous_answer else None,
        f" minus {previous_answer - target_integer}" if previous_answer > target_integer else None,
        f" times {target_integer // previous_answer}" if previous_answer > 0 and target_integer % previous_answer == 0 else None,
        f" divided by {previous_answer // target_integer}" if target_integer > 0 and previous_answer % target_integer == 0 else None,
    ]

    valid_relationships = [r for r in relationships if r is not None]
    
    return random.choice(valid_relationships)

def process_hendrycks_math_dataset(
    input_file: str,
    output_file: str,
    num_subproblems: int,
    num_repetitions: int = 1,
    seed: int = 42
):
    """Process Hendrycks Math dataset to create long horizon reasoning tasks."""
    
    print(f"Loading dataset from {input_file}...")
    with open(input_file, 'r', encoding='utf-8') as f:
        data = [json.loads(line.strip()) for line in f]
    
    print(f"Loaded {len(data)} entries")

    if num_subproblems == 1:
        for entry in data:
            entry['prompt'] = "Solve the following math problem step by step:\n\n" + entry['problem']
            entry['final_answer'] = entry['answer']
            with open(output_file, 'a', encoding='utf-8') as f:
                json.dump(entry, f, ensure_ascii=False)
                f.write('\n')
        print(f"Saved {len(data)} entries to {output_file}")
        return
    
    # Filter entries with integer answers
    integer_answer_data = []
    for entry in data:
        if entry['answer'] is not None and is_integer_answer(entry['answer']):
            integer_answer_data.append(entry)
    
    print(f"Found {len(integer_answer_data)} entries with integer answers")
    
    random.seed(seed)
    
    all_results = []
    skipped_sequences = 0
    successful_sequences = 0
    
    for rep in range(num_repetitions):
        print(f"\nRepetition {rep + 1}/{num_repetitions}")
        
        # Shuffle the dataset
        shuffled_data = integer_answer_data.copy()
        random.shuffle(shuffled_data)
        
        i = 0
        while i + num_subproblems <= len(shuffled_data):
            sequence = shuffled_data[i:i + num_subproblems]
            
            # Try to create a valid sequence
            success, combined_problem = create_combined_problem(sequence, num_subproblems)
            
            if success:
                all_results.append(combined_problem)
                successful_sequences += 1
                if successful_sequences % 100 == 0:
                    print(f"Created {successful_sequences} successful sequences...")
            else:
                skipped_sequences += 1
            
            i += num_subproblems
    
    print(f"\nTotal successful sequences: {successful_sequences}")
    print(f"Total skipped sequences: {skipped_sequences}")
    
    # Save results
    print(f"Saving {len(all_results)} combined problems to {output_file}...")
    with open(output_file, 'w', encoding='utf-8') as f:
        for result in all_results:
            json.dump(result, f, ensure_ascii=False)
            f.write('\n')
    
    print("Dataset processing completed successfully!")

def create_combined_problem(sequence: List[Dict], num_subproblems: int) -> Tuple[bool, Optional[Dict]]:
    """Create a combined problem from a sequence of individual problems."""
    
    # variable_names = ['X', 'Y', 'Z', 'U', 'V', 'W']
    variable_names = ['TMP']
    random.shuffle(variable_names)
    
    prompt_parts = []
    previous_answer = None
    
    # First problem (no variable substitution)
    first_problem = sequence[0]
    try:
        previous_answer = int(first_problem['answer'].replace(',', ''))
    except (ValueError, TypeError):
        return False, None
    
    prompt_parts.append(f"Step 1: Solve the following math problem step by step:\n\n{first_problem['problem']}")
    
    # Subsequent problems (with variable substitution)
    for idx in range(1, num_subproblems):
        if idx >= len(sequence):
            return False, None
            
        current_problem = sequence[idx]
        
        # Find integers in the current problem
        integers_in_problem = find_integers_in_problem(current_problem['problem'])
        
        if not integers_in_problem:
            # Skip this sequence if no suitable integers found
            return False, None
        
        # Pick a random integer to replace
        target_integer = random.choice(list(integers_in_problem.keys()))
        variable_name = variable_names[idx % len(variable_names)]
        
        # Create modified problem with variable
        modified_problem = replace_integer_with_variable(
            current_problem['problem'], 
            target_integer, 
            variable_name,
            integers_in_problem[target_integer]
        )
        
        # Create relationship description
        relationship = create_variable_relationship(previous_answer, target_integer)
        
        step_num = idx * 2
        
        prompt_parts.append(
            f"Step {step_num}: Let {variable_name} be the final answer from the previous step{relationship}. "
            f"Substitute {variable_name} in the following problem:\n\n{modified_problem}\n\n"
            f"Write out the updated version of the problem with the actual number in place of {variable_name}."
        )
        
        prompt_parts.append(
            f"Step {step_num + 1}: Solve the updated problem from Step {step_num} step by step."
        )
        
        # Update previous answer for next iteration
        try:
            previous_answer = int(current_problem['answer'].replace(',', ''))
        except (ValueError, TypeError):
            return False, None
    
    # Final instruction
    final_prompt = "\n\n".join(prompt_parts)
    final_prompt += " In the end, provide only the final numerical answer."
    
    # Calculate the final answer (answer of the last problem)
    try:
        final_answer = int(sequence[-1]['answer'].replace(',', ''))
    except (ValueError, TypeError):
        return False, None
    
    combined_problem = {
        "prompt": final_prompt,
        "final_answer": final_answer,
    }
    
    return True, combined_problem

def main():
    parser = argparse.ArgumentParser(description="Combine Hendrycks Math problems into long horizon reasoning tasks")
    parser.add_argument("--input_file", type=str, 
                       default="datasets/hendrycks_math_train_all_with_answers.jsonl",
                       help="Input file with Hendrycks Math problems")
    parser.add_argument("--output_file", type=str, required=True,
                       help="Output file for combined problems")
    parser.add_argument("--num_subproblems", type=int, required=True,
                       help="Number of subproblems to combine")
    parser.add_argument("--num_repetitions", type=int, default=1,
                       help="Number of times to repeat the process")
    parser.add_argument("--seed", type=int, default=42,
                       help="Random seed for reproducibility")
    
    args = parser.parse_args()
    
    process_hendrycks_math_dataset(
        input_file=args.input_file,
        output_file=args.output_file,
        num_subproblems=args.num_subproblems,
        num_repetitions=args.num_repetitions,
        seed=args.seed
    )

if __name__ == "__main__":
    main()
