import json
import re
import os
import random

def get_target_integer(target_list):
    """Extract integer from target list"""
    for target in target_list:
        if is_pure_integer(target):
            try:
                cleaned = re.sub(r'[\\{}\s]', '', target)
                return int(cleaned)
            except ValueError as e:
                print(f"Warning: Cannot convert '{target}' -> '{cleaned}' to integer: {e}")
                continue
    return None

def is_pure_integer(text):
    """Check if text is pure integer"""
    # First check if contains math symbols or special characters, return False if any
    if re.search(r'\\[a-zA-Z]', text):  # Contains LaTeX commands like \sqrt, \frac etc
        return False

    if re.search(r'[a-zA-Z]', text):  # Contains letters
        return False

    if '/' in text:  # Contains fractions
        return False

    if '.' in text:  # Contains decimal points
        return False

    # Remove spaces, braces etc
    cleaned = re.sub(r'[\\{}\s]', '', text)

    # Final check if it's pure integer
    pattern = r'^-?\d+$'
    result = bool(re.match(pattern, cleaned))

    return result

def load_pass_rate_data(result_base_path):
    """Load multiple evaluation result data, calculate pass_rate and average output tokens for each instance"""
    pass_rate_mapping = {}
    avg_token_mapping = {}

    # First list all subdirectories under the directory
    if not os.path.exists(result_base_path):
        print(f"Warning: Base path does not exist - {result_base_path}")
        return pass_rate_mapping, avg_token_mapping
    
    all_subdirs = [d for d in os.listdir(result_base_path) if os.path.isdir(os.path.join(result_base_path, d))]
    print(f"Found subdirectories: {all_subdirs}")
    
    # Match directories containing "-number-32" pattern
    file_paths = []
    pattern = re.compile(r'-(\d+)-32')
    
    for subdir in all_subdirs:
        match = pattern.search(subdir)  # Use search instead of match
        if match:
            number = int(match.group(1))
            if 0 <= number <= 31:  # Ensure number is in 0-31 range
                scenario_file = os.path.join(result_base_path, subdir, "scenario_state.json")
                if os.path.exists(scenario_file):
                    file_paths.append(scenario_file)
                    print(f"Found matching file: {subdir}")
    
    if not file_paths:
        print(f"Warning: No matching result files found")
        return pass_rate_mapping, avg_token_mapping
    
    print(f"Total found {len(file_paths)} result files")
    
    # Store results for each instance across all files
    instance_results = {}  # {instance_id: {'scores': [], 'tokens': []}}
    failed_score_extraction_count = 0  # Count samples where score extraction failed
    
    # Iterate through all files
    for file_path in file_paths:
        try:
            with open(file_path, 'r', encoding='utf-8') as f:
                data = json.load(f)

            request_states = data.get('request_states', [])

            for request_state in request_states:
                # Get instance information
                instance = request_state.get('instance', {})
                instance_id = instance.get('id', '')
                if not instance_id:
                    continue

                # Initialize result storage for this instance
                if instance_id not in instance_results:
                    instance_results[instance_id] = {'scores': [], 'tokens': []}

                # Get score information
                model_score_result = request_state.get('model_score_result', '{}')
                # Use regex to directly extract score
                score_match = re.search(r'{"score":\s*(\d+)}', model_score_result)
                if score_match:
                    score = float(score_match.group(1))
                    instance_results[instance_id]['scores'].append(score)
                else:
                    # Cannot extract score, record as 0 (error)
                    print(f"Warning: Cannot extract score - instance_id: {instance_id}, marked as error")
                    instance_results[instance_id]['scores'].append(0.0)
                    failed_score_extraction_count += 1

                # Get output token count
                result = request_state.get('result', {})
                completions = result.get('completions', [])
                if completions:
                    output_token_num = completions[0].get('output_token_num', 0)
                    try:
                        instance_results[instance_id]['tokens'].append(float(output_token_num))
                    except (ValueError, TypeError):
                        instance_results[instance_id]['tokens'].append(0.0)
                else:
                    instance_results[instance_id]['tokens'].append(0.0)

        except FileNotFoundError:
            print(f"Warning: File does not exist - {file_path}")
            continue
        except json.JSONDecodeError as e:
            print(f"Warning: JSON parse error - {file_path}: {e}")
            continue
        except Exception as e:
            print(f"Warning: Error loading file - {file_path}: {e}")
            continue

    # Calculate pass_rate and average tokens for each instance
    for instance_id, results in instance_results.items():
        scores = results['scores']
        tokens = results['tokens']

        if scores:
            # Calculate pass_rate (proportion of scores equal to 1)
            pass_count = sum(1 for score in scores if score == 1.0)
            pass_rate = pass_count / len(scores)
            pass_rate_mapping[instance_id] = pass_rate

        if tokens:
            # Calculate average output tokens
            avg_tokens = sum(tokens) / len(tokens)
            avg_token_mapping[instance_id] = avg_tokens

    print(f"Successfully loaded pass_rate data, {len(pass_rate_mapping)} samples have pass_rate info")
    print(f"Successfully loaded average output token data, {len(avg_token_mapping)} samples have average token info")
    print(f"Samples with failed score extraction: {failed_score_extraction_count}")

    return pass_rate_mapping, avg_token_mapping

