import dataclasses 
import json 
import random 
import re 
import multiprocessing 
from typing import List ,Dict ,Any ,Tuple 


import fire 
from tqdm import tqdm 

import pytracify 

from common .common_types import Sample 


def extract_code_content (code_block :str )->str :
    """
    Removes any triple backticks (``` or ```python, etc.) from a code block
    and returns the inner code as a string.
    """
    code_block_pattern =re .compile (r"```python(.*?)```",re .DOTALL )
    matches =code_block_pattern .findall (code_block )
    if matches :
        return matches [0 ].strip ()
    return code_block .strip ()


def extract_rhs_from_input (input_str :str )->str :
    """
    Extract right-hand side values from input string with variable assignments.
    """
    pattern =re .compile (r"([\w_]+)\s*=\s*(.*?)(?=,\s*[\w_]+\s*=|$)")
    lines =input_str .strip ().splitlines ()
    line_results =[]

    for line in lines :
        matches =pattern .findall (line )
        if matches :
            rhs_list =[rhs .strip ()for _ ,rhs in matches ]
            line_result =" ".join (rhs_list )
            line_results .append (line_result )

    if line_results :
        return ",".join (line_results )
    return input_str 


def load_pyx ()->List [Sample ]:
    """
    Load samples from pyx_samples.jsonl
    Returns a list of Sample objects containing code, input, and output.
    """
    samples =[]
    with open ("../pyx_samples.jsonl","r")as f :
        for line in tqdm (f ,desc ="Loading samples"):
            if not line .strip ():
                continue 

            data =json .loads (line )
            sample_id =data .get ("sample_id","")

            if not sample_id .startswith ("pyx_"):
                continue 

            try :
                idx =int (sample_id .split ("_")[-1 ])
                if idx <50 :
                    continue 
            except ValueError :
                continue 

            code =extract_code_content (data ["code"])
            input_str =extract_rhs_from_input (data ["input"])

            if "git"in code :
                print ("Git Found!!!")
                continue 

            if "vim"in code :
                print ("Vim Found!!!")
                continue 

            samples .append (
            Sample (
            sample_id =sample_id ,
            code =code ,
            input =input_str ,
            output =data ["output"].strip (),
            function_name =data ["function_name"],
            )
            )

    print (f"Loaded {len (samples )} valid samples from PyX dataset.")
    return samples 


def execute_pytracify_in_subprocess (
code :str ,
function_name :str ,
input_data :str ,
trace_formatter_name :str ,
return_dict :dict ,
):
    """
    This function is run inside a child process. We store results in return_dict.
    """
    try :
        if "input("in code :
            result =pytracify .run (code ,input_data =input_data )
        else :
            modified_code =f"""
{code }

print({function_name }({input_data }))
"""
            result =pytracify .run (modified_code )
        trace_formatter =pytracify .get_trace_formatter (trace_formatter_name )
        trace_str =trace_formatter (result .trace )
        output_str =result .stdout .rstrip ("\n")
        return_dict ["trace"]=trace_str 
        return_dict ["output"]=output_str 
        return_dict ["error"]=None 
    except Exception as e :
        return_dict ["trace"]=None 
        return_dict ["output"]=None 
        return_dict ["error"]=str (e )


def execute_pytracify (
code :str ,
function_name :str ,
input_data :str ,
trace_formatter_name :str ,
timeout :float =5.0 ,
)->Tuple [str ,str ]:
    """
    Execute code using pytracify with a timeout and return formatted trace and output.
    Raises TimeoutError if execution takes longer than timeout seconds.
    """
    manager =multiprocessing .Manager ()
    return_dict =manager .dict ()


    p =multiprocessing .Process (
    target =execute_pytracify_in_subprocess ,
    args =(code ,function_name ,input_data ,trace_formatter_name ,return_dict ),
    )
    p .start ()
    p .join (timeout )

    if p .is_alive ():

        p .terminate ()
        p .join ()
        raise TimeoutError (f"Execution timed out after {timeout } seconds")

    if return_dict ["error"]is not None :
        raise RuntimeError (return_dict ["error"])

    return return_dict ["trace"],return_dict ["output"]


