"""
Solution Input Mutual Verification Module

This module implements a mutual verification mechanism for synthetic programming problems
where no oracle solutions exist. The key idea is to leverage agreement among multiple
independently generated solutions to establish correctness.

The verification process works as follows:
1. For each problem, we sample multiple candidate solutions (e.g., 16) using language models
2. Generate a diverse set of test inputs (e.g., 50+) with varying complexities
3. Execute each candidate solution on all test inputs
4. If a majority of solutions produce identical outputs across all test inputs:
   - These consistent outputs are considered correct
   - The solutions that generated them are considered valid

The effectiveness of this approach stems from the observation that incorrect solutions
are more likely to diverge in their errors than to converge on the same incorrect
answers across multiple test inputs. When a majority of independently generated
solutions produce identical results for a diverse set of inputs, it provides strong
evidence that they have successfully solved the problem.

Key steps in the pipeline:
1. Load and process execution results from multiple solution attempts
2. Filter out problems where all solutions fail
3. Determine correct outputs through majority voting
4. Verify solution consistency against majority-voted answers
5. Filter problems based on solution consistency ratios
"""

import os
import json
import argparse
from collections import Counter
import pandas as pd
import resource
from typing import Dict, List, Tuple, Any

# Set memory limit to prevent OOM issues
MEMORY_LIMIT_GB = 500
memory_limit_bytes = MEMORY_LIMIT_GB * 1024 * 1024 * 1024
resource.setrlimit(resource.RLIMIT_AS, (memory_limit_bytes, memory_limit_bytes))

def load_all_data(execution_results_paths: List[str], qr_paths: List[str] = None) -> Tuple[pd.DataFrame, pd.DataFrame]:
    """
    Load execution results and question-response pairs from multiple files.
    
    Args:
        execution_results_paths: List of paths to execution result files
        qr_paths: Optional list of paths to question-response pair files
    
    Returns:
        Tuple of (execution_results_df, qr_pairs_df)
    """
    # Load execution results using line-by-line JSON parsing
    data = []
    for path in execution_results_paths:
        with open(path, 'r') as f:
            for line in f:
                data.append(json.loads(line))
    df = pd.DataFrame(data)
    
    # Load QR pairs if provided
    df_qr = pd.concat([pd.read_json(path) for path in qr_paths]) if qr_paths else None
    
    print(f"Number of unique descriptions: {df['description'].nunique()}")
    print(f"Total entries before deduplication: {len(df)}")
    
    # Remove duplicate solutions
    df.drop_duplicates(subset=['solution_code'], inplace=True)
    print(f"Total entries after deduplication: {len(df)}")
    
    return df, df_qr

def remove_problem_with_all_fail_solutions(df: pd.DataFrame) -> pd.DataFrame:
    """
    Remove problems where all solution attempts failed.
    
    Args:
        df: DataFrame containing execution results
        
    Returns:
        DataFrame with all-fail problems removed
    """
    def check_description_all_fail(group):
        # Check if any solution succeeded for any test case
        return any(
            any(item['success'] for item in entry)
            for entry in group['execution_results']
        )

    # Filter out descriptions where all solutions failed
    description_has_success = df.groupby('description').apply(check_description_all_fail)
    all_fail_descriptions = description_has_success[~description_has_success].index
    return df[~df['description'].isin(all_fail_descriptions)]

