import json
import logging
import os
import time
from typing import Callable, List
from pprint import pformat

import hydra
import numpy as np
import sglang as sgl
import yaml
from datasets import Dataset
from hydra.utils import get_original_cwd
from omegaconf import DictConfig, OmegaConf
from sglang.lang.ir import SglFunction
from transformers import AutoTokenizer
from jinja2 import Template

# Register custom resolvers
OmegaConf.register_new_resolver("basename", lambda path: os.path.basename(path))
OmegaConf.register_new_resolver("eval", eval)

from models.verifier import load_verifier
from search.amulet import amulet
from search.beam_search import beam_search
from search.best_of_n import best_of_n
from search.dvts import dvts
from search.residual_mppi import residual_mppi
from utils.data import get_dataset, normalize_dataset
from utils.evaluator import *
from utils.prompts import *

logging.basicConfig(level=logging.INFO)
logging.getLogger("openai").setLevel(logging.ERROR)
logging.getLogger("httpx").setLevel(logging.ERROR)
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

INFERENCE_METHODS = {
    "beam_search": beam_search,
    "best_of_n": best_of_n,
    "dvts": dvts,
    "residual_mppi": residual_mppi,
    "amulet": amulet,
}

def post_process_cfg(args):
    assert args.base_url is not None, "base_url is required"
    if args.model_name is None:
        args.model_name = os.path.basename(args.model_path)
    os.makedirs(args.output_dir, exist_ok=True)
    return args

