"""
Test Input Generation and Validation Module

This module implements optimized parallel processing for generating and validating test inputs.
It executes generated utility functions with different parameter scales and validates the generated test inputs.

"""

import os
import argparse
import json
import re
import itertools
import pandas as pd
import numpy as np
from pebble import ProcessPool, ThreadPool
from pebble.common import ProcessExpired
from tqdm import tqdm
import time
import traceback
import sys
import math
import signal
import random
sys.set_int_max_str_digits(5000)

DEBUG = False

def set_resource_limits():
    """Set memory limits for child processes to prevent memory overflow."""
    import resource
    # Set memory limit to 15GB
    resource.setrlimit(resource.RLIMIT_AS, (15 *2**30, 15 *2**30))
    
def timeout_handler(signum, frame):
    """Signal handler for timeout control."""
    raise TimeoutError("Execution timed out")

def run_with_timeout(func, timeout, args=(), kwargs={}):
    """
    Execute a function with timeout control using signal mechanism.
    
    Args:
        func: Function to execute
        timeout: Maximum execution time in seconds
        args: Positional arguments for the function
        kwargs: Keyword arguments for the function
        
    Returns:
        Function result or None if execution fails
    """
    signal.signal(signal.SIGALRM, timeout_handler)
    signal.alarm(timeout)  # Start timer
    try:
        result = func(*args, **kwargs)
    except Exception as e:
        if DEBUG:
            print(f"Run with timeout, {e}")
            traceback.print_exc()
        result = None
    finally:
        signal.alarm(0)  # Clear alarm
    return result

def safe_exec(code, globals_=None, locals_=None, timeout=8):
    """
    Safely execute code with timeout and exception handling.
    
    Args:
        code: Python code string to execute
        globals_: Global namespace
        locals_: Local namespace
        timeout: Maximum execution time
        
    Returns:
        bool: True if execution succeeds, False otherwise
    """
    try:
        run_with_timeout(exec, timeout, args=(code, globals_, locals_))
    except TimeoutError:
        return False
    except Exception as e:
        if DEBUG:
            print(f"Exec error: {e}")
            traceback.print_exc()
        return False
    return True

def extract_parameters(code):
    """
    Extract scale-controlling parameters from code string using regex.
    
    Args:
        code: Python function code string
        
    Returns:
        list: List of parameter names
    """
    # Remove comments and normalize whitespace
    code = re.sub(r'#.*$', '', code, flags=re.MULTILINE)
    code = re.sub(r'\s+', ' ', code)
    
    # Match function definition including decorators
    pattern = re.compile(r"(?:^|\s)def\s+\w+\s*\((.*?)\)\s*(?:->.*?)?:")
    match = pattern.search(code)
    
    if match:
        params = match.group(1)
        param_list = []
        for param in params.split(','):
            param = param.strip()
            if not param:
                continue
            # Remove type annotation and default value
            param = re.split(r':', param)[0]
            param = re.split(r'=', param)[0]
            param = param.strip()
            if param:
                param_list.append(param)
        return param_list
    return []

def get_param_combinations(num_params, max_exponent_for_complex=5):
    """
    Generate parameter value combinations for different input scales.
    
    Args:
        num_params: Number of parameters to generate values for
        max_exponent_for_complex: Maximum power of 10 for complex test cases
        
    Returns:
        list: List of parameter value combinations
    """
    range_list = list(range(1, 10)) + [10**i for i in range(1, max_exponent_for_complex + 1)]
    return list(itertools.product(range_list, repeat=num_params))

def generate_and_validate_input_string(generate_test_input, validate_test_input, case_timeout, param_dict):
    """
    Generate and validate a single test input.
    
    Args:
        generate_test_input: Function to generate test input
        validate_test_input: Function to validate test input
        case_timeout: Maximum execution time for each case
        param_dict: Parameter values for test generation
        
    Returns:
        str: Valid input string or None if generation/validation fails
    """
    input_string = run_with_timeout(generate_test_input, case_timeout, kwargs=param_dict)
    if not input_string or not isinstance(input_string, str):
        return None
    is_valid = run_with_timeout(validate_test_input, case_timeout, kwargs={"input_string": input_string})
    return input_string if is_valid else None

