import json
import re
import os
import random
from glob import glob

def contains_numbers(text):
    """Check if text contains numbers"""
    return bool(re.search(r'\d', text))

def contains_integers(text):
    """Check if text contains complete integers (not part of floating point numbers)"""
    integers = extract_numbers_from_text(text)
    return len(integers) > 0

def is_pure_integer(text):
    """Check if text is pure integer"""
    if re.search(r'\\[a-zA-Z]', text):  # Contains LaTeX commands like \sqrt, \frac etc
        return False

    if re.search(r'[a-zA-Z]', text):  # Contains letters
        return False

    if '/' in text:  # Contains fractions
        return False

    if '.' in text:  # Contains decimal points
        return False

    # Remove spaces, braces etc
    cleaned = re.sub(r'[\\{}\s]', '', text)

    # Final check if it's pure integer
    pattern = r'^-?\d+$'
    result = bool(re.match(pattern, cleaned))

    return result

def extract_numbers_from_text(text):
    """Extract all complete integers from text, completely excluding any part of floating point numbers"""
    numbers_with_pos = []

    # First mark all floating point number position ranges
    float_ranges = set()

    # Match various floating point formats: 123.45, .45, 123., 123.0, 0.123 etc
    float_patterns = [
        r'\d+\.\d+',  # 123.45
        r'\d+\.',     # 123.
        r'\.\d+',     # .45
    ]

    for pattern in float_patterns:
        for match in re.finditer(pattern, text):
            # Mark all positions covered by floating point numbers as unavailable
            for pos in range(match.start(), match.end()):
                float_ranges.add(pos)

    # Now find all integers, but exclude digits in floating point ranges
    integer_pattern = r'\d+'
    for match in re.finditer(integer_pattern, text):
        start_pos = match.start()
        end_pos = match.end()

        # Check if any part of this integer overlaps with floating point numbers
        is_part_of_float = False
        for pos in range(start_pos, end_pos):
            if pos in float_ranges:
                is_part_of_float = True
                break

        if is_part_of_float:
            continue

        # Additional safety check: check characters before and after
        before_char = text[start_pos-1] if start_pos > 0 else ' '
        after_char = text[end_pos] if end_pos < len(text) else ' '

        # If there's decimal point before or after, skip absolutely
        if before_char == '.' or after_char == '.':
            continue

        # Check if it's part of other patterns (like variable names, identifiers etc)
        if start_pos > 0 and before_char.isalnum() and before_char not in ['$', '£', '€', '¥', '#', ' ', '(', '[', '{']:
            continue

        if end_pos < len(text) and after_char.isalpha():
            continue

        num = int(match.group())
        if num > 0:  # Keep only positive integers
            numbers_with_pos.append({
                'number': num,
                'start': start_pos,
                'end': end_pos,
                'text': match.group()
            })

    return numbers_with_pos

def get_target_integer(target_list):
    """Extract integer from target list"""
    for target in target_list:
        if is_pure_integer(target):
            try:
                cleaned = re.sub(r'[\\{}\s]', '', target)
                return int(cleaned)
            except ValueError as e:
                print(f"Warning: Cannot convert '{target}' -> '{cleaned}' to integer: {e}")
                continue
    return None

