import argparse
import json
import pandas as pd
import transformers
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import requests
import re
from tqdm import tqdm
from utils.qa_em import compute_score_em
import numpy as np

FEW_SHOT_TEMPLATE = """<|im_start|>user
Answer the given question. You must conduct reasoning inside <think> and </think> first every time you get new information. 
After reasoning, if you find you lack some knowledge, you can call a search engine by <search> query </search> and 
it will return the top searched results between <information> and </information>. You can search as many times as your want. 
If you find no further external knowledge needed, you can directly provide the answer inside <answer> and </answer>, 
without detailed illustrations. For example, <answer> Beijing </answer>.
This is a few-shot learning exercise. Examples are provided below.
Question: What is the birth date of the lead singer of Coldplay?<think>I need to find out the lead singer of Coldplay and their birth date.</think><search>lead singer of Coldplay birth date</search><information>Doc 1(Title: Chris Martin) Christopher Anthony John Martin (born 2 March 1977) is an English singer, songwriter, and musician. He is the lead singer, pianist, rhythm guitarist, and co-founder of the rock band Coldplay.
</information><think>The lead singer of Coldplay is Chris Martin, and he was born on March 2, 1977.</think><answer>March 2, 1977</answer>

Question: What is the most populous city in the United States?
<think>I need to determine which city in the United States has the largest population.</think><search>most populous city in the United States</search>
<information>Doc 1(Title: New York City) New York City is the most populous city in the United States, with an estimated population of over 8.3 million people.
</information><think>New York City is the most populous city in the United States.</think><answer>New York City</answer>

Question: {question}
<|im_end|>
"""


ZERO_SHOT_TEMPLATE = """<|im_start|>user
Answer the given question. You must conduct reasoning inside <think> and </think> first every time you get new information. After reasoning, if you find you lack some knowledge, you can call a search engine by <search> query </search> and it will return the top searched results between <information> and </information>. You can search as many times as your want. If you find no further external knowledge needed, you can directly provide the answer inside <answer> and </answer>, without detailed illustrations. For example, <answer> Beijing </answer>. Question: {question}<|im_end|>
"""

def _convert_numpy_to_native(obj):
    if isinstance(obj, np.ndarray):
        return obj.tolist()
    elif isinstance(obj, np.generic):
        return obj.item()
    elif isinstance(obj, dict):
        return {k: _convert_numpy_to_native(v) for k, v in obj.items()}
    elif isinstance(obj, list):
        return [_convert_numpy_to_native(elem) for elem in obj]
    return obj

def batch_search(queries: list, retriever_url, top_k):
    if not queries:
        return []
    try:
        payload = {"queries": queries, "topk": top_k, "return_scores": True}
        response = requests.post(retriever_url, json=payload)
        response.raise_for_status()
        results = response.json().get('result', [])

        def _passages2string(retrieval_result):
            format_reference = ''
            for idx, doc_item in enumerate(retrieval_result or []):
                content = doc_item.get('document', {}).get('contents', '')
                title = content.split("\n")[0]
                text = "\n".join(content.split("\n")[1:])
                format_reference += f"Doc {idx+1}(Title: {title}) {text}\n"
            return format_reference
        
        if len(results) != len(queries):
            print(f"Warning: Number of search results ({len(results)}) does not match number of queries ({len(queries)}).")
            return [""] * len(queries)

        return [_passages2string(res) for res in results]
    except Exception as e:
        print(f"Error in batch search: {e}")
        return [""] * len(queries)

from utils.model_loading import select_checkpoint_path, load_causal_lm, _normalize_local_path


