from typing import Optional 
from pathlib import Path 
import copy 
import pytracify 
import crbench 
from crbench .base import Task ,Instance 

SOLVER_DIR =Path (__file__ ).parent /"solvers"


def read_solver (name_or_path :str )->str :
    path =Path (name_or_path )
    if not path .exists ():
        assert (SOLVER_DIR /path ).exists (),f"Solver {name_or_path } not found"
        path =SOLVER_DIR /path 
    print ("Reading solver from",name_or_path )
    return path .open ().read ()


def string_format (s ,**kwargs ):
    """
    same as s.format(var=value) but use .replace instead
    """
    for var ,value in kwargs .items ():
        s =s .replace ("{"+var +"}",str (value ))
    return s 


class BaseTransform :
    """
    tokenize text
    """

    def __init__ (
    self ,
    tokenizer ,

    tokenize_kwargs :Optional [dict ]=None ,
    ):
        self .tokenizer =tokenizer 
        if tokenize_kwargs is None :
            tokenize_kwargs ={"padding":True }
        self .tokenize_kwargs =tokenize_kwargs 

    def __call__ (self ,example ,batched =True ,do_eval =False ):
        if self .tokenizer is None :
            return example 
        padding_side ="left"if do_eval else "right"
        self .tokenizer .padding_side =padding_side 
        tokenize_kwargs =copy .deepcopy (self .tokenize_kwargs )
        if do_eval :

            tokenize_kwargs ["padding"]=True 
        return self .tokenizer (
        example ["text"],
        return_tensors ="pt",
        **tokenize_kwargs ,
        )


class R1DistillTransform (BaseTransform ):
    def __init__ (
    self ,
    system_prompt :str ,
    system_prompt_in_user :bool =False ,
    **kwargs ,
    ):
        super ().__init__ (**kwargs )
        self .system_prompt =system_prompt 
        self .system_prompt_in_user =system_prompt_in_user 

    def apply_chat_template (self ,example )->str :
        generation =example ["generation"]
        if generation is None :
            generation =f"<answer>\n{example ['answer']}\n</answer>"
        if self .system_prompt_in_user :
            text =f"{self .system_prompt }\n\nInput\n{example ['input']}"
            messages =[
            {"role":"user","content":text },
            {"role":"assistant","content":generation },
            ]
        else :
            messages =[
            {"role":"system","content":self .system_prompt },
            {"role":"user","content":example ["input"]},
            {"role":"assistant","content":generation },
            ]
        return self .tokenizer .apply_chat_template (messages ,tokenize =False )

    def __call__ (self ,batch ):
        keys =list (batch .keys ())
        batch =[{k :batch [k ][i ]for k in keys }for i in range (len (batch [keys [0 ]]))]
        texts =[self .apply_chat_template (example )for example in batch ]
        return super ().__call__ ({"text":texts },batched =True ,do_eval =False )