def create_dependency_expression(prev_answer_placeholder, prev_answer_value, selected_number):
    """Create dependency expression ensuring variable value equals selected_number"""
    # Calculate needed difference: selected_number = prev_answer_value + diff
    # So diff = selected_number - prev_answer_value
    diff = selected_number - prev_answer_value

    if diff >= 0:
        return f"{prev_answer_placeholder} + {diff}"
    else:
        return f"{prev_answer_placeholder} - {abs(diff)}"

def create_linked_input(prev_answer_placeholder, prev_answer_value, sample, problem_index):
    """Create chained input using selected_variable information for replacement"""
    
    input_text = sample.get('input', '')
    selected_variable = sample.get('selected_variable', {})
    
    # Get selected variable information
    selected_number = selected_variable.get('number')
    start_pos = selected_variable.get('start_pos')
    end_pos = selected_variable.get('end_pos')
    is_in_math_env = selected_variable.get('is_in_math_env', False)
    
    if selected_number is None or start_pos is None or end_pos is None:
        print(f"Warning: Sample {sample.get('instanceId', 'unknown')} missing selected_variable info")
        return None

    # Create variable name based on problem index and math environment
    base_var_name = f"[variable{problem_index}]"
    if is_in_math_env:
        # In math environment, use braces around variable name to ensure it's treated as a whole
        var_name = f"{{{base_var_name}}}"
    else:
        var_name = base_var_name

    # Create dependency expression ensuring variable value equals selected_number
    dependency_expr = create_dependency_expression(prev_answer_placeholder, prev_answer_value, selected_number)

    # Construct chained input
    linked_input = f"Using the result {prev_answer_placeholder} from the previous calculation, {base_var_name} = {dependency_expr}. "

    # Use selected_variable position info to replace selected number in original problem
    modified_input = input_text[:start_pos] + var_name + input_text[end_pos:]
    linked_input += modified_input

    return linked_input