def main(args):
    # Resolve model path: either a direct HF dir/id or a checkpoint root + step
    model_path = args.model_path
    if getattr(args, "ckpt_root", None):
        model_path = select_checkpoint_path(args.ckpt_root, args.ckpt_step)
    elif model_path:
        model_path = _normalize_local_path(model_path)
    print("debugging", "model_path", model_path)

    tokenizer = AutoTokenizer.from_pretrained(model_path, padding_side='left')
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    model = AutoModelForCausalLM.from_pretrained(
        model_path,
        torch_dtype=torch.bfloat16,
        device_map="auto",
        trust_remote_code=True,
        attn_implementation="sdpa",
        local_files_only=True,
    )
    
    df = pd.read_parquet(args.data_path)
    if args.num_samples:
        df = df.head(args.num_samples)
    
    questions = df['question'].tolist()
    ground_truths = df['golden_answers'].tolist()

    with open(args.output_path, 'w') as f:
        for i in tqdm(range(0, len(questions), args.val_batch_size), desc="Evaluating batches"):
            batch_questions = questions[i:i+args.val_batch_size]
            batch_ground_truths = ground_truths[i:i+args.val_batch_size]
            
            if args.fewshot:
                current_prompts = [FEW_SHOT_TEMPLATE.format(question=q) for q in batch_questions]
            else:
                current_prompts = [ZERO_SHOT_TEMPLATE.format(question=q) for q in batch_questions]

            # Add assistant start tag
            for i in range(len(current_prompts)):
                current_prompts[i] += '\n<|im_start|>assistant\n'

            do_search_flags = [False] * len(batch_questions)
            active_mask = [True] * len(batch_questions)

            for turn in range(args.max_turns):
                if not any(active_mask):
                    break

                active_prompts = [p for p, active in zip(current_prompts, active_mask) if active]
                active_indices = [idx for idx, active in enumerate(active_mask) if active]

                inputs = tokenizer(active_prompts, return_tensors='pt', padding=True, truncation=True).to(model.device)
                
                outputs = model.generate(
                    **inputs,
                    max_new_tokens=args.max_new_tokens,
                    pad_token_id=tokenizer.pad_token_id,
                    do_sample=True,
                    temperature=0.7,
                )
                
                generated_outputs = tokenizer.batch_decode(outputs[:, inputs.input_ids.shape[1]:], skip_special_tokens=True)
                
                search_queries = []
                search_indices = []

                for i, gen_text in enumerate(generated_outputs):
                    original_index = active_indices[i]
                    
                    processed_text = gen_text.split('</search>')[0] + '</search>' if '</search>' in gen_text else \
                                     gen_text.split('</answer>')[0] + '</answer>' if '</answer>' in gen_text else \
                                     gen_text
                    
                    current_prompts[original_index] += processed_text
                    
                    tag_match = re.search(r'<(search|answer)>(.*)', processed_text, re.DOTALL)
                    
                    if tag_match:
                        action = tag_match.group(1)
                        content = tag_match.group(2).strip()
                        end_tag = f'</{action}>'
                        if end_tag in content:
                            content = content.split(end_tag, 1)[0].strip()

                        if action == 'answer':
                            active_mask[original_index] = False
                        
                        elif action == 'search' and args.do_search:
                            do_search_flags[original_index] = True
                            if content.strip():
                                search_queries.append(content)
                                search_indices.append(original_index)
                            else:
                                current_prompts[original_index] += "<information></information>\n\n"
                        else:
                            active_mask[original_index] = False
                    else:
                        active_mask[original_index] = False
                
                if search_queries:
                    search_results = batch_search(search_queries, args.retriever_url, args.top_k)
                    for res_idx, original_idx in enumerate(search_indices):
                        current_prompts[original_idx] += f"<information>{search_results[res_idx]}</information>\n\n"
            
            for idx in range(len(batch_questions)):
                current_prompts[idx] += '<|im_end|>' # Add end tag
                gt_dict = {'target': batch_ground_truths[idx]}
                score = compute_score_em(solution_str=current_prompts[idx], ground_truth=gt_dict)
                serializable_ground_truth = _convert_numpy_to_native(gt_dict)
                
                result_item = {
                    "question": batch_questions[idx],
                    "sequences_str": current_prompts[idx],
                    "ground_truth": serializable_ground_truth,
                    "reward": score,
                    "do_search": do_search_flags[idx]
                }
                f.write(json.dumps(result_item) + '\n')


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_path", type=str, required=False, default=None,
                        help="HF repo ID or local model dir. Ignored if --ckpt_root is set.")
    parser.add_argument("--ckpt_root", type=str, default=None,
                        help="Training run root containing global_step_* subdirs, or a direct model dir.")
    parser.add_argument("--ckpt_step", type=str, default="latest",
                        help="Checkpoint step to load (integer) or 'latest'.")
    parser.add_argument("--data_path", type=str, required=True)
    parser.add_argument("--output_path", type=str, required=True)
    parser.add_argument("--max_new_tokens", type=int, default=500)
    parser.add_argument("--retriever_url", type=str)
    parser.add_argument("--top_k", type=int, default=3)
    parser.add_argument("--max_turns", type=int, default=5)
    parser.add_argument("--val_batch_size", type=int, default=8, help="Validation batch size.")
    parser.add_argument("--do_search", action='store_true')
    parser.add_argument("--fewshot", action='store_true', help="Enable few-shot prompting.")
    parser.add_argument("--num_samples", type=int, default=None, help="Number of samples to evaluate on for debugging.")
    args = parser.parse_args()
    main(args) 
