"""
Program Execution Module

This module provides functionality for executing code solutions with safety controls:
1. Resource limitation (CPU, memory)
2. Timeout control
3. Process isolation
4. Input/output handling
5. Error management

The module supports both standard input/output and function-based execution modes.
"""

import os
import pandas as pd
from pebble import ProcessPool
import subprocess
import time
import json
from tqdm import tqdm
import resource
import sys
import signal
from functools import wraps
import argparse

# Constants
TIMEOUT = 16  # Maximum execution time in seconds
MAX_MEMORY = 512 * 1024 * 1024  # Maximum memory limit (512MB)
MAX_CORES = 48  # Maximum number of CPU cores to use

def limit_memory(max_memory):
    """Set memory limit for child processes.
    
    Args:
        max_memory: Maximum memory in bytes
    """
    resource.setrlimit(resource.RLIMIT_AS, (max_memory, max_memory))

def run_solution(solution_code, input_data, language='python'):
    """Execute code solution with standard input/output.
    
    Args:
        solution_code: Source code to execute
        input_data: Input data as string
        language: Programming language (default: python)
        
    Returns:
        tuple: (output, success_flag, execution_time)
    """
    try:
        def set_cpu_affinity():
            # Limit CPU cores and memory usage
            os.sched_setaffinity(0, set(range(MAX_CORES)))
            limit_memory(MAX_MEMORY)
            
        process = subprocess.Popen(
            ['python3', '-c', solution_code],
            stdin=subprocess.PIPE,
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE,
            text=True,
            preexec_fn=set_cpu_affinity,
        )
        
        start_time = time.time()
        stdout, stderr = process.communicate(input=input_data, timeout=TIMEOUT)
        end_time = time.time()

        success = process.returncode == 0
        return stdout.strip() if success else stderr.strip(), success, end_time - start_time
        
    except subprocess.TimeoutExpired:
        if process:
            process.kill()
        return "Timeout", False, None
    except MemoryError:
        if process:
            process.kill()
        return "Memory Limit Exceeded", False, None
    except Exception as e:
        return str(e), False, None
    finally:
        if process and process.poll() is None:
            process.kill()

def timeout_handler(signum, frame):
    """Signal handler for timeout control."""
    raise TimeoutError("Execution timed out")

def run_with_timeout(func, args=(), kwargs={}, timeout_duration=TIMEOUT):
    """Execute function with timeout control using signals.
    
    Args:
        func: Function to execute
        args: Positional arguments
        kwargs: Keyword arguments
        timeout_duration: Maximum execution time
        
    Returns:
        Function result or raises TimeoutError
    """
    signal.signal(signal.SIGALRM, timeout_handler)
    signal.alarm(timeout_duration)
    try:
        result = func(*args, **kwargs)
    finally:
        signal.alarm(0)
    return result

def run_solution_with_function(solution_code, input_dict, fn_name=None):
    """Execute code solution using function calls.
    
    Args:
        solution_code: Source code containing function definition
        input_dict: Input parameters as dictionary
        fn_name: Name of function to call
        
    Returns:
        tuple: (output, success_flag, execution_time)
    """
    try:
        def set_cpu_affinity():
            os.sched_setaffinity(0, set(range(MAX_CORES)))
            limit_memory(MAX_MEMORY)
            
        # Create execution environment
        exec_globals = {}
        exec_locals = {}
        
        # Execute code to define functions/classes
        exec(solution_code, exec_globals, exec_locals)
        
        start_time = time.time()
        
        def execute_solution():
            if 'Solution' in exec_locals:
                solution_instance = exec_locals['Solution']()
                return getattr(solution_instance, fn_name)(**input_dict)
            else:
                return exec_locals[fn_name](**input_dict)
        
        result = run_with_timeout(execute_solution, timeout_duration=TIMEOUT)
        end_time = time.time()
            
        return str(result), True, end_time - start_time
        
    except TimeoutError:
        return "Timeout", False, None
    except MemoryError:
        return "Memory Limit Exceeded", False, None
    except Exception as e:
        return str(e), False, None