def load_key_variable_samples(input_file, pass_rate_mapping, avg_token_mapping):
    """Load samples containing key variable information and match with pass_rate data"""
    samples = []
    
    if not os.path.exists(input_file):
        print(f"Error: Input file does not exist - {input_file}")
        return samples
    
    # Convert to sets for fast lookup
    pass_rate_keys = set(pass_rate_mapping.keys())
    avg_token_keys = set(avg_token_mapping.keys())
    
    found_pass_rate = 0
    found_avg_token = 0
    
    with open(input_file, 'r', encoding='utf-8') as infile:
        content = infile.read().strip()
        
        # Split each JSON object by double newlines
        json_blocks = content.split('\n\n')
        
        for i, block in enumerate(json_blocks):
            block = block.strip()
            if not block:
                continue
                
            try:
                sample = json.loads(block)
                
                # Validate required fields
                if all(key in sample for key in ['instanceId', 'input', 'target', 'selected_variable']):
                    # Match pass_rate and avg_token data
                    instance_id_str = str(sample['instanceId'])
                    
                    # Find corresponding pass_rate data
                    if instance_id_str in pass_rate_keys:
                        sample['pass_rate'] = pass_rate_mapping[instance_id_str]
                        found_pass_rate += 1
                    else:
                        sample['pass_rate'] = None
                        
                    # Find corresponding avg_token data
                    if instance_id_str in avg_token_keys:
                        sample['avg_output_tokens'] = avg_token_mapping[instance_id_str]
                        found_avg_token += 1
                    else:
                        sample['avg_output_tokens'] = None
                    
                    samples.append(sample)
                else:
                    print(f"Warning: Sample {i+1} missing required fields")
                    
            except json.JSONDecodeError as e:
                print(f"Warning: JSON block {i+1} parse failed: {e}")
                continue
    
    print(f"Successfully loaded {len(samples)} key variable samples")
    print(f"Samples with pass_rate data: {found_pass_rate}")
    print(f"Samples with avg_token data: {found_avg_token}")
    print(f"pass_rate match rate: {found_pass_rate/len(samples)*100:.1f}%" if samples else "0%")
    print(f"avg_token match rate: {found_avg_token/len(samples)*100:.1f}%" if samples else "0%")
    
    return samples