class ChatTransformForCR (BaseTransform ):
    def __init__ (
    self ,
    system_prompt :Optional [str ]=None ,
    input_format :Optional [str ]=None ,
    answer_format :Optional [str ]=None ,
    **kwargs ,
    ):
        super ().__init__ (**kwargs )
        if system_prompt is None :
            system_prompt =crbench .utils .SYSTEM_PROMPT 
        self .system_prompt =system_prompt 
        self .input_format =input_format 
        self .answer_format =answer_format 
        if input_format is not None :
            assert "{input}"in input_format ,"input_format must contain {input}"
        if answer_format is not None :
            assert "{answer}"in answer_format ,"answer_format must contain {answer}"
        else :
            self .answer_format ="""<answer>\n{answer}\n</answer>"""

    def process_instance (self ,task :Task ,instance :Instance ,do_eval :bool )->str :
        messages =[
        {"role":"system","content":self .system_prompt },
        {
        "role":"user",
        "content":task .build_prompt (instance )
        if self .input_format is None 
        else string_format (self .input_format ,input =instance .input ),
        },
        ]
        if do_eval :
            prompt =self .tokenizer .apply_chat_template (
            messages ,tokenize =False ,add_generation_prompt =True 
            )
        else :
            answer =task .solve (instance .input )
            answer_str =string_format (self .answer_format ,answer =answer )
            messages .append (
            {"role":"assistant","content":answer_str },
            )
            prompt =self .tokenizer .apply_chat_template (
            messages ,tokenize =False ,add_generation_prompt =False 
            )
        return prompt 

    def __call__ (self ,example ,task :Task ,batched =True ,do_eval =False ):
        if not batched :
            instance =Instance (
            input =example ["input"],verifier_hint =example ["verifier_hint"]
            )
            example ["text"]=self .process_instance (task ,instance ,do_eval )
        else :
            keys =list (example .keys ())
            bsz =len (example [keys [0 ]])
            instances =[
            Instance (
            input =example ["input"][idx ],
            verifier_hint =example ["verifier_hint"][idx ],
            )
            for idx in range (bsz )
            ]
            example ["text"]=[
            self .process_instance (task ,instance ,do_eval )for instance in instances 
            ]
        return super ().__call__ (example ,batched ,do_eval )


class PytracifyTransform (BaseTransform ):
    def __init__ (
    self ,
    solver_name_or_path :str ,
    trace_formatter :str ="numeric_depth",
    **kwargs ,
    ):
        super ().__init__ (**kwargs )
        self .solver_script =read_solver (solver_name_or_path )
        self .trace_formatter =pytracify .get_trace_formatter (trace_formatter )

    def get_trace (self ,problem )->str :
        solver_script =self .solver_script .format (problem =problem )
        try :
            result =pytracify .run (solver_script )
            trace =self .trace_formatter (result .trace )
        except Exception as e :
            raise ValueError (f"Error in problem={problem }\nerror: {e }")
        return trace 

    def __call__ (self ,example ,batched =True ):
        if not batched :
            example ["text"]=self .get_trace (example ["input"])
        else :
            example ["text"]=[self .get_trace (problem )for problem in example ["input"]]
        return super ().__call__ (example ,batched )


class PytracifyTransformForCD (BaseTransform ):
    """
    Pytracify transform for CountDown dataset
    """

    def __init__ (
    self ,
    solver_name_or_path :str ,
    trace_formatter :str ="numeric_depth",
    **kwargs ,
    ):
        super ().__init__ (**kwargs )
        self .solver_script =read_solver (solver_name_or_path )
        self .trace_formatter =pytracify .get_trace_formatter (trace_formatter )

    def get_trace (self ,problem ,target )->str :
        solver_script =string_format (
        self .solver_script ,problem =problem ,target =target 
        )
        try :
            result =pytracify .run (solver_script )
            trace =self .trace_formatter (result .trace )
        except Exception as e :




            trace =""
        return trace 

    def __call__ (self ,example ,batched =True ):
        if not batched :
            example ["text"]=self .get_trace (example ["input"],example ["answer"])
        else :
            bsz =len (example ["input"])
            example ["text"]=[
            self .get_trace (example ["input"][idx ],example ["answer"][idx ])
            for idx in range (bsz )
            ]
        return super ().__call__ (example ,batched )


