"""
Run CRBench with vLLM.

Usage:
CKPT=logs/ckpt-path
python run_crbench.py \
    --ckpt $CKPT \
    --num_gpus=8 \
    --num_examples=100
"""

import sys 
import argparse 
import os 
import re 
from tqdm import tqdm 
import glob 
import gc 
import itertools 
from natsort import natsorted 
from omegaconf import OmegaConf 

import pandas as pd 
from datasets import Dataset ,Features ,Value 
import difflib 
import pytracify 
from crbench .base import Instance 

import torch 
from torch .utils .data import DataLoader 
from lightning .pytorch import seed_everything 
import ray 
from vllm import LLM ,SamplingParams 

sys .path .append (".")
from ltm .utils import instantiate_from_config ,instantiate_model_from_config 
from ltm .data .transforms import string_format 


SEED =42 


def parse_args ():
    parser =argparse .ArgumentParser ()
    parser .add_argument ("--ckpt",type =str ,default =None )
    parser .add_argument ("--model_path",type =str ,default =None )
    parser .add_argument ("--config",type =str ,default =None )
    parser .add_argument ("--batch_size",type =int ,default =1 )
    parser .add_argument ("--num_gpus",type =int ,default =1 )
    parser .add_argument ("--tp_size",type =int ,default =1 )
    parser .add_argument ("--num_examples",type =int ,default =None )
    return parser .parse_args ()


def get_config (ckpt ):
    if ckpt .endswith (".ckpt"):
        base_dir =os .path .dirname (os .path .dirname (ckpt ))
    else :
        base_dir =ckpt 
    try :
        config =glob .glob (os .path .join (base_dir ,"configs","*-project.yaml"))[0 ]
    except IndexError :
        raise FileNotFoundError (base_dir )
    return config 


def get_best_checkpoint_name (logdir ):
    ckpt =os .path .join (logdir ,"checkpoints","last**.ckpt")
    ckpt =natsorted (glob .glob (ckpt ))
    if len (ckpt )==0 :
        ckpt =os .path .join (logdir ,"checkpoints","epoch**.ckpt")
        ckpt =natsorted (glob .glob (ckpt ))
    ckpt =ckpt [-1 ]
    return ckpt 


def process_example (example ,system_prompt ,task ,tokenizer ,input_format =None ):
    instance =Instance (example ["input"],example ["verifier_hint"])

    messages =[
    {"role":"system","content":system_prompt },
    {
    "role":"user",
    "content":task .build_prompt (instance )
    if input_format is None 
    else string_format (input_format ,input =instance .input ),
    },
    ]
    example ["text"]=tokenizer .apply_chat_template (
    messages ,tokenize =False ,add_generation_prompt =True 
    )

    output =task .solve (instance .input )
    assert task .verify (instance ,output ),f"{instance .input } != {output }"

    example ["answer"]=str (output )
    return example 


def collate_fn (batch ):

    return {k :[d [k ]for d in batch ]for k in batch [0 ].keys ()}


def get_generated_ids (generated_ids :torch .Tensor ,input_ids :torch .Tensor ):
    input_len =input_ids .shape [-1 ]
    return generated_ids [:,input_len :]


def generate_trace_diff (example ,task ):
    instance =Instance (example ["input"],example ["verifier_hint"])
    program =task .get_solution_program ()
    code =(
    program .source_code .strip ()
    +f"\n\n\nreturn {program .generate_call_code (instance .input )}"
    )

    try :
        pytracify_result =pytracify .run (code )
        true_trace =pytracify .get_trace_formatter ("numeric_depth")(
        pytracify_result .trace 
        )
    except Exception as e :
        print (f"Error running pytracify for seed {example ['seed']}: {e }")
        example ["trace_diff"]=f"Error generating true trace: {e }"
        return example 

    generated_text =example ["generation"]
    try :
        generated_trace_match =re .search (
        r"<think>\n(.*?)\n</think>",generated_text ,re .DOTALL 
        )
        if generated_trace_match :
            generated_trace =generated_trace_match .group (1 ).strip ()


            diff_output =diff_strings (true_trace ,generated_trace ,n =2 )
            example ["trace_diff"]=diff_output 
        else :
            example ["trace_diff"]="No <think>...</think> trace found in generation"

    except AttributeError :
        example ["trace_diff"]=(
        "Error extracting trace from generation (AttributeError)"
        )
    except Exception as e :
        print (f"Error processing generated trace for seed {example ['seed']}: {e }")
        example ["trace_diff"]=f"Error processing generated trace: {e }"

    return example 


def diff_strings (
a :str ,
b :str ,
fromfile :str ="original_trace",
tofile :str ="generated_trace",
n :int =3 ,
)->str :
    a_lines =a .splitlines ()
    b_lines =b .splitlines ()



    diff_generator =difflib .unified_diff (
    a_lines ,b_lines ,fromfile =fromfile ,tofile =tofile ,n =n ,lineterm =""
    )
    diff_lines =list (diff_generator )

    if not diff_lines :
        return "No difference found"
    else :
        return "\n".join (diff_lines )