def create_single_combination(samples_group, k, combination_id):
    """Create single combination sample"""
    # Validate all samples in group have valid integer targets, pass_rate and average output tokens
    valid_group = True
    target_values = []  # Store target value for each problem
    original_targets = []  # Store original targets
    group_targets = []  # Store each problem's target combined in order
    group_pass_rates = []  # Store pass_rate for each sample in group (float type)
    group_avg_tokens = []  # Store average output tokens for each sample in group (float type)

    for sample in samples_group:
        target_val = get_target_integer(sample['target'])
        if target_val is None:
            valid_group = False
            break

        # Convert pass_rate to float, skip this combination if invalid
        pass_rate = sample.get('pass_rate')
        if pass_rate is not None:
            try:
                float_pass_rate = float(pass_rate)
                group_pass_rates.append(float_pass_rate)
            except (ValueError, TypeError):
                valid_group = False
                break
        else:
            valid_group = False
            break

        # Convert average output tokens to float, skip this combination if invalid
        avg_tokens = sample.get('avg_output_tokens')
        if avg_tokens is not None:
            try:
                float_avg_tokens = float(avg_tokens)
                group_avg_tokens.append(float_avg_tokens)
            except (ValueError, TypeError):
                valid_group = False
                break
        else:
            valid_group = False
            break

        target_values.append(target_val)
        original_targets.append(sample['target'])
        group_targets.append(sample['target'])  # Add to group_targets array

    if not valid_group or len(group_pass_rates) != k or len(group_avg_tokens) != k:
        return None

    # Calculate expected_pass_rate: multiply all pass_rates
    expected_pass_rate = 1.0
    for pass_rate in group_pass_rates:
        expected_pass_rate *= pass_rate

    # Calculate expected average output tokens: sum all average tokens
    expected_avg_output_tokens = sum(group_avg_tokens)

    # Create combination sample
    combined_input_parts = []
    instance_ids = []
    selected_variables = []  # Record all selected variable information

    for j, sample in enumerate(samples_group):
        instance_ids.append(sample.get("instanceId", f"unknown_{combination_id}_{j}"))
        selected_variables.append(sample.get("selected_variable", {}))

        if j == 0:
            # First sample uses original input directly
            combined_input_parts.append(f"Problem {j+1}: {sample['input']}")
        else:
            # Subsequent samples need to link previous sample's answer, using [] markers
            prev_answer_placeholder = f"[answer{j}]"  # [answer1], [answer2], [answer3], ...
            prev_answer_value = target_values[j-1]  # Previous problem's target value

            # Create chained input using selected_variable information
            linked_input = create_linked_input(
                prev_answer_placeholder,
                prev_answer_value,
                sample,
                j + 1  # Problem index, second problem is 2, third problem is 3, etc.
            )

            # If create_linked_input returns None, skip this combination
            if linked_input is None:
                return None

            combined_input_parts.append(f"Problem {j+1}: {linked_input}")

    # Combine complete input
    combined_input = "\n\n".join(combined_input_parts)
    
    # Build answer format example using [] markers
    answer_format_lines = []
    for i in range(1, k + 1):
        answer_format_lines.append(f"Problem {i}: \\boxed{{[answer{i}]}}")
    
    answer_format = "\n\n".join(answer_format_lines)
    
    # Add ending instructions and format requirements, including marker explanations
    combined_input += f"""\n\nNote: In this problem set:
- [variablek] represents the calculated variable needed to solve problem k.
- [answerk] represents the answer to problem k.

Solve all problems step by step and provide the answers for all problems in the following format:

### Final Answers

{answer_format}
"""
    # Create combined target - join all answers with commas into a string
    target_strings = []
    for target in original_targets:
        if isinstance(target, list) and len(target) > 0:
            target_strings.append(str(target[0]))  # Take first element
        else:
            target_strings.append(str(target))
    
    combined_target_string = ",".join(target_strings)
    combined_target = [combined_target_string]  # Wrap in list

    # Create combination sample, maintaining original dataset format
    combined_example = {
        "input": combined_input,
        "instanceId": instance_ids[-1],  # Use last problem's instanceId
        "target": combined_target,  # Keep last sample's original target format
        # New fields
        "instanceIds": instance_ids,
        "num_problems": k,
        "problem_type": "chained_reasoning",
        "group_targets": group_targets,  # Each problem's target combined in order
        "group_pass_rates": group_pass_rates,  # Save all samples' pass_rate in group (float type)
        "group_avg_tokens": group_avg_tokens,  # Save all samples' average output tokens in group (float type)
        "expected_pass_rate": expected_pass_rate,  # Expected pass_rate calculated using multiplication
        "expected_avg_output_tokens": expected_avg_output_tokens,  # Expected average output tokens
        "selected_variables": selected_variables,  # Save all selected variable information
    }

    # Keep other fields from last sample (if any)
    last_sample = samples_group[-1]
    for key, value in last_sample.items():
        if key not in combined_example:  # Don't overwrite existing fields
            combined_example[key] = value

    return combined_example