def process_row(data, row_timeout, case_timeout):
    """
    Process a single row of test generation data.
    
    Args:
        data: Row data containing generation and validation code
        row_timeout: Maximum processing time for the row
        case_timeout: Maximum time for each test case
        
    Returns:
        dict: Processing results including generated inputs and errors
    """
    set_resource_limits()
    start_time = time.time()

    error_message = None
    for i in range(len(data['response_info'])):
        try:
            validation_code = data["response_info"][i]["test_input_validation_code"]
            generation_code = data["response_info"][i]["test_input_generation_code"]

            # Set up execution environment
            import_code = """import random\nimport cyaron as cy"""
            if not safe_exec(import_code, globals()):
                sys.exit(-1)
            
            # Execute and validate generation code
            if not safe_exec(generation_code, globals()):
                error_message = f"Response [{i}] Generation code execution failed."
                print(error_message)
                continue

            if not safe_exec(validation_code, globals()):
                error_message = f"Response [{i}] Validation code execution failed."
                print(error_message)
                continue
            
            # Extract and test different parameter combinations
            generation_params = extract_parameters(generation_code)
            results = []
            scales = []
            for values in get_param_combinations(len(generation_params), 6):
                param_dict = dict(zip(generation_params, values))
                case_start_time = time.time()
                input_string = generate_and_validate_input_string(generate_test_input, validate_test_input, case_timeout, param_dict)
                last_case_time = time.time() - case_start_time
                
                # Store unique valid inputs
                if input_string and input_string not in results:
                    results.append(input_string)
                    scales.append(values)
                
                # Check remaining time
                if time.time() - start_time + case_timeout * 2 > row_timeout * 0.9:
                    print(f"Break before row timeout, generated {len(results)} inputs.")
                    break

            if len(results) == 0:
                error_message = f"Response [{i}] No valid input generated."
                continue
                
            row_result = data.to_dict()
            row_result["generated_inputs"] = results
            row_result["generated_inputs_scales"] = scales
            return row_result

        except Exception as e:
            error_traceback = traceback.format_exc()
            error_message = f"Response [{i}] Error processing row: {e}"
            if DEBUG:
                print(f"Error processing row response [{i}]: {e}")
                print(error_traceback)
    
    # Handle case where all responses fail
    row_result = data.to_dict() if isinstance(data, pd.Series) else {}
    row_result["generated_inputs"] = []
    row_result["generated_inputs_scales"] = []
    row_result["error"] = {
        "message": error_message,  # Record error message of the last response
    }
    return row_result

def parallel_process_input(df_path, output_path, num_processes=8, row_timeout=60, case_timeout=5, start_idx=0, end_idx=-1, batch_size=100):
    """
    Process test input generation in parallel with batching.
    
    Args:
        df_path: Path to input data file
        output_path: Path to save results
        num_processes: Number of parallel processes
        row_timeout: Maximum time per row
        case_timeout: Maximum time per test case
        batch_size: Number of rows to process in each batch
    """
    with open(df_path, 'r') as f:
        data = json.load(f)
    df = pd.DataFrame(data)
    print('output path', output_path)

    def batch_generator(df, batch_size):
        """Generate batches of rows for processing."""
        for start in range(0, len(df), batch_size):
            yield df.iloc[start:start + batch_size]

    total_rows = len(df)
    print('df length', len(df), 'total tasks', total_rows, 'Worst time', f'{math.ceil((len(df) / num_processes) * row_timeout/60)}m')

    with tqdm(total=total_rows, desc="Processing entire DataFrame", unit="row") as pbar:
        with ProcessPool(max_workers=num_processes) as pool:
            for batch_df in batch_generator(df, batch_size):
                batch_results = []
                futures = []
                
                # Schedule batch processing
                for _, row in batch_df.iterrows():
                    futures.append(pool.schedule(process_row, args=(row, row_timeout, case_timeout), timeout=row_timeout))
                
                # Collect results
                for future in futures:
                    try:
                        batch_results.append(future.result())
                    except TimeoutError:
                        pass
                    except ProcessExpired as e:
                        print(f"Process expired: {e}.")
                    except Exception as e:
                        pass
                    finally:
                        pbar.update(1)

                if len(batch_results) == 0:
                    continue
                    
                # Aggregate and save results
                results_df = pd.DataFrame(batch_results)
                results_df = results_df.groupby('description').apply(
                    lambda group: pd.Series({
                        **{col: group[col].iloc[0] for col in group.columns if col not in ['generated_inputs', 'description', 'generated_inputs_scales']},
                        'generated_inputs': sum(group['generated_inputs'], []),
                        'generated_inputs_scales': sum(group['generated_inputs_scales'], [])
                    })
                ).reset_index()

                with open(output_path, 'a') as f:
                    for row in results_df.to_dict(orient="records"):
                        f.write(json.dumps(row) + "\n")
    print(f"Results saved to {output_path}")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Generate test inputs for test cases.")
    parser.add_argument("--input_path", type=str, help="Path to the input JSON file.")
    parser.add_argument("--output_path", type=str, help="Path to save the output JSON file.")
    parser.add_argument('--row_timeout', type=int, default=60, help="Timeout for processing a single row.")
    parser.add_argument('--case_timeout', type=int, default=5, help="Timeout for processing a single test case.")
    parser.add_argument('--num_processes', type=int, default=8, help="Number of processes to use for parallel processing.")
    parser.add_argument('--batch_size', type=int, default=100, help="Batch size for processing rows.")
    args = parser.parse_args()
    
    parallel_process_input(args.input_path, args.output_path, num_processes=args.num_processes, 
                         row_timeout=args.row_timeout, case_timeout=args.case_timeout,
                         batch_size=args.batch_size)