def create_finetune_dataset (
code :str ,
input_data :str ,
trace :str ,
final_output :str ,
trace_formatter_name :str ,
)->Dict [str ,Any ]:
    user_content =f"<code>\n{code }\n</code>\n\n<input>\n{input_data }\n</input>"
    assistant_content =(
    f"<think>\n{trace }\n</think>\n\n<answer>\n{final_output }\n</answer>"
    )

    return {
    "input":[
    {
    "role":"user",
    "content":user_content ,
    }
    ],
    "output":{
    "role":"assistant",
    "content":assistant_content ,
    },
    }








def evaluate (samples :List [Sample ],trace_formatter_name :str )->List [Dict [str ,Any ]]:
    results =[]
    quit_skip_cnt =0 
    open_cnt =0 
    exit_cnt =0 
    timeout_cnt =0 

    for sample in tqdm (samples ,desc ="Evaluating samples"):


        result_dict =dataclasses .asdict (sample )


        if "quit("in sample .code :
            quit_skip_cnt +=1 
            continue 
        if "open("in sample .code :
            open_cnt +=1 
            continue 
        if "exit("in sample .code :
            exit_cnt +=1 
            continue 

        try :
            trace_str ,actual_output =execute_pytracify (
            sample .code ,
            sample .function_name ,
            sample .input ,
            trace_formatter_name ,
            timeout =5.0 ,
            )

            is_correct =eval (actual_output .strip ())==eval (sample .output .strip ())
            result_dict ["pytracify_correct"]=is_correct 

            if not is_correct :
                result_dict ["pytracify_error"]="Wrong answer"



            fine_tune_data =create_finetune_dataset (
            code =sample .code ,
            input_data =sample .input ,
            trace =trace_str ,
            final_output =sample .output ,
            trace_formatter_name =trace_formatter_name ,
            )
            result_dict ["fine_tune_data"]=fine_tune_data 

        except TimeoutError :
            timeout_cnt +=1 
            result_dict ["pytracify_correct"]=False 
            result_dict ["pytracify_error"]="Timeout"

        except Exception as e :
            result_dict ["pytracify_correct"]=False 
            result_dict ["pytracify_error"]=str (e )

        results .append (result_dict )


    incorrect_results =[r for r in results if not r .get ("pytracify_correct",False )]
    if incorrect_results :
        print ("\n=== PyX: Incorrect samples with errors ===")
        err_dict ={}
        for r in incorrect_results :
            err_key =r .get ("pytracify_error","UnknownError")
            err_dict [err_key ]=err_dict .get (err_key ,0 )+1 
        for err ,count in sorted (err_dict .items (),key =lambda x :x [1 ],reverse =True ):
            print (f"Error: {err } (count: {count })")
        print ("=== End of incorrect samples ===\n")






    return results 


def create_pyx_data ():
    """
    Create the complete PyX dataset by loading samples and evaluating them.
    """
    print ("=== Creating PyX data ===")
    random .seed (42 )
    samples =load_pyx ()
    random .shuffle (samples )

    formatter_settings =[





    (
    "numeric_depth",
    "../datasets/LLTM-pyx-numeric-depth-train.jsonl",
    "../datasets/LLTM-pyx-numeric-depth-val.jsonl",
    ),
    ]

    for formatter_name ,train_output ,val_output in formatter_settings :
        print (f"\n[PyX] Evaluating with formatter: {formatter_name }")
        results =evaluate (samples ,formatter_name )


        correct_results =[r for r in results if r .get ("pytracify_correct")]
        print (f"Correct count (overall): {len (correct_results )}/{len (samples )}")

        random .shuffle (correct_results )
        total_correct =len (correct_results )
        train_count =int (1 *total_correct )

        train_data =correct_results [:train_count ]
        val_data =correct_results [train_count :]


        with open (train_output ,"w",encoding ="utf-8")as f :
            for r in train_data :
                f .write (json .dumps (r ["fine_tune_data"],ensure_ascii =False )+"\n")
        print (f"Train split: {len (train_data )}")


        with open (val_output ,"w",encoding ="utf-8")as f :
            for r in val_data :
                f .write (json .dumps (r ["fine_tune_data"],ensure_ascii =False )+"\n")
        print (f"Val split: {len (val_data )}")


def main ():
    create_pyx_data ()


if __name__ =="__main__":
    fire .Fire (main )