def get_majority_vote_answer_and_check_consistency(df: pd.DataFrame) -> Tuple[pd.DataFrame, Dict]:
    """
    Calculate majority vote answers for each problem and check solution consistency.
    
    Args:
        df: DataFrame containing execution results
        
    Returns:
        Tuple of (processed_df, description_to_majority_vote_dict)
    """
    def get_mj_answer(df_group):
        # Explode execution results for analysis
        df_explode = df_group.explode('execution_results')
        df_explode['input_index'] = df_explode['execution_results'].apply(lambda x: x['input_index'])
        df_explode['output'] = df_explode['execution_results'].apply(
            lambda x: x['output'] if x['success'] else None
        )
        df_explode.drop_duplicates(subset=['solution_code', 'input_index'], inplace=True)
        
        # Create pivot table for majority voting
        pivot_table = df_explode.pivot(
            index='solution_code',
            columns='input_index',
            values='output'
        )
        
        def majority_vote(df):
            # Perform column-wise majority voting
            majority_votes = {}
            for col in df.columns:
                values = df[col].dropna()
                if not values.empty:
                    most_common = Counter(values).most_common(1)[0][0]
                    majority_votes[col] = most_common
                else:
                    majority_votes[col] = None
            return majority_votes

        return majority_vote(pivot_table)

    # Get majority vote answers for each description
    results = df.groupby('description').apply(get_mj_answer).reset_index()
    results['column_majority_vote'] = results[0].apply(lambda x: x[1])

    # Remove problems with identical outputs for all test cases
    results = results[results['column_majority_vote'].apply(
        lambda x: len(set(x.values())) > 1
    )]
    
    description2majority_vote = results.set_index('description').to_dict()

    def check_solution_consistent_with_mj_answer(row):
        """Check if a solution's outputs match majority vote answers"""
        execution_results = row['execution_results']
        majority_vote_answer = description2majority_vote[row['description']]
        
        # Convert execution results to dictionaries for easier lookup
        execution_results_dict = {
            entry['input_index']: entry['output']
            for entry in execution_results if entry['success']
        }
        execution_time_dict = {
            entry['input_index']: entry['execution_time']
            for entry in execution_results if entry['success']
        }
        
        # Filter out None values from majority vote answers
        majority_vote_answer = {
            idx: output
            for idx, output in majority_vote_answer.items()
            if output is not None
        }
        
        # Check consistency and calculate average execution time
        execution_times = []
        for idx, output in majority_vote_answer.items():
            idx = int(idx)
            if idx not in execution_results_dict or execution_results_dict[idx] != output:
                return False, None
            execution_times.append(execution_time_dict[idx])
            
        avg_time = sum(execution_times) / len(execution_times) if execution_times else None
        return True, avg_time

    # Filter and add consistency information
    df = df[df['description'].isin(description2majority_vote.keys())]
    df['is_consistent'], df['avg_execution_time'] = zip(
        *df.apply(check_solution_consistent_with_mj_answer, axis=1)
    )
    
    return df, description2majority_vote

def filter_by_consistency_ratio(df: pd.DataFrame, mj_ratio: float) -> pd.DataFrame:
    """
    Filter problems based on solution consistency ratio.
    
    Args:
        df: DataFrame containing execution results and consistency information
        mj_ratio: Minimum ratio of consistent solutions required (0-1)
        
    Returns:
        DataFrame containing only problems meeting consistency threshold
    """
    # Calculate consistency ratio for each description
    df_mj_ratio = df.groupby('description').agg({
        'is_consistent': lambda x: x.value_counts(normalize=True).get(True, 0),
    }).reset_index()
    
    # Filter descriptions meeting threshold
    effective_descriptions = df_mj_ratio[df_mj_ratio['is_consistent'] > mj_ratio]['description']
    print(f'Number of problems meeting consistency threshold: {len(effective_descriptions)}')
    
    # Return only consistent solutions for effective descriptions
    return df[
        (df['description'].isin(effective_descriptions)) & 
        (df['is_consistent'])
    ]

def mutual_verification(
    execution_result_paths: str,
    output_path_prefix: str,
    consistency_ratio: int
) -> None:
    """
    Main verification pipeline function.
    
    Args:
        execution_result_paths: Comma-separated paths to execution result files
        output_path_prefix: Prefix for output file paths
        consistency_ratio: Minimum percentage of consistent solutions required (0-100)
    """
    # Load and process execution results
    execution_result_paths = execution_result_paths.split(',')
    df_exec, _ = load_all_data(execution_result_paths)
    
    # Remove problems with all failed solutions
    df_exec = remove_problem_with_all_fail_solutions(df_exec)
    
    # Get majority vote answers and check consistency
    df_exec, description2majority_vote = get_majority_vote_answer_and_check_consistency(df_exec)
    
    # Filter by consistency ratio and save results
    selected_dfs = filter_by_consistency_ratio(df_exec, consistency_ratio/100)
    
    # Save filtered results and majority vote answers
    selected_dfs[['query', 'response', 'code', 'avg_execution_time']].to_json(
        output_path_prefix + f'_maj{consistency_ratio}.json',
        orient='records',
        indent=4
    )
    with open(output_path_prefix + '_description2majority_vote.json', 'w') as f:
        json.dump(description2majority_vote, f, indent=4)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Verify programming problem solutions through majority voting')
    parser.add_argument(
        "--execution_result_paths",
        type=str,
        required=True,
        help="Comma-separated paths to execution result files"
    )
    parser.add_argument(
        "--output_path_prefix",
        type=str,
        required=True,
        help="Prefix for output file paths"
    )
    parser.add_argument(
        "--consistency_ratio",
        type=int,
        default=60,
        help="Minimum percentage of consistent solutions required (0-100)"
    )
    
    args = parser.parse_args()
    mutual_verification(
        args.execution_result_paths,
        args.output_path_prefix,
        args.consistency_ratio
    )