def determine_input_type(solution_code, starter_code):
    """Determine whether to use stdin or function-based execution.
    """
    if starter_code and starter_code.strip():
        return 'function'
    return 'stdin'

def process_task(description, solution_code, inputs_with_index, max_cases=15, starter_code=None, fn_name=None, language='python'):
    """Process a single solution task.
    
    Args:
        description: Problem description
        solution_code: Solution source code
        inputs_with_index: List of input cases with indices

        starter_code: Starter code template if any
        fn_name: Function name to call
        language: Programming language
        
    Returns:
        dict: Execution results including outputs and timing
    """
    results = {
        'description': description,
        'solution_code': solution_code,
        'language': language,
        'execution_results': []
    }
    
    input_type = determine_input_type(solution_code, starter_code)
    
    if max_cases and len(inputs_with_index) > max_cases:
        inputs_with_index = inputs_with_index[-max_cases:]
    
    for item in inputs_with_index:
        input_data = item['input']
        input_index = item['index']
        
        if input_type == 'stdin':
            output, success, exec_time = run_solution(solution_code, input_data, language)
        else:
            if isinstance(input_data, str):
                try:
                    input_data = json.loads(input_data)
                except:
                    output, success, exec_time = "Invalid input format", False, None
                    results['execution_results'].append({
                        'input_index': input_index,
                        'output': output,
                        'success': success,
                        'execution_time': exec_time,
                    })
                    continue
            
            output, success, exec_time = run_solution_with_function(solution_code, input_data, fn_name)
            
        results['execution_results'].append({
            'input_index': input_index,
            'output': output,
            'success': success,
            'execution_time': exec_time,
        })
        
        if output in ['Timeout', 'Memory Limit Exceeded']:
            break
            
    return results

def main():
    """Main entry point for program execution."""
    parser = argparse.ArgumentParser(description='Execute and evaluate code solutions')
    parser.add_argument("--query_path", type=str, help="Path to query file")
    parser.add_argument("--output_file", type=str, help="Path to output file")
    parser.add_argument("--max_cases", type=int, help="Maximum number of test cases to run")
    parser.add_argument("--batch_size", type=int, default=128, help="Batch size for processing")
    parser.add_argument("--num_processes", type=int, default=32, help="Number of parallel processes")
    args = parser.parse_args()
    
    df = pd.read_json(args.query_path)
    tasks = []
    
    # Build task list
    for _, row in df.iterrows():
        tasks.append((
            row['description'],
            row['code'],
            row.get('inputs_with_index', []),
            row.get('starter_code', ''),
            row.get('extract_fn_name'),
            row.get('language', 'python')
        ))

    total_tasks = len(tasks)
    print(f"Total tasks: {total_tasks}")

    success_count = 0
    fail_count = 0

    with ProcessPool(max_workers=args.num_processes) as pool:
        with tqdm(total=total_tasks, desc=f"Processing tasks | Success: {success_count} | Fail: {fail_count}", unit="task") as pbar:
            for batch_start in range(0, len(tasks), args.batch_size):
                start_time = time.time()
                batch_tasks = tasks[batch_start:batch_start + args.batch_size]
                futures = []
                
                for task in batch_tasks:
                    task_timeout = max(120, (TIMEOUT + 1) * min(args.max_cases or float('inf'), len(task[2])))
                    futures.append(pool.schedule(process_task, args=task, timeout=task_timeout))
                
                for future in futures:
                    try:
                        result = future.result()
                        with open(args.output_file, 'a') as f:
                            f.write(json.dumps(result, ensure_ascii=False) + '\n')
                        success_count += 1
                    except Exception as exc:
                        print(f"Task failed: {exc}")
                        fail_count += 1
                    finally:
                        pbar.set_description(f"Processing tasks | Success: {success_count} | Fail: {fail_count}")
                        pbar.update(1)
                        
                end_time = time.time()
                print(f"Batch time: {end_time - start_time:.2f} seconds")

if __name__ == "__main__":
    main()