from swift .plugin import ORM ,orms 
import re 
import os 

def _short (s :str ,n :int =160 )->str :
    s =str (s ).replace ("\n","\\n")
    return s [:n ]+("…"if len (s )>n else "")

def _norm (s :str )->str :

    return " ".join (str (s ).strip ().lower ().split ())

class ExactMatch (ORM ):
    def __init__ (self ):

        self ._dbg_batches_to_log =int (os .getenv ("EM_DEBUG_BATCHES","3"))
        self ._called =0 
        print ("[EXACTMATCH] plugin loaded",flush =True )

    def __call__ (self ,completions ,solution ,**kwargs ):
        step =kwargs .get ("global_step",kwargs .get ("step",-1 ))
        rewards =[]
        fmt_ok =0 
        hits =0 


        if self ._called <self ._dbg_batches_to_log :
            print (f"[EXACTMATCH][step={step }] batch size: "
            f"completions={len (completions )} solutions={len (solution )}",flush =True )

        for i ,(completion ,sol )in enumerate (zip (completions ,solution )):
            try :

                match =re .search (r"<\s*answer\s*>(.*?)<\s*/\s*answer\s*>",
                str (completion ),
                flags =re .DOTALL |re .IGNORECASE )
                if match is None :
                    rewards .append (0.0 )
                    if self ._called <self ._dbg_batches_to_log and i <4 :
                        print (f"[EXACTMATCH][step={step }][i={i }] format_miss "
                        f"| completion_head={_short (completion )}",flush =True )
                    continue 


                pred_raw =match .group (1 )
                pred =_norm (pred_raw )
                gold =_norm (sol )

                r =1.0 if pred ==gold else 0.0 
                rewards .append (r )

                fmt_ok +=1 
                hits +=int (r ==1.0 )

                if self ._called <self ._dbg_batches_to_log and i <4 :
                    print (f"[EXACTMATCH][step={step }][i={i }] r={r } "
                    f"\n  pred={_short (pred )}"
                    f"\n  gold={_short (gold )}",flush =True )

            except Exception as e :
                rewards .append (0.0 )
                if self ._called <self ._dbg_batches_to_log and i <4 :
                    print (f"[EXACTMATCH][step={step }][i={i }] exception={e } "
                    f"\n  completion_head={_short (completion )}"
                    f"\n  gold={_short (sol )}",flush =True )

        self ._called +=1 


        if self ._called <=self ._dbg_batches_to_log :
            total =max (1 ,len (completions ))
            print (f"[EXACTMATCH][step={step }] summary: "
            f"fmt_ok={fmt_ok }/{total } "
            f"hits={hits } "
            f"hit_rate={(hits /max (1 ,fmt_ok )):.3f} "
            f"(over formatted examples)",flush =True )

        return rewards 


orms ['exactmatch']=ExactMatch ()
