import dataclasses 
import json 
import random 
import subprocess 
import traceback 
from collections import defaultdict 
from datetime import datetime 
from pathlib import Path 
from typing import Any ,Sequence 

import datasets 
import fire 
from tqdm import tqdm 

import pytracify 


@dataclasses .dataclass 
class Sample :
    sample_id :str 
    code :str 
    function_name :str 
    input :str 
    output :str 


def execute_cpython (source :str ,function_name :str ,args :str )->Any :
    output =None 

    def save (value :Any )->None :
        nonlocal output 
        output =value 

    code =f"""
from typing import List  # Many samples from LiveCodeBench require this

{source }

save({function_name }({args }))
"""

    env ={"save":save }
    exec (code ,env )
    return output 


def execute_pytracify (source :str ,function_name :str ,args :str )->Any :
    code =f"""
{source }

return {function_name }({args })
"""

    result =pytracify .run (code )
    return result .result .value 


def load_cruxeval ()->list [Sample ]:
    rows =datasets .load_dataset ("cruxeval-org/cruxeval",split ="test")
    return [
    Sample (
    sample_id =row ["id"],
    code =row ["code"],
    function_name ="f",
    input =row ["input"],
    output =row ["output"],
    )
    for row in rows 
    ]


def load_livecodebench ()->list [Sample ]:
    rows =datasets .load_dataset ("livecodebench/execution-v2",split ="test")


    samples =[]
    for row in rows :
        function_name =row ["function_name"]


        call_str =row ["input"]
        assert call_str .startswith (function_name +"(")
        assert call_str .endswith (")")
        input_str =call_str [len (function_name )+1 :-1 ]

        samples .append (
        Sample (
        sample_id =str (row ["id"]),
        code =row ["code"],
        function_name =function_name ,
        input =input_str ,
        output =row ["output"],
        )
        )

    return samples 


def choose_randomly (samples :list [Sample ],n_samples :int |None )->list [Sample ]:
    if n_samples is None :
        return samples 
    else :
        samples =samples .copy ()
        random .Random (42 ).shuffle (samples )
        return samples [:n_samples ]


def evaluate (samples :Sequence [Sample ])->list [dict ]:
    results =[]

    for sample in tqdm (samples ):
        result =dataclasses .asdict (sample )


        output_expected =eval (sample .output )


        try :
            output_cpython =execute_cpython (
            sample .code ,sample .function_name ,sample .input 
            )
            if output_expected ==output_cpython :
                result ["cpython_correct"]=True 
            else :
                result ["cpython_correct"]=False 
                result ["cpython_error"]="Wrong answer"
                result ["cpython_details"]=str (output_cpython )
        except Exception as e :
            result ["cpython_correct"]=False 
            result ["cpython_error"]=str (e )
            result ["cpython_details"]=traceback .format_exc ()


        try :
            output_pytracify =execute_pytracify (
            sample .code ,sample .function_name ,sample .input 
            )
            if output_expected ==output_pytracify :
                result ["pytracify_correct"]=True 
            else :
                result ["pytracify_correct"]=False 
                result ["pytracify_error"]="Wrong answer"
                result ["pytracify_details"]=str (output_pytracify )
        except Exception as e :
            result ["pytracify_correct"]=False 
            result ["pytracify_error"]=str (e )
            result ["pytracify_details"]=traceback .format_exc ()

        results .append (result )

    return results 


def get_git_repo_state ()->dict [str ,Any ]:
    def run_git_command (*args :str )->str :
        result =subprocess .run (
        ["git"]+list (args ),capture_output =True ,text =True ,check =True 
        )
        return result .stdout 

    commit_hash =run_git_command ("rev-parse","HEAD")
    status_lines =run_git_command ("status","--porcelain")
    branch_name =run_git_command ("branch","--show-current")
    log_lines =run_git_command ("log","-n","10","--pretty=oneline")
    diff_lines =run_git_command ("diff")

    return {
    "commit":commit_hash ,
    "branch":branch_name ,
    "status":status_lines ,
    "log":log_lines ,
    "diff":diff_lines ,
    }


def main (n_samples :int |None =None )->None :
    report ={"git":get_git_repo_state ()}

    for dataset_name in ["cruxeval","livecodebench"]:
        print (f"Dataset: {dataset_name }")
        if dataset_name =="cruxeval":
            samples =load_cruxeval ()
        else :
            samples =load_livecodebench ()
        samples =choose_randomly (samples ,n_samples )

        results =evaluate (samples )

        num_correct_cpython =sum (result ["cpython_correct"]for result in results )
        num_correct_pytracify =sum (result ["pytracify_correct"]for result in results )
        print (f"Correct (CPython):   {num_correct_cpython }/{len (samples )}")
        print (f"Correct (Pytracify): {num_correct_pytracify }/{len (samples )}")

        error_freqs :defaultdict [str ,int ]=defaultdict (int )
        for result in results :
            if not result ["pytracify_correct"]:
                error_freqs [result ["pytracify_error"]]+=1 
        print ("Errors:")
        for error_str ,freq in sorted (
        error_freqs .items (),key =lambda x :x [1 ],reverse =True 
        ):
            print (f"{freq :4d} {error_str }")

        report [dataset_name ]={
        "summary":{
        "correct_cpython":num_correct_cpython ,
        "correct_pytracify":num_correct_pytracify ,
        "total_samples":len (samples ),
        },
        "errors":{error_str :freq for error_str ,freq in error_freqs .items ()},
        "samples":results ,
        }

    if n_samples is None :
        filename =datetime .now ().strftime ("%Y%m%d-%H%M%S.json")
        filepath =Path ("benchmark_coverage")/filename 
        filepath .parent .mkdir (parents =True ,exist_ok =True )
        with open (filepath ,"w")as f :
            json .dump (report ,f ,indent =2 )


if __name__ =="__main__":
    fire .Fire (main )