def combine_samples_for_k_fixed_num(key_variable_samples, output_dir, input_filename, k, seed, target_num_samples=500):
    """Combine samples for specified k value, generating fixed number of samples"""

    if len(key_variable_samples) < k:
        print(f"Warning: Key variable sample count ({len(key_variable_samples)}) less than combination requirement (k={k}), skipping")
        return None, [], []

    # Build output file path
    output_filename = f"combined_key_var_k{k}_sd{seed}_{input_filename}"
    output_path = os.path.join(output_dir, output_filename)

    # Shuffle sample order
    shuffled_samples = key_variable_samples.copy()
    random.shuffle(shuffled_samples)

    all_combined_examples = []  # Store all combination samples
    all_expected_pass_rates = []  # Store all expected_pass_rate for statistics
    all_expected_avg_tokens = []  # Store all expected average output tokens for statistics
    skipped_no_replacement = 0  # Count combinations skipped due to inability to replace

    # Use set to store generated combinations for efficient deduplication
    generated_combinations = set()

    print(f"k={k}: Using key variables for combination, target generating {target_num_samples} samples")

    # Phase 1: Random sampling to generate samples
    attempt_count = 0
    max_attempts = target_num_samples * 5

    while len(all_combined_examples) < target_num_samples and attempt_count < max_attempts:
        # Randomly select k samples from sample pool
        if len(shuffled_samples) >= k:
            random_group = random.sample(shuffled_samples, k)

            # Check if same combination already generated
            current_instance_ids = tuple(sample.get("instanceId", "") for sample in random_group)

            if current_instance_ids in generated_combinations:
                attempt_count += 1
                continue

            combined_example = create_single_combination(random_group, k, len(all_combined_examples))
            if combined_example is not None:
                all_combined_examples.append(combined_example)
                all_expected_pass_rates.append(combined_example['expected_pass_rate'])
                all_expected_avg_tokens.append(combined_example['expected_avg_output_tokens'])
                generated_combinations.add(current_instance_ids)
            else:
                skipped_no_replacement += 1

        attempt_count += 1

    print(f"k={k}: Random sampling phase generated {len(all_combined_examples)} samples")

    # Phase 2: If random sampling insufficient, supplement with permutations
    if len(all_combined_examples) < target_num_samples:
        print(f"k={k}: Random sampling insufficient, starting permutation supplement phase")

        max_direct_combinations = len(shuffled_samples) // k

        # Generate base combinations
        base_groups = []
        for i in range(max_direct_combinations):
            start_idx = i * k
            group = shuffled_samples[start_idx:start_idx + k]
            base_groups.append(group)

        # Generate more samples through permutations
        combination_count = len(all_combined_examples)
        retry_count = 0
        max_retries = (target_num_samples - len(all_combined_examples)) * 3

        while len(all_combined_examples) < target_num_samples and retry_count < max_retries:
            if not base_groups:
                break

            # Randomly select a base combination
            base_group = random.choice(base_groups)

            # Shuffle order of samples in this combination
            shuffled_group = base_group.copy()
            random.shuffle(shuffled_group)

            # Check if same combination already generated
            current_instance_ids = tuple(sample.get("instanceId", "") for sample in shuffled_group)

            if current_instance_ids in generated_combinations:
                retry_count += 1
                continue

            combined_example = create_single_combination(shuffled_group, k, combination_count)
            if combined_example is None:
                skipped_no_replacement += 1
                retry_count += 1
                continue

            all_combined_examples.append(combined_example)
            all_expected_pass_rates.append(combined_example['expected_pass_rate'])
            all_expected_avg_tokens.append(combined_example['expected_avg_output_tokens'])
            generated_combinations.add(current_instance_ids)
            combination_count += 1
            retry_count = 0  # Reset retry count

        print(f"k={k}: Permutation phase supplemented samples")

    # Truncate to target number
    all_combined_examples = all_combined_examples[:target_num_samples]
    all_expected_pass_rates = all_expected_pass_rates[:target_num_samples]
    all_expected_avg_tokens = all_expected_avg_tokens[:target_num_samples]

    print(f"k={k}: Finally generated {len(all_combined_examples)} samples, skipped {skipped_no_replacement}")

    # Write to file, maintaining original dataset format
    with open(output_path, 'w', encoding='utf-8') as outfile:
        # Calculate combination sample pass_rate and token statistics
        avg_expected_pass_rate = sum(all_expected_pass_rates) / len(all_expected_pass_rates) if all_expected_pass_rates else None
        avg_expected_avg_tokens = sum(all_expected_avg_tokens) / len(all_expected_avg_tokens) if all_expected_avg_tokens else None

        output_data = {
            "examples": all_combined_examples,
            "statistics": {
                "num_combined_samples": len(all_combined_examples),
                "k_value": k,
                "target_num_samples": target_num_samples,
                "samples_with_pass_rates": len(all_expected_pass_rates),
                "average_expected_pass_rate": avg_expected_pass_rate,
                "average_expected_avg_tokens": avg_expected_avg_tokens,
                "skipped_no_replacement": skipped_no_replacement,
                "using_key_variables": True
            }
        }
        outfile.write(json.dumps(output_data, ensure_ascii=False) + '\n')

    # Format statistics information
    if avg_expected_pass_rate is not None:
        pass_rate_str = f"{avg_expected_pass_rate:.6f}"
    else:
        pass_rate_str = "N/A"

    if avg_expected_avg_tokens is not None:
        tokens_str = f"{avg_expected_avg_tokens:.2f}"
    else:
        tokens_str = "N/A"

    print(f"k={k}: Combined samples {len(all_combined_examples)}, skipped samples {skipped_no_replacement}, average expected_pass_rate {pass_rate_str}, average expected output tokens {tokens_str}, saved to {output_path}")

    return output_path, all_expected_pass_rates, all_expected_avg_tokens

