from typing import Any ,List ,Tuple 
import pytracify 
from tqdm import tqdm 
import dataclasses 
import multiprocessing 

from concurrent .futures import ProcessPoolExecutor ,as_completed 

from .common_types import Sample 


def create_prompt (
code :str ,
trace :str ,
final_output :str ,
input_data :str ="",
function_name :str ="",
problem_statement :str ="",
is_code_gen :bool =False ,
)->dict :
    if function_name =="":
        user_content =f"<code>\n{code }\n</code>\n<input>\n{input_data }\n</input>"
    elif function_name ==""and input_data =="":
        user_content =f"<code>\n{code }\n</code>"
    else :
        user_content =f"<code>\n{code }\nprint({function_name }({input_data }))</code>"

    assistant_content =f"<think>\n{trace }\n</think>\n<answer>{final_output }</answer>"

    if is_code_gen :
        user_content =(
        f"<problem>{problem_statement }</problem>\n<input>\n{input_data }\n</input>"
        if input_data !=""
        else f"<problem>{problem_statement }</problem>"
        )
        assistant_content =f"This is the code:\n<code>\n{code }\n</code>\n<think>\n{trace }\n</think>\n<answer>{final_output }</answer>"

    return {
    "input":[
    {
    "role":"user",
    "content":user_content ,
    }
    ],
    "output":{
    "role":"assistant",
    "content":assistant_content ,
    },
    }










def execute_pytracify (
code :str ,
input_data :str ,
trace_formatter_name :str ,
function_name :str ="",
hide_mnemonics :bool =False ,
)->tuple [str ,str ]:
    override_code_str ="""
def len(target):
    cnt = 0
    for element in target:
        cnt += 1
    return cnt
"""

    if function_name ==""or "input("in code :
        result =pytracify .run (
        code ,
        input_data =input_data ,
        hide_mnemonics =hide_mnemonics ,
        override_code_str =override_code_str ,
        )
    else :
        code =f"""
{code }

return {function_name }({input_data })
"""
        result =pytracify .run (
        code ,hide_mnemonics =hide_mnemonics ,override_code_str =override_code_str 
        )

    trace_formatter =pytracify .get_trace_formatter (trace_formatter_name )
    trace_str =trace_formatter (result .trace )

    return (
    trace_str ,
    result .result .value 
    if function_name !=""and "input("not in code 
    else result .stdout .rstrip ("\n"),
    )


def execute_pytracify_in_subprocess (
code :str ,
input_data :str ,
trace_formatter_name :str ,
function_name :str ,
return_dict :dict ,
hide_mnemonics :bool =False ,
):
    try :
        trace_str ,output_val =execute_pytracify (
        code ,input_data ,trace_formatter_name ,function_name ,hide_mnemonics 
        )
        return_dict ["trace"]=trace_str 
        return_dict ["output_val"]=output_val 
        return_dict ["error"]=None 
    except Exception as e :
        return_dict ["trace"]=None 
        return_dict ["output_val"]=None 
        return_dict ["error"]=str (e )


def execute_pytracify_with_timeout (
code :str ,
input_data :str ,
trace_formatter_name :str ,
function_name :str ,

timeout :float =5.0 ,
hide_mnemonics :bool =False ,
)->Tuple [str ,Any ]:
    manager =multiprocessing .Manager ()
    return_dict =manager .dict ()
    p =multiprocessing .Process (
    target =execute_pytracify_in_subprocess ,
    args =(
    code ,
    input_data ,
    trace_formatter_name ,
    function_name ,
    return_dict ,
    hide_mnemonics ,
    ),
    )
    p .start ()
    p .join (timeout )
    if p .is_alive ():
        p .terminate ()
        p .join ()
        raise TimeoutError ("Execution timed out after {} seconds.".format (timeout ))
    if return_dict ["error"]is not None :
        raise RuntimeError (return_dict ["error"])
    return return_dict ["trace"],return_dict ["output_val"]


def evaluate_sample (
sample :Sample ,trace_formatter_name :str ,hide_mnemonics :bool ,is_code_gen :bool 
)->dict :
    result_dict =dataclasses .asdict (sample )
    try :
        expected_value =eval (sample .output )
    except Exception :
        expected_value =sample .output 

    result_dict ["pytracify_correct"]=False 
    result_dict ["pytracify_error"]=""

    if (
    "quit("in sample .code 
    or "open("in sample .code 
    or "exit("in sample .code 
    or "queue.Queue"in sample .code 
    or "import Queue"in sample .code 
    ):
        return result_dict 


    if (
    "count("in sample .code or "cycle"in sample .code or "repeat"in sample .code 
    )and "itertools"in sample .code :
        return result_dict 

    try :
        trace_str ,output_val =execute_pytracify_with_timeout (
        sample .code ,
        sample .input ,
        trace_formatter_name ,
        sample .function_name ,
        timeout =5.0 ,

        hide_mnemonics =hide_mnemonics ,
        )
        is_correct =str (output_val )==str (expected_value )
        result_dict ["pytracify_correct"]=is_correct 

        if not is_correct :
            result_dict ["pytracify_error"]="Wrong answer"

        fine_tune_data =create_prompt (
        code =sample .code ,
        trace =trace_str ,
        final_output =sample .output ,
        input_data =sample .input ,
        problem_statement =sample .problem_statement ,
        is_code_gen =is_code_gen ,
        )
        result_dict ["fine_tune_data"]=fine_tune_data 

    except TimeoutError as e :
        print (f"Skipping sample {sample .sample_id } due to timeout: {e }")
        result_dict ["pytracify_correct"]=False 
        result_dict ["pytracify_error"]="Execution Timeout"

    except Exception as e :
        result_dict ["pytracify_correct"]=False 
        result_dict ["pytracify_error"]=str (e )

    return result_dict 



def evaluate (
samples :List [Sample ],
trace_formatter_name :str ,
num_workers :int ,
hide_mnemonics :bool =False ,
is_code_gen :bool =False ,
)->List [dict ]:
    results =[]
    with ProcessPoolExecutor (max_workers =num_workers )as executor :
        future_to_sample ={
        executor .submit (
        evaluate_sample ,
        sample ,
        trace_formatter_name ,
        hide_mnemonics ,
        is_code_gen ,
        ):sample 
        for sample in samples 
        }
        for future in tqdm (as_completed (future_to_sample ),total =len (future_to_sample )):
            try :
                result =future .result ()
                results .append (result )
            except Exception as exc :
                sample =future_to_sample [future ]
                print (f"Sample {sample .sample_id } generated an exception: {exc }")

    return results 