def load_pass_rate_data(result_base_path):
    """Load multiple evaluation result data, calculate pass_rate and average output tokens for each instance"""
    pass_rate_mapping = {}
    avg_token_mapping = {}

    # First list all subdirectories under the directory
    if not os.path.exists(result_base_path):
        print(f"Warning: Base path does not exist - {result_base_path}")
        return pass_rate_mapping, avg_token_mapping
    
    all_subdirs = [d for d in os.listdir(result_base_path) if os.path.isdir(os.path.join(result_base_path, d))]
    print(f"Found subdirectories: {all_subdirs}")
    
    # Match directories containing "-number-32" pattern
    file_paths = []
    pattern = re.compile(r'-(\d+)-32')
    
    for subdir in all_subdirs:
        match = pattern.search(subdir)  # Use search instead of match
        if match:
            number = int(match.group(1))
            if 0 <= number <= 31:  # Ensure number is in 0-31 range
                scenario_file = os.path.join(result_base_path, subdir, "scenario_state.json")
                if os.path.exists(scenario_file):
                    file_paths.append(scenario_file)
                    print(f"Found matching file: {subdir}")
    
    if not file_paths:
        print(f"Warning: No matching result files found")
        return pass_rate_mapping, avg_token_mapping
    
    print(f"Total found {len(file_paths)} result files")
    
    # Store results for each instance across all files
    instance_results = {}  # {instance_id: {'scores': [], 'tokens': []}}
    failed_score_extraction_count = 0  # Count samples where score extraction failed
    
    # Iterate through all files
    for file_path in file_paths:
        try:
            with open(file_path, 'r', encoding='utf-8') as f:
                data = json.load(f)

            request_states = data.get('request_states', [])

            for request_state in request_states:
                # Get instance information
                instance = request_state.get('instance', {})
                instance_id = instance.get('id', '')
                if not instance_id:
                    continue

                # Initialize result storage for this instance
                if instance_id not in instance_results:
                    instance_results[instance_id] = {'scores': [], 'tokens': []}

                # Get score information
                model_score_result = request_state.get('model_score_result', '{}')
                # Use regex to directly extract score
                score_match = re.search(r'{"score":\s*(\d+)}', model_score_result)
                if score_match:
                    score = float(score_match.group(1))
                    instance_results[instance_id]['scores'].append(score)
                else:
                    # Cannot extract score, record as 0 (error)
                    print(f"Warning: Cannot extract score - instance_id: {instance_id}, marked as error")
                    instance_results[instance_id]['scores'].append(0.0)
                    failed_score_extraction_count += 1

                # Get output token count
                result = request_state.get('result', {})
                completions = result.get('completions', [])
                if completions:
                    output_token_num = completions[0].get('output_token_num', 0)
                    try:
                        instance_results[instance_id]['tokens'].append(float(output_token_num))
                    except (ValueError, TypeError):
                        instance_results[instance_id]['tokens'].append(0.0)
                else:
                    instance_results[instance_id]['tokens'].append(0.0)

        except FileNotFoundError:
            print(f"Warning: File does not exist - {file_path}")
            continue
        except json.JSONDecodeError as e:
            print(f"Warning: JSON parse error - {file_path}: {e}")
            continue
        except Exception as e:
            print(f"Warning: Error loading file - {file_path}: {e}")
            continue

    # Calculate pass_rate and average tokens for each instance
    for instance_id, results in instance_results.items():
        scores = results['scores']
        tokens = results['tokens']

        if scores:
            # Calculate pass_rate (proportion of scores equal to 1)
            pass_count = sum(1 for score in scores if score == 1.0)
            pass_rate = pass_count / len(scores)
            pass_rate_mapping[instance_id] = pass_rate

        if tokens:
            # Calculate average output tokens
            avg_tokens = sum(tokens) / len(tokens)
            avg_token_mapping[instance_id] = avg_tokens

    print(f"Successfully loaded pass_rate data, {len(pass_rate_mapping)} samples have pass_rate info")
    print(f"Successfully loaded average output token data, {len(avg_token_mapping)} samples have average token info")
    print(f"Samples with failed score extraction: {failed_score_extraction_count}")

    return pass_rate_mapping, avg_token_mapping