def create_pretty_files(output_paths, output_dir):
    """Create formatted beautiful versions for all output files"""
    pretty_paths = []

    for output_path in output_paths:
        if output_path is None:
            continue

        # Extract k value
        filename = os.path.basename(output_path)
        k_match = re.search(r'k(\d+)', filename)
        if k_match:
            k = k_match.group(1)
            pretty_filename = f"combined_key_var_k{k}_pretty_formatted.jsonl"
        else:
            pretty_filename = f"pretty_{filename}"

        pretty_output_path = os.path.join(output_dir, pretty_filename)

        with open(output_path, 'r', encoding='utf-8') as infile, \
             open(pretty_output_path, 'w', encoding='utf-8') as outfile:

            for line in infile:
                data = json.loads(line.strip())
                outfile.write(json.dumps(data, ensure_ascii=False, indent=2) + '\n')

        pretty_paths.append(pretty_output_path)

    return pretty_paths

def calculate_k1_statistics(samples_with_pass_rate):
    """Calculate k=1 statistics"""
    if not samples_with_pass_rate:
        return None, None, 0
    
    valid_pass_rates = []
    valid_tokens = []
    
    for sample in samples_with_pass_rate:
        pass_rate = sample.get('pass_rate')
        avg_tokens = sample.get('avg_output_tokens')
        
        if pass_rate is not None:
            try:
                valid_pass_rates.append(float(pass_rate))
            except (ValueError, TypeError):
                pass
        
        if avg_tokens is not None:
            try:
                valid_tokens.append(float(avg_tokens))
            except (ValueError, TypeError):
                pass
    
    avg_pass_rate = sum(valid_pass_rates) / len(valid_pass_rates) if valid_pass_rates else None
    avg_tokens = sum(valid_tokens) / len(valid_tokens) if valid_tokens else None
    sample_count = len(samples_with_pass_rate)
    
    return avg_pass_rate, avg_tokens, sample_count

