#!/usr/bin/env python3
"""
Script to bin rollout responses and original responses by math_equal equivalence.

Given a completions file created with completions_rollouts.py,
this script:
1. Groups all rollout responses per prompt by math_equal equivalence
2. If present, also bins original completion responses using the same equivalence classes
3. Assigns bin indices (0, 1, 2, ...) to each response group
4. Saves the modified completions file with added bin index fields

The binning is done per prompt, meaning responses from different prompts
are not compared against each other. Original responses and rollout responses
are binned together in a unified equivalence class space.
"""

import argparse
import os
import time
import warnings
from typing import List, Dict, Any

import torch
from calib.utils import math_equal

# Suppress the SyntaxWarning about callable objects that occurs during SymPy parsing
# (This is mainly needed for the original math_equal, ultra_math_equal handles this better)
warnings.filterwarnings('ignore', category=SyntaxWarning, message='.*object is not callable.*')


def bin_responses_per_prompt(rollout_responses: List[List[str]], responses: List[str] = None, prompt_idx: int = None) -> tuple[List[List[int]], List[int], Dict[int, str]]:
    """
    Bin responses for a single prompt by math_equal equivalence.
    
    Args:
        rollout_responses: 2D list [M][P] where M=completions, P=paragraphs
        responses: 1D list [M] of original completion responses (optional)
        prompt_idx: Optional prompt index for debugging output
        
    Returns:
        tuple of:
        - bin_indices: 2D list [M][P] with bin index for each rollout response
        - response_bin_indices: 1D list [M] with bin index for each original response (None if responses not provided)
        - bin_representatives: Dict mapping bin_index -> representative response string
    """
    bin_indices: List[List[int]] = []
    response_bin_indices: List[int] = [] if responses is not None else None
    bin_representatives: Dict[int, str] = {}
    next_bin_id = 0
    
    # Track all unique responses we've seen so far
    response_to_bin: Dict[str, int] = {}
    
    if prompt_idx is not None:
        print(f"\n=== DEBUGGING: Prompt {prompt_idx} ===")
    
    # First, process original responses if provided
    if responses is not None:
        if prompt_idx is not None:
            print(f"  Processing original responses:")
        for completion_idx, response in enumerate(responses):
            response_str = str(response) if response is not None else ""
            
            # Check if this response is equivalent to any existing bin
            assigned_bin = None
            matched_with = None
            
            if response_str in response_to_bin:
                assigned_bin = response_to_bin[response_str]
                matched_with = response_str
                if prompt_idx is not None:
                    print(f"    Original {completion_idx}: '{response_str}' -> bin {assigned_bin} (exact string match)")
            else:
                # Fall back to math_equal for mathematical equivalence
                for existing_response, existing_bin in response_to_bin.items():
                    if math_equal(response_str, existing_response):
                        assigned_bin = existing_bin
                        matched_with = existing_response
                        response_to_bin[response_str] = existing_bin
                        if prompt_idx is not None:
                            print(f"    Original {completion_idx}: '{response_str}' -> bin {assigned_bin} (math equivalent to '{matched_with}')")
                        break
            
            if assigned_bin is None:
                # Create new bin
                assigned_bin = next_bin_id
                response_to_bin[response_str] = assigned_bin
                bin_representatives[assigned_bin] = response_str
                next_bin_id += 1
                
                if prompt_idx is not None:
                    print(f"    Original {completion_idx}: '{response_str}' -> NEW bin {assigned_bin}")
            
            response_bin_indices.append(assigned_bin)
    
    # Then process rollout responses
    for completion_idx, completion_responses in enumerate(rollout_responses):
        completion_bins: List[int] = []
        
        if prompt_idx is not None:
            print(f"  Completion {completion_idx}:")
        
        for paragraph_idx, response in enumerate(completion_responses):
            # Convert to string and handle None responses
            response_str = str(response) if response is not None else ""
            
            # Check if this response is equivalent to any existing bin
            # First try exact string match (faster)
            assigned_bin = None
            matched_with = None
            
            if response_str in response_to_bin:
                assigned_bin = response_to_bin[response_str]
                matched_with = response_str
                if prompt_idx is not None:
                    print(f"    Para {paragraph_idx}: '{response_str}' -> bin {assigned_bin} (exact string match)")
            else:
                # Fall back to math_equal for mathematical equivalence
                for existing_response, existing_bin in response_to_bin.items():
                    if math_equal(response_str, existing_response):
                        assigned_bin = existing_bin
                        matched_with = existing_response
                        # Store this response for future exact string matching
                        response_to_bin[response_str] = existing_bin
                        if prompt_idx is not None:
                            print(f"    Para {paragraph_idx}: '{response_str}' -> bin {assigned_bin} (math equivalent to '{matched_with}')")
                        break
            
            if assigned_bin is None:
                # Create new bin
                assigned_bin = next_bin_id
                response_to_bin[response_str] = assigned_bin
                bin_representatives[assigned_bin] = response_str
                next_bin_id += 1
                
                if prompt_idx is not None:
                    print(f"    Para {paragraph_idx}: '{response_str}' -> NEW bin {assigned_bin}")
            
            completion_bins.append(assigned_bin)
        
        bin_indices.append(completion_bins)
    
    if prompt_idx is not None:
        print(f"  Final bin representatives: {bin_representatives}")
        print(f"  Final rollout bin indices: {bin_indices}")
        if response_bin_indices is not None:
            print(f"  Final response bin indices: {response_bin_indices}")
    
    return bin_indices, response_bin_indices, bin_representatives