def generate (args ,model_path ,max_length ,dataset ):
    model =LLM (
    model =model_path ,
    trust_remote_code =True ,
    tensor_parallel_size =args .tp_size ,
    )
    sampling_params =SamplingParams (
    seed =SEED ,
    temperature =0.0 ,
    stop_token_ids =["<|im_end|>"],
    max_tokens =max_length ,
    )
    dataloader =DataLoader (
    dataset ,
    batch_size =args .batch_size ,
    shuffle =False ,
    collate_fn =collate_fn ,
    )
    all_results =[]
    for batch in tqdm (dataloader ,total =len (dataloader ),desc ="Generating"):
        outputs =model .generate (
        batch ["text"],
        sampling_params =sampling_params ,
        use_tqdm =False ,
        )
        generated_texts =[]
        preds =[]
        for output in outputs :
            generated_text =output .outputs [0 ].text 

            try :

                maybe_pred =re .search (
                r"<answer>\n(.*?)\n</answer>",generated_text ,re .DOTALL 
                ).group (1 )


                preds .append (maybe_pred )
            except AttributeError :
                preds .append (-1 )
            generated_texts .append (generated_text )
        batch ["generation"]=generated_texts 
        batch ["pred"]=preds 
        all_results .append (batch )
    return all_results 




if __name__ =="__main__":
    args =parse_args ()
    assert args .num_gpus %args .tp_size ==0 ,"num_gpus must be divisible by tp_size"
    seed_everything (SEED )
    if not args .config :
        args .config =get_config (args .ckpt )
    if args .model_path :
        model_path =args .model_path 
    else :
        assert args .ckpt is not None ,"Either --ckpt or --model_path must be provided"
        if os .path .isdir (args .ckpt )and not args .ckpt .endswith (".ckpt"):
            args .ckpt =get_best_checkpoint_name (args .ckpt )
        if os .path .isfile (args .ckpt ):
            model_path =os .path .join (os .path .dirname (args .ckpt ),"hf_model")
        else :
            model_path =os .path .join (args .ckpt ,"hf_model")
    config =OmegaConf .load (args .config )
    is_ltm =(
    "ltm"in os .path .basename (args .config )
    or "pytracify"in str (config .data .params .transform_config .target ).lower ()
    )

    if hasattr (config ,"eval_data"):
        data_config =OmegaConf .to_container (config .eval_data ,resolve =True )
    else :
        data_config =OmegaConf .to_container (config .data ,resolve =True )

    data =instantiate_from_config (data_config )
    if args .model_path is None and (
    not os .path .exists (model_path )
    or not os .path .exists (os .path .join (model_path ,"config.json"))
    ):
        tokenizer =data .tokenizer 
        tokenizer .padding_side ="left"

        model =instantiate_model_from_config (
        config =OmegaConf .to_container (
        config .model .params .model_config ,resolve =True 
        ),
        ckpt =args .ckpt ,
        )
        model .save_pretrained (model_path )
        tokenizer .save_pretrained (model_path )
        del model ,tokenizer 
        print (f"Saved model to {model_path }")
        gc .collect ()
    else :
        print (f"Found existing model in {model_path }")

    ray .init (num_gpus =args .num_gpus )

    tokenizer =data .tokenizer 
    assert config .system_prompt is not None ,"System prompt is not set"
    system_prompt =config .system_prompt 
    max_length =data .transform .tokenize_kwargs ["max_length"]
    input_format =getattr (data .transform ,"input_format",None )
    data .transform =None 
    data .prepare_data ()
    data .setup (stage ="fit")
    dataset =data .datasets ["test"]
    cr_task =data .task 
    del data 
    gc .collect ()
    dataset =dataset .map (
    process_example ,
    fn_kwargs ={
    "system_prompt":system_prompt ,
    "task":cr_task ,
    "tokenizer":tokenizer ,
    "input_format":input_format ,
    },
    num_proc =os .cpu_count (),
    )
    dataset =dataset .select_columns (["seed","input","text","answer"])
    if args .num_examples :
        dataset =dataset .select (range (args .num_examples ))

    print (f"Data length: {len (dataset )}")
    vllm_inference_func =ray .remote (num_gpus =args .tp_size )(generate ).remote 
    split_size =args .num_gpus //args .tp_size 
    all_results =[]
    for i in tqdm (range (split_size )):
        ds =dataset .shard (split_size ,i )
        all_results .append (vllm_inference_func (args ,model_path ,max_length ,ds ))
    all_results =ray .get (all_results )

    all_results =[item for sublist in all_results for item in sublist ]
    all_results ={
    k :list (itertools .chain .from_iterable ([item [k ]for item in all_results ]))
    for k in all_results [0 ].keys ()
    }






    print (all_results .keys ())
    ds =Dataset .from_dict (
    all_results ,
    features =Features (
    {
    "seed":Value ("int32"),
    "input":Value ("string"),
    "text":Value ("string"),
    "answer":Value ("string"),
    "pred":Value ("string"),
    "generation":Value ("string"),
    }
    ),
    )

    ds =ds .map (
    lambda x :{"correct":str (x ["answer"]).lower ()==str (x ["pred"]).lower ()},
    num_proc =os .cpu_count (),
    desc ="Computing accuracy",
    )
    accuracy =sum (ds ["correct"])/len (ds )
    print (f"Accuracy: {accuracy :.4f}")

    if is_ltm :
        ds =ds .map (
        generate_trace_diff ,
        fn_kwargs ={"task":cr_task },
        num_proc =os .cpu_count (),
        desc ="Generating trace diff",
        )


    name =os .path .basename (os .path .dirname (os .path .dirname (args .config )))
    output =os .path .join ("results",f"{name }.csv")
    os .makedirs (os .path .dirname (output ),exist_ok =True )
    ds .to_csv (output ,index =False )