def get_filtered_samples(input_path, result_base_path, filtered_output_dir):
    """Get filtered samples"""
    filtered_samples = []
    total_samples = 0
    rejected_samples = []
    rejected_no_integers = 0

    # Load pass_rate data and average output token data
    pass_rate_mapping, avg_token_mapping = load_pass_rate_data(result_base_path)

    print(f"Loaded pass_rate_mapping sample count: {len(pass_rate_mapping)}")
    print(f"Loaded avg_token_mapping sample count: {len(avg_token_mapping)}")
    
    # Convert to sets for fast lookup while keeping original mappings
    pass_rate_keys = set(pass_rate_mapping.keys())
    avg_token_keys = set(avg_token_mapping.keys())
    
    print(f"pass_rate_mapping first 10 keys: {list(pass_rate_mapping.keys())[:10]}")
    print(f"avg_token_mapping first 10 keys: {list(avg_token_mapping.keys())[:10]}")
    
    # Statistics for matching
    found_pass_rate = 0
    found_avg_token = 0
    not_found_pass_rate = []
    not_found_avg_token = []
    
    # Detailed data type and value check
    sample_check_count = 0

    # For calculating statistics
    all_samples_with_data = []  # Store all samples with data

    with open(input_path, 'r', encoding='utf-8') as infile:
        for line in infile:
            try:
                data = json.loads(line.strip())
                if "examples" not in data:
                    continue

                examples = data["examples"]
                if not isinstance(examples, list):
                    continue

                for example in examples:
                    total_samples += 1
                    input_text = example.get("input", "")
                    target = example.get("target", [])
                    instance_id = example.get("instanceId", "")

                    if not isinstance(target, list):
                        target = [str(target)]

                    # First check if pass_rate and avg_token data exist (for statistics)
                    instance_id_str = str(instance_id)
                    if (instance_id_str in pass_rate_keys and instance_id_str in avg_token_keys):
                        sample_with_data = example.copy()
                        sample_with_data['pass_rate'] = pass_rate_mapping[instance_id_str]
                        sample_with_data['avg_output_tokens'] = avg_token_mapping[instance_id_str]
                        all_samples_with_data.append(sample_with_data)

                    # Check conditions: input contains complete integers and target is pure integer
                    if contains_integers(input_text):
                        all_targets_integer = True
                        for t in target:
                            if not is_pure_integer(str(t)):
                                all_targets_integer = False
                                rejected_samples.append({
                                    'input': input_text[:50] + '...',
                                    'target': t,
                                    'reason': f'target is not pure integer: {t}'
                                })
                                break

                        if all_targets_integer:
                            # Detailed matching process check
                            sample_check_count += 1
                            
                            pass_rate = None
                            avg_tokens = None
                            
                            # Detailed lookup process
                            if instance_id_str in pass_rate_keys:
                                pass_rate = pass_rate_mapping[instance_id_str]
                                found_pass_rate += 1
                                
                                # Check detailed info for first 5 samples
                                if sample_check_count <= 5:
                                    print(f"✅ Sample {sample_check_count}: instance_id='{instance_id_str}', "
                                          f"pass_rate={pass_rate} (type: {type(pass_rate)})")
                            else:
                                not_found_pass_rate.append(instance_id_str)
                                if sample_check_count <= 5:
                                    print(f"❌ Sample {sample_check_count}: instance_id='{instance_id_str}' not found in pass_rate_mapping")
                            
                            if instance_id_str in avg_token_keys:
                                avg_tokens = avg_token_mapping[instance_id_str]
                                found_avg_token += 1
                                
                                if sample_check_count <= 5:
                                    print(f"✅ Sample {sample_check_count}: instance_id='{instance_id_str}', "
                                          f"avg_tokens={avg_tokens} (type: {type(avg_tokens)})")
                            else:
                                not_found_avg_token.append(instance_id_str)
                                if sample_check_count <= 5:
                                    print(f"❌ Sample {sample_check_count}: instance_id='{instance_id_str}' not found in avg_token_mapping")
                            
                            # Create sample
                            sample_with_metrics = example.copy()
                            sample_with_metrics['pass_rate'] = pass_rate
                            sample_with_metrics['avg_output_tokens'] = avg_tokens
                            filtered_samples.append(sample_with_metrics)
                            
                    else:
                        rejected_no_integers += 1
                        rejected_samples.append({
                            'input': input_text[:50] + '...',
                            'target': target,
                            'reason': 'No replaceable complete integers in input (excluding floating point numbers)'
                        })

            except json.JSONDecodeError as e:
                print(f"JSON parse error: {e}")
                continue
            except Exception as e:
                print(f"Error processing sample: {e}")
                continue

    # Print matching statistics
    print(f"\n=== Data Matching Statistics ===")
    print(f"Total processed samples: {total_samples}")
    print(f"Qualified samples: {len(filtered_samples)}")
    print(f"Samples with pass_rate data: {found_pass_rate}")
    print(f"Samples with avg_token data: {found_avg_token}")
    print(f"pass_rate match rate: {found_pass_rate/len(filtered_samples)*100:.1f}%" if filtered_samples else "0%")
    print(f"avg_token match rate: {found_avg_token/len(filtered_samples)*100:.1f}%" if filtered_samples else "0%")
    
    # Show some unmatched sample IDs (if any)
    if not_found_pass_rate:
        print(f"First 5 instance_ids without pass_rate: {not_found_pass_rate[:5]}")
    if not_found_avg_token:
        print(f"First 5 instance_ids without avg_token: {not_found_avg_token[:5]}")

    # Detailed data conversion validation
    print(f"\n=== Data Conversion Validation ===")
    valid_pass_rates = []
    valid_tokens = []
    invalid_pass_rates = []
    invalid_tokens = []
    
    for i, sample in enumerate(filtered_samples):
        pass_rate = sample['pass_rate']
        avg_tokens = sample['avg_output_tokens']

        # Detailed pass_rate conversion check
        if pass_rate is not None:
            try:
                float_pass_rate = float(pass_rate)
                valid_pass_rates.append(float_pass_rate)
                if i < 5:  # Detailed info for first 5 samples
                    print(f"Sample {i+1}: pass_rate {pass_rate} -> {float_pass_rate} ✅")
            except (ValueError, TypeError) as e:
                invalid_pass_rates.append((i, pass_rate, str(e)))
                if i < 5:
                    print(f"Sample {i+1}: pass_rate {pass_rate} conversion failed: {e} ❌")
        else:
            if i < 5:
                print(f"Sample {i+1}: pass_rate is None")

        # Detailed avg_tokens conversion check
        if avg_tokens is not None:
            try:
                float_avg_tokens = float(avg_tokens)
                valid_tokens.append(float_avg_tokens)
                if i < 5:
                    print(f"Sample {i+1}: avg_tokens {avg_tokens} -> {float_avg_tokens} ✅")
            except (ValueError, TypeError) as e:
                invalid_tokens.append((i, avg_tokens, str(e)))
                if i < 5:
                    print(f"Sample {i+1}: avg_tokens {avg_tokens} conversion failed: {e} ❌")
        else:
            if i < 5:
                print(f"Sample {i+1}: avg_tokens is None")

    samples_with_pass_rates = len(valid_pass_rates)
    samples_without_pass_rates = len(filtered_samples) - samples_with_pass_rates
    samples_with_tokens = len(valid_tokens)
    samples_without_tokens = len(filtered_samples) - samples_with_tokens
    
    print(f"\nConversion results:")
    print(f"Valid pass_rate count: {len(valid_pass_rates)}")
    print(f"Invalid pass_rate count: {len(invalid_pass_rates)}")
    print(f"Valid avg_tokens count: {len(valid_tokens)}")
    print(f"Invalid avg_tokens count: {len(invalid_tokens)}")
    
    if invalid_pass_rates:
        print(f"Invalid pass_rate examples: {invalid_pass_rates[:3]}")
    if invalid_tokens:
        print(f"Invalid avg_tokens examples: {invalid_tokens[:3]}")

    # Pass_rate statistics
    if valid_pass_rates:
        average_pass_rate = sum(valid_pass_rates) / len(valid_pass_rates)
        print(f"\n=== Pass Rate Statistics ===")
        print(f"Samples with pass_rate info: {samples_with_pass_rates}")
        print(f"Samples without pass_rate info: {samples_without_pass_rates}")
        print(f"Average pass_rate: {average_pass_rate:.4f}")

        # Pass_rate distribution statistics
        pass_rate_ranges = {
            '0.0': 0, '0.0-0.2': 0, '0.2-0.4': 0, '0.4-0.6': 0,
            '0.6-0.8': 0, '0.8-1.0': 0, '1.0': 0
        }

        for pr in valid_pass_rates:
            if pr == 0.0:
                pass_rate_ranges['0.0'] += 1
            elif pr == 1.0:
                pass_rate_ranges['1.0'] += 1
            elif 0.0 < pr <= 0.2:
                pass_rate_ranges['0.0-0.2'] += 1
            elif 0.2 < pr <= 0.4:
                pass_rate_ranges['0.2-0.4'] += 1
            elif 0.4 < pr <= 0.6:
                pass_rate_ranges['0.4-0.6'] += 1
            elif 0.6 < pr <= 0.8:
                pass_rate_ranges['0.6-0.8'] += 1
            elif 0.8 < pr < 1.0:
                pass_rate_ranges['0.8-1.0'] += 1

        print(f"Pass Rate distribution:")
        for range_name, count in pass_rate_ranges.items():
            if count > 0:
                percentage = (count / samples_with_pass_rates) * 100
                print(f"  {range_name}: {count} samples ({percentage:.1f}%)")
    else:
        average_pass_rate = None
        pass_rate_ranges = {}
        print(f"\n=== Pass Rate Statistics ===")
        print(f"Warning: No valid pass_rate information found")

    # Average output token statistics
    if valid_tokens:
        overall_average_tokens = sum(valid_tokens) / len(valid_tokens)
        print(f"\n=== Average Output Token Statistics ===")
        print(f"Samples with average output token info: {samples_with_tokens}")
        print(f"Samples without average output token info: {samples_without_tokens}")
        print(f"Overall average output tokens: {overall_average_tokens:.2f}")
        print(f"Average output token range: {min(valid_tokens):.2f} - {max(valid_tokens):.2f}")
    else:
        overall_average_tokens = None
        print(f"\n=== Average Output Token Statistics ===")
        print(f"Warning: No valid average output token information found")

    print(f"\n=== Filtering Results ===")
    print(f"Total samples: {total_samples}")
    print(f"Filtered {len(filtered_samples)} qualified samples")
    print(f"Samples rejected for no replaceable integers: {rejected_no_integers}")
    if average_pass_rate is not None:
        print(f"Average pass_rate of filtered samples: {average_pass_rate:.4f}")
    if overall_average_tokens is not None:
        print(f"Overall average output tokens of filtered samples: {overall_average_tokens:.2f}")

    # Save filtered samples (including pass_rate and average output token info)
    os.makedirs(filtered_output_dir, exist_ok=True)

    filtered_output_path = os.path.join(filtered_output_dir, "filtered_samples_with_pass_rates.jsonl")
    with open(filtered_output_path, 'w', encoding='utf-8') as outfile:
        filtered_data = {
            "examples": filtered_samples,
            "statistics": {
                "total_samples": total_samples,
                "filtered_samples": len(filtered_samples),
                "rejected_no_integers": rejected_no_integers,
                "samples_with_pass_rates": samples_with_pass_rates,
                "samples_without_pass_rates": samples_without_pass_rates,
                "samples_with_tokens": samples_with_tokens,
                "samples_without_tokens": samples_without_tokens,
                "average_pass_rate": average_pass_rate,
                "overall_average_tokens": overall_average_tokens,
                "pass_rate_distribution": pass_rate_ranges if valid_pass_rates else {},
                "found_pass_rate": found_pass_rate,
                "found_avg_token": found_avg_token
            }
        }
        outfile.write(json.dumps(filtered_data, ensure_ascii=False, indent=2))

    print(f"Filtered samples saved to: {filtered_output_path}")

    # Return filtered samples and all samples with data
    return filtered_samples, all_samples_with_data

def main():
    # ============ All path configurations centralized here ============
    # Input data path
    input_path = "INPUT_DATA_PATH"
    
    # Result data path (for loading pass_rate data)
    result_base_path = "RESULT_BASE_PATH"
    
    # Filtered sample save path
    filtered_output_dir = "FILTERED_OUTPUT_DIR"
    # ============ Path configuration end ============

    # Create output directory
    os.makedirs(filtered_output_dir, exist_ok=True)

    # Check if input file exists
    if not os.path.exists(input_path):
        print(f"Error: Input file does not exist - {input_path}")
        return

    # Filter qualified samples
    print("Starting to filter qualified samples...")
    filtered_samples, all_samples_with_data = get_filtered_samples(input_path, result_base_path, filtered_output_dir)

    if not filtered_samples:
        print("No qualified samples found")
        return

    print(f"\nFiltering completed! Filtered {len(filtered_samples)} qualified samples")
    print(f"Filtering results saved to: {filtered_output_dir}")

if __name__ == "__main__":
    main()