def main():
    parser = argparse.ArgumentParser(
        description="Bin rollout responses by math_equal equivalence and save modified completions file."
    )
    parser.add_argument(
        "--input_path", 
        type=str, 
        help="Path to the completions .pt file created by completions_rollouts_from_file.py"
    )
    parser.add_argument(
        "--output_dir", 
        type=str, 
        default="outputs/completions_rollouts", 
        help="Directory to save binned completions file"
    )
    parser.add_argument(
        "--force_overwrite",
        action="store_true",
        help="Overwrite output file if it already exists"
    )
    parser.add_argument(
        "--debug",
        action="store_true",
        help="Print detailed debugging information about binning process"
    )
    
    args = parser.parse_args()
    
    # Validate input file exists
    if not os.path.exists(args.input_path):
        raise FileNotFoundError(f"Input file not found: {args.input_path}")
    
    print(f"Loading completions from {args.input_path}...")
    start_time = time.time()
    
    # Load existing completions bundle
    bundle: Dict[str, Any] = torch.load(args.input_path, map_location="cpu", weights_only=False)
    
    # Validate required fields
    required_fields = ["rollout_responses", "completions", "prompts_text"]
    for field in required_fields:
        if field not in bundle:
            raise ValueError(f"Input bundle missing required field: {field}")
    
    rollout_responses: List[List[List[str]]] = bundle["rollout_responses"]  # [N][M][P]
    responses: List[List[str]] = bundle.get("responses")  # [N][M] - original completion responses
    N = len(rollout_responses)
    
    if N == 0:
        raise ValueError("No rollout responses found in input bundle.")
    
    print(f"Found {N} prompts with rollout responses")
    
    # Initialize bin structures
    rollout_bin_indices: List[List[List[int]]] = []  # [N][M][P]
    response_bin_indices: List[List[int]] = [] if responses is not None else None  # [N][M]
    rollout_bin_representatives: List[Dict[int, str]] = []  # [N] -> {bin_id: representative}
    
    # Process each prompt separately
    for prompt_idx in range(N):
        prompt_rollout_responses = rollout_responses[prompt_idx]
        prompt_responses = responses[prompt_idx] if responses is not None else None
        print(f"Processing prompt {prompt_idx + 1}/{N}...")
        
        rollout_bins, response_bins, bin_reps = bin_responses_per_prompt(
            prompt_rollout_responses, 
            prompt_responses, 
            prompt_idx if args.debug else None
        )
        rollout_bin_indices.append(rollout_bins)
        if response_bin_indices is not None:
            response_bin_indices.append(response_bins)
        rollout_bin_representatives.append(bin_reps)
        
        # Print some stats
        total_rollout_responses = sum(len(completion) for completion in prompt_rollout_responses)
        total_original_responses = len(prompt_responses) if prompt_responses is not None else 0
        num_bins = len(bin_reps)
        print(f"  Prompt {prompt_idx}: {total_rollout_responses} rollout responses + {total_original_responses} original responses -> {num_bins} bins")
    
    # Generate output filename
    input_basename = os.path.basename(args.input_path)
    if input_basename.endswith('.pt'):
        input_basename = input_basename[:-3]
    
    output_filename = os.path.join(args.output_dir, f"{input_basename}_binned.pt")
    
    # Check if output file exists
    if os.path.exists(output_filename) and not args.force_overwrite:
        raise FileExistsError(
            f"Output file already exists: {output_filename}. "
            "Use --force_overwrite to overwrite."
        )
    
    print(f"Saving binned completions to {output_filename}...")
    
    # Create output bundle with additional fields
    output_bundle = dict(bundle)
    output_bundle.update({
        "rollout_bin_indices": rollout_bin_indices,
        "rollout_bin_representatives": rollout_bin_representatives,
    })
    
    # Add response_bin_indices only if we processed original responses
    if response_bin_indices is not None:
        output_bundle["response_bin_indices"] = response_bin_indices
    
    # Update args to include binning parameters
    if "args" in output_bundle:
        if hasattr(output_bundle["args"], "__dict__"):
            output_bundle["args"].__dict__.update({
                "binned": True,
                "binning_input_path": args.input_path,
                "binning_output_dir": args.output_dir,
            })
        elif isinstance(output_bundle["args"], dict):
            output_bundle["args"].update({
                "binned": True,
                "binning_input_path": args.input_path,
                "binning_output_dir": args.output_dir,
            })
    
    # Ensure output directory exists
    os.makedirs(args.output_dir, exist_ok=True)
    
    # Save the bundle
    torch.save(output_bundle, output_filename)
    
    end_time = time.time()
    
    # Print summary statistics
    total_rollout_responses = sum(
        sum(len(completion) for completion in prompt_responses)
        for prompt_responses in rollout_responses
    )
    total_original_responses = sum(len(prompt_responses) for prompt_responses in responses) if responses is not None else 0
    total_responses = total_rollout_responses + total_original_responses
    total_bins = sum(len(bin_reps) for bin_reps in rollout_bin_representatives)
    
    print(f"\nBinning completed in {end_time - start_time:.2f} seconds")
    print(f"Total rollout responses: {total_rollout_responses}")
    if responses is not None:
        print(f"Total original responses: {total_original_responses}")
    print(f"Total responses: {total_responses}")
    print(f"Total unique bins: {total_bins}")
    print(f"Compression ratio: {total_responses / total_bins:.2f}x")
    print(f"Saved to: {output_filename}")


if __name__ == "__main__":
    main()
