import ast 
import dataclasses 
import io 

from contextlib import redirect_stdout ,redirect_stderr ,_RedirectStream 
from .interpreter import Interpreter ,Result 
from .trace import Trace 


@dataclasses .dataclass 
class ExecutionResult :
    interpreter :Interpreter 
    result :Result 
    trace :Trace 
    stdout :str 
    stderr :str 



class redirect_stdin (_RedirectStream ):
    _stream ="stdin"


def run (
source :str ,
input_data :str ="",
max_char_length :int =128_000 *10 ,
hide_mnemonics :bool =False ,
override_code_str :str |None =None ,
)->ExecutionResult :
    override_func_defs :dict [str ,ast .FunctionDef ]={}
    combined_source =source 


    if override_code_str :
        main_n_lines =source .count ("\n")+1 
        override_tree =ast .parse (override_code_str )
        ast .increment_lineno (override_tree ,main_n_lines )
        override_func_defs ={
        node .name :node 
        for node in override_tree .body 
        if isinstance (node ,ast .FunctionDef )
        }
        combined_source =source +"\n"+override_code_str 

    tree =ast .parse (combined_source )


    if override_func_defs :
        tree .body =[
        n 
        for n in tree .body 
        if not (isinstance (n ,ast .FunctionDef )and n .name in override_func_defs )
        ]

    interpreter =Interpreter (
    combined_source ,
    max_char_length =max_char_length ,
    hide_mnemonics =hide_mnemonics ,
    override_func_defs =override_func_defs ,
    )

    stdout_buffer =io .StringIO ()
    stderr_buffer =io .StringIO ()
    stdin_buffer =io .StringIO (input_data )

    with (
    redirect_stdout (stdout_buffer ),
    redirect_stderr (stderr_buffer ),
    redirect_stdin (stdin_buffer ),
    ):
        result =interpreter .visit (tree ,None )

    return ExecutionResult (
    interpreter =interpreter ,
    result =result ,
    trace =interpreter .trace ,
    stdout =stdout_buffer .getvalue (),
    stderr =stderr_buffer .getvalue (),
    )