class PytracifyTransformForCR (BaseTransform ):
    """
    Pytracify transform for CRBench dataset
    """

    def __init__ (
    self ,
    trace_formatter :str ="numeric_depth",
    system_prompt :Optional [str ]=None ,
    input_format :Optional [str ]=None ,
    answer_format :Optional [str ]=None ,
    **kwargs ,
    ):
        super ().__init__ (**kwargs )
        if system_prompt is None :
            system_prompt =crbench .utils .SYSTEM_PROMPT 
        self .trace_formatter =pytracify .get_trace_formatter (trace_formatter )
        self .system_prompt =system_prompt 
        self .input_format =input_format 
        self .answer_format =answer_format 
        if input_format is not None :
            assert "{input}"in input_format ,"input_format must contain {input}"

        if answer_format is not None :
            assert "{trace}"in answer_format ,"answer_format must contain {trace}"
            assert "{answer}"in answer_format ,"answer_format must contain {answer}"
        else :
            self .answer_format =(
            """<think>\n{trace}\n</think>\n<answer>\n{answer}\n</answer>"""
            )

    def get_trace (self ,task :Task ,instance :Instance )->str :
        program =task .get_solution_program ()
        instance_input =instance .input 


        if isinstance (instance_input ,list ):
            instance_input =tuple (instance_input )
        code =(
        program .source_code .strip ()
        +f"\n\n\nreturn {program .generate_call_code (instance_input )}"
        )
        pytracify_result =pytracify .run (code )
        trace_str =pytracify .get_trace_formatter (self .trace_formatter )(
        pytracify_result .trace 
        )
        assert pytracify_result .result .kind =="return",(
        f"Seems like the program is not returning a value\ninstance: {instance }"
        )
        assert task .verify (instance ,pytracify_result .result .value ),(
        f"Verification failed\ninstance: {instance }"
        )
        user_input =(
        task .build_prompt (instance )
        if self .input_format is None 
        else string_format (self .input_format ,input =instance .input )
        )
        messages =[
        {"role":"system","content":self .system_prompt },
        {"role":"user","content":user_input },
        {
        "role":"assistant",
        "content":string_format (
        self .answer_format ,
        trace =trace_str ,
        answer =pytracify_result .result .value ,
        ),
        },
        ]
        return self .tokenizer .apply_chat_template (
        messages ,tokenize =False ,add_generation_prompt =False 
        )

    def get_prompt (self ,task :Task ,instance :Instance )->str :
        messages =[
        {"role":"system","content":self .system_prompt },
        {"role":"user","content":task .build_prompt (instance )},
        ]
        return self .tokenizer .apply_chat_template (
        messages ,tokenize =False ,add_generation_prompt =True 
        )

    def process_instance (self ,task :Task ,instance :Instance ,do_eval :bool )->str :
        if not do_eval :
            return self .get_trace (task ,instance )
        else :
            return self .get_prompt (task ,instance )

    def __call__ (self ,example ,task :Task ,batched =True ,do_eval =False ):
        if not batched :
            instance =Instance (
            input =example ["input"],verifier_hint =example ["verifier_hint"]
            )
            example ["text"]=self .process_instance (task ,instance ,do_eval )
        else :
            keys =list (example .keys ())
            bsz =len (example [keys [0 ]])
            instances =[
            Instance (
            input =example ["input"][idx ],
            verifier_hint =example ["verifier_hint"][idx ],
            )
            for idx in range (bsz )
            ]
            example ["text"]=[
            self .process_instance (task ,instance ,do_eval )for instance in instances 
            ]
        return super ().__call__ (example ,batched ,do_eval )