def run_ping_pong_eval(args):
    run_batch_seconds = 0.0
    sgl.set_default_backend(sgl.RuntimeEndpoint(base_url=args.base_url, api_key=args.api_key))
    
    logger.info(f"Loading dataset from {args.data.dataset_path}...")
    dataset: Dataset = get_dataset(args.data.dataset_path, args.data.dataset_split)
    tokenizer = AutoTokenizer.from_pretrained(args.model_path)

    logger.info(f"Loading inference method '{args.method.approach}' and verifier...")    
    method: SglFunction = INFERENCE_METHODS[args.method.approach]
    verifier: Callable[[str, List[str]], List[List[float]]] = load_verifier(args.verifier, tokenizer)

    if args.data.dataset_size is not None:
        dataset = dataset.select(range(args.data.dataset_size))
        
    dataset = dataset.map(
        normalize_dataset,
        num_proc=32,
        desc="Normalizing dataset",
        with_indices=True,
        fn_kwargs={"args": args.data, "tokenizer": tokenizer},
    )

    # Initialize conversation history
    initial_histories = []
    for sample in dataset:
        history = []
        if sample['character'].get('initial_message'):
            history.append({'role': 'assistant', 'content': sample['character']['initial_message']})
        initial_histories.append(history)
    dataset = dataset.add_column("history", initial_histories)

    @sgl.function
    def player_generator(s, prompt):
        s += prompt
        s += sgl.gen("prediction", max_tokens=args.method.max_tokens)

    @sgl.function
    def interrogator_generator(s, prompt):
        s += prompt
        s += sgl.gen("prediction", max_tokens=1024)
        
    # Load templates once
    with open(args.data.templates.interrogator_system, "r") as f:
        interrogator_system_template = Template(f.read())
    with open(args.data.templates.interrogator_user, "r") as f:
        interrogator_user_template = Template(f.read())
    with open(args.data.templates.player_character, "r") as f:
        player_system_template = Template(f.read())

    num_turns = args.data.get("num_turns", 4)

    dataset_list = dataset.to_list()

    for i in range(num_turns):
        logger.info(f"--- Turn {i+1}/{num_turns} ---")

        # 1. Generate interrogator utterances
        logger.info("Generating interrogator utterances...")
        interrogator_prompts = []
        for sample in dataset_list:
            char = sample["character"]
            sit = sample["situation"]
            hist = sample["history"]
            
            sys_prompt = interrogator_system_template.render()
            usr_prompt = interrogator_user_template.render(
                char_summary=char.get("summary", char.get("char_name")),
                situation=sit.get("text", ""),
                messages=hist
            )
            
            interrogator_chat = [
                {"role": "system", "content": sys_prompt},
                {"role": "user", "content": usr_prompt}
            ]
            
            prompt_text = tokenizer.apply_chat_template(interrogator_chat, add_generation_prompt=True, tokenize=False)
            interrogator_prompts.append(prompt_text)

        interrogator_params = [dict(prompt=p) for p in interrogator_prompts]
        interrogator_states = interrogator_generator.run_batch(interrogator_params, num_threads=args.max_threads, progress_bar=True)
        
        for idx, state in enumerate(interrogator_states):
            try:
                interrogator_output = json.loads(state["prediction"])
                user_utterance = interrogator_output["next_utterance"]
            except (json.JSONDecodeError, KeyError):
                logger.warning(f"Failed to parse interrogator output: {state['prediction']}")
                user_utterance = state["prediction"]

            dataset_list[idx]["history"].append({"role": "user", "content": user_utterance})

        # 2. Generate player responses
        logger.info("Generating player responses...")

        # Reconstruct prompts for the player for this turn
        turn_prompts = []
        for sample in dataset_list:
            system_prompt = sample['_preference']
            history = sample['history']
            
            pref_messages = [{"role": "system", "content": system_prompt}] + history
            sample['_pref_prompt'] = tokenizer.apply_chat_template(pref_messages, add_generation_prompt=True, tokenize=False)
            
            non_pref_messages = history
            sample['_non_pref_prompt'] = tokenizer.apply_chat_template(non_pref_messages, add_generation_prompt=True, tokenize=False)
            
            # Use non_pref_prompt for generation, verifier will use both
            sample['_prompt'] = sample['_non_pref_prompt']
            turn_prompts.append(sample['_prompt'])
        
        params_list = [
            dict(
                prompt=prompt,
                sample=sample,
                verifier=verifier,
                tokenizer=tokenizer,
                args=args.method
            ) for sample, prompt in zip(dataset_list, turn_prompts)
        ]

        _t0_turn = time.perf_counter()
        player_states = method.run_batch(params_list, num_threads=args.max_threads, progress_bar=True)
        run_batch_seconds += time.perf_counter() - _t0_turn
    
        for idx, state in enumerate(player_states):
            player_response = state["prediction"]
            dataset_list[idx]["history"].append({"role": "assistant", "content": player_response})

    dataset = Dataset.from_list(dataset_list)

    dataset = dataset.rename_column("history", "full_conversation")
    dataset.to_json(args.result_path, lines=True)
    
    logger.info("Evaluation phase...")
    dataset = dataset.map(
        evaluate_ping_pong,
        num_proc=32,
        desc="Evaluating Ping-Pong Bench",
        fn_kwargs={
            "model_name": args.data.llm_as_a_judge_model,
            "templates_paths": OmegaConf.to_container(args.data.templates, resolve=True)
        }
    )
    
    metrics = dataset["metric"]
    metric = {key: np.mean([m[key] for m in metrics]) for key in metrics[0].keys()}
    metric["num_samples"] = len(dataset)
    metric["run_batch_seconds"] = run_batch_seconds

    json.dump(metric, open(args.metric_path, "w"), ensure_ascii=False, indent=4)
    logger.info(f"Done! Metrics: {metric}")


@hydra.main(version_base=None, config_path="configs", config_name="config")
def hydra_main(cfg: DictConfig):
    os.chdir(get_original_cwd())
    args = post_process_cfg(cfg)
    with open(os.path.join(args.output_dir, "config.yaml"), "w") as f:
        config_dict = OmegaConf.to_container(args, resolve=True)
        yaml.dump(config_dict, f, default_flow_style=False)
    
    logger.info(f"Running Ping Pong evaluation with config: {pformat(config_dict)}")
    run_ping_pong_eval(args)

if __name__ == "__main__":
    hydra_main()