def main():
    # ============ All path configurations centralized here ============
    # Key variable data path
    key_variable_input_path = "KEY_VARIABLE_INPUT_PATH"
    
    # Result data path (for loading pass_rate data)
    result_base_path = "RESULT_BASE_PATH"
    
    # Final output path
    output_dir = "OUTPUT_DIRECTORY"
    # ============ Path configuration end ============

    # Create output directory
    os.makedirs(output_dir, exist_ok=True)

    # Set different k values and target sample counts
    k_values = [2, 3, 4, 5, 6, 8, 10, 12, 16, 20]
    target_num_samples = 500  # Target sample count generated for each k value

    # Check if input file exists
    if not os.path.exists(key_variable_input_path):
        print(f"Error: Input file does not exist - {key_variable_input_path}")
        return

    # Set random seed to ensure reproducible results
    seed = 43
    random.seed(seed)

    # Load pass_rate data
    print("Starting to load pass_rate data...")
    pass_rate_mapping, avg_token_mapping = load_pass_rate_data(result_base_path)

    # Load key variable samples and match with pass_rate data
    print("Starting to load key variable samples...")
    key_variable_samples = load_key_variable_samples(key_variable_input_path, pass_rate_mapping, avg_token_mapping)

    if not key_variable_samples:
        print("No key variable samples found")
        return

    # Calculate k=1 statistics
    k1_pass_rate, k1_avg_tokens, k1_sample_count = calculate_k1_statistics(key_variable_samples)

    input_filename = os.path.basename(key_variable_input_path)
    output_paths = []
    k_expected_pass_rates = {}  # Store expected_pass_rates list for each k value
    k_expected_avg_tokens = {}  # Store expected_avg_tokens list for each k value

    print(f"\nStarting to generate combination data for different k values...")
    print(f"k value list: {k_values}")
    print(f"Available key variable samples: {len(key_variable_samples)}")
    print(f"Target samples per k value: {target_num_samples}")

    # Generate data for each k value
    for k in k_values:
        print(f"\nProcessing k={k}...")
        output_path, expected_pass_rates, expected_avg_tokens = combine_samples_for_k_fixed_num(
            key_variable_samples, output_dir, input_filename, k, seed, target_num_samples
        )
        if output_path:
            output_paths.append(output_path)
            k_expected_pass_rates[k] = expected_pass_rates
            k_expected_avg_tokens[k] = expected_avg_tokens

    # Create formatted beautiful files
    print(f"\nCreating beautiful format files...")
    pretty_paths = create_pretty_files(output_paths, output_dir)

    # Output summary
    print(f"\n" + "="*60)
    print(f"Processing complete! Generated {len(output_paths)} data files:")
    for i, (k, output_path) in enumerate(zip(k_values, output_paths)):
        if output_path:
            print(f"k={k:2d}: {os.path.basename(output_path)}")
        else:
            print(f"k={k:2d}: Skipped (insufficient samples)")

    print(f"\nBeautiful format files:")
    for pretty_path in pretty_paths:
        print(f"       {os.path.basename(pretty_path)}")

    # Output expected_pass_rate and expected_avg_tokens statistics for each k value
    print(f"\n" + "="*95)
    print(f"Average Expected Pass Rate and Expected Avg Output Tokens statistics for each k value (target samples: {target_num_samples}):")
    print(f"{'k value':<10} {'Actual samples':<12} {'Average Expected Pass Rate':<25} {'Average Expected Avg Tokens':<30}")
    print("-" * 95)

    # Add k=1 key variable data (using current model's pass_rate)
    if k1_pass_rate is not None and k1_avg_tokens is not None:
        print(f"{'k=1':<10} {k1_sample_count:<12} {k1_pass_rate:<25.6f} {k1_avg_tokens:<30.2f}")
    else:
        print(f"{'k=1':<10} {'N/A':<12} {'N/A':<25} {'N/A':<30}")

    # Add combination k value data
    for k in k_values:
        if k in k_expected_pass_rates and k_expected_pass_rates[k]:
            pass_rates = k_expected_pass_rates[k]
            tokens = k_expected_avg_tokens[k]
            avg_pass_rate = sum(pass_rates) / len(pass_rates)
            avg_tokens = sum(tokens) / len(tokens)
            print(f"k={k:<7} {len(pass_rates):<12} {avg_pass_rate:<25.6f} {avg_tokens:<30.2f}")
        else:
            print(f"k={k:<7} {'0':<12} {'N/A':<25} {'N/A':<30}")

    print(f"\nExplanation:")
    print(f"k=1: Key variable sample statistics ({k1_sample_count} samples) - using current model's pass_rate")
    print(f"k>=2: Combination chained reasoning sample statistics based on key variables - using current model's pass_rate")
    print(f"Pass_rate data source: {result_base_path}")
    print(f"\nVariable marking explanation:")
    print(f"- [variablek]: Calculated variable needed to solve problem k")
    print(f"- [answerk]: Answer to problem k") 
    print(f"- Variables in math environment use {{[variablek]}} marking")
    print(f"\nAll files saved in: {output_dir}")

if __name__ == "__main__":
    main()