class PytracifyTransformForCRraw (BaseTransform ):
    """
    Pytracify transform for CRBench dataset
    """

    def __init__ (
    self ,
    trace_formatter :str ="numeric_depth",
    system_prompt :Optional [str ]=None ,
    input_format :Optional [str ]=None ,
    answer_format :Optional [str ]=None ,
    return_answer :bool =False ,
    **kwargs ,
    ):
        super ().__init__ (**kwargs )
        if system_prompt is None :
            system_prompt =crbench .utils .SYSTEM_PROMPT 
        self .trace_formatter =pytracify .get_trace_formatter (trace_formatter )
        self .system_prompt =system_prompt 
        self .input_format =input_format 
        self .answer_format =answer_format 
        self .return_answer =return_answer 
        if input_format is not None :
            assert "{input}"in input_format ,"input_format must contain {input}"

        if answer_format is not None :
            assert "{trace}"in answer_format ,"answer_format must contain {trace}"
            assert "{answer}"in answer_format ,"answer_format must contain {answer}"
        else :
            self .answer_format =(
            """<think>\n{trace}\n</think>\n<answer>\n{answer}\n</answer>"""
            )

    def get_trace (self ,task :Task ,instance :Instance )->str :
        program =task .get_solution_program ()
        instance_input =instance .input 


        if isinstance (instance_input ,list ):
            instance_input =tuple (instance_input )
        code =(
        program .source_code .strip ()
        +f"\n\n\nreturn {program .generate_call_code (instance_input )}"
        )
        pytracify_result =pytracify .run (code )
        trace_str =pytracify .get_trace_formatter (self .trace_formatter )(
        pytracify_result .trace 
        )
        assert pytracify_result .result .kind =="return",(
        f"Seems like the program is not returning a value\ninstance: {instance }"
        )
        assert task .verify (instance ,pytracify_result .result .value ),(
        f"Verification failed\ninstance: {instance }"
        )
        user_input =(
        task .build_prompt (instance )
        if self .input_format is None 
        else string_format (self .input_format ,input =instance .input )
        )
        messages =[
        {"role":"system","content":self .system_prompt },
        {"role":"user","content":user_input },
        {
        "role":"assistant",
        "content":string_format (
        self .answer_format ,
        trace =trace_str ,
        answer =pytracify_result .result .value ,
        ),
        },
        ]
        if self .return_answer :
            return self .tokenizer .apply_chat_template (
            messages ,tokenize =False ,add_generation_prompt =False 
            ),pytracify_result .result .value 
        else :
            return self .tokenizer .apply_chat_template (
            messages ,tokenize =False ,add_generation_prompt =False 
            )

    def get_prompt (self ,task :Task ,instance :Instance )->str :
        messages =[
        {"role":"system","content":self .system_prompt },
        {"role":"user","content":task .build_prompt (instance )},
        ]
        if self .return_answer :
            program =task .get_solution_program ()
            instance_input =instance .input 


            if isinstance (instance_input ,list ):
                instance_input =tuple (instance_input )
            code =(
            program .source_code .strip ()
            +f"\n\n\nreturn {program .generate_call_code (instance_input )}"
            )
            pytracify_result =pytracify .run (code )

            return self .tokenizer .apply_chat_template (
            messages ,tokenize =False ,add_generation_prompt =True 
            ),pytracify_result .result .value 
        else :
            return self .tokenizer .apply_chat_template (
            messages ,tokenize =False ,add_generation_prompt =True 
            )

    def process_instance (self ,task :Task ,instance :Instance ,do_eval :bool )->str :
        if not do_eval :
            if self .return_answer :
                trace ,answer =self .get_trace (task ,instance )
                return trace ,answer 
            if not self .return_answer :
                return self .get_trace (task ,instance )
        else :
            if self .return_answer :
                trace ,answer =self .get_prompt (task ,instance )
                return trace ,answer 
            else :
                return self .get_prompt (task ,instance )

    def __call__ (self ,example ,task :Task ,batched =True ,do_eval =False ):
        if not batched :
            instance =Instance (
            input =example ["input"],verifier_hint =example ["verifier_hint"]
            )
            if self .return_answer :
                trace ,answer =self .process_instance (task ,instance ,do_eval )
                example ["text"]=trace 
                example ["answer"]=answer 
            else :
                trace =self .process_instance (task ,instance ,do_eval )
                example ["text"]=trace 

        else :
            keys =list (example .keys ())
            bsz =len (example [keys [0 ]])
            instances =[
            Instance (
            input =example ["input"][idx ],
            verifier_hint =example ["verifier_hint"][idx ],
            )
            for idx in range (bsz )
            ]
            if self .return_answer :
                example ["text"]=[]
                example ["answer"]=[]
                for instance in instances :
                    trace ,answer =self .process_instance (task ,instance ,do_eval )
                    example ["text"].append (trace )
                    example ["answer"].append (answer )
            else :
                example ["text"]=[
                self .process_instance (task ,instance ,do_eval )
                for instance in instances 
                ]

        return example 

