import json
import logging
import os
import time
from typing import Callable, List
from pprint import pformat

import hydra
import numpy as np
import sglang as sgl
import yaml
from datasets import Dataset
from hydra.utils import get_original_cwd
from omegaconf import DictConfig, OmegaConf
from sglang.lang.ir import SglFunction
from transformers import AutoTokenizer

# Register custom resolvers
OmegaConf.register_new_resolver("basename", lambda path: os.path.basename(path))
OmegaConf.register_new_resolver("eval", eval)

from models.verifier import load_verifier
from search.amulet import amulet
from search.beam_search import beam_search
from search.best_of_n import best_of_n
from search.dvts import dvts
from search.residual_mppi import residual_mppi
from utils.data import get_dataset, normalize_dataset
from utils.evaluator import (
    math_verify_evaluate,
    evaluate_prefeval_explicit,
    evaluate_prefeval_implicit,
    evaluate_prefeval_choice,
    evaluate_ping_pong,
    evaluate_multifaceted_bench,
)
from utils.math import *
from utils.prompts import *
from utils.count_tokens import count_tokens

logging.basicConfig(level=logging.INFO)

logging.getLogger("openai").setLevel(logging.ERROR)
logging.getLogger("httpx").setLevel(logging.ERROR)
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


INFERENCE_METHODS = {
    "beam_search": beam_search,
    "best_of_n": best_of_n,
    "dvts": dvts,
    "residual_mppi": residual_mppi,
    "amulet": amulet,
}


def post_process_cfg(args):
    assert args.base_url is not None, "base_url is required"

    if args.model_name is None:
        args.model_name = os.path.basename(args.model_path)

    os.makedirs(args.output_dir, exist_ok=True)
    return args
    

def run_search(args):
    run_batch_seconds = 0.0
    if not args.eval_only and os.path.exists(args.result_path):
        logger.info(f"Result file {args.result_path} already exists. Automatically setting eval_only=True.")
        args.eval_only = True
        
    if args.eval_only:
        logger.info("Running in eval_only mode - loading existing results...")
        if not os.path.exists(args.result_path):
            raise FileNotFoundError(f"result_path {args.result_path} not found for eval_only mode")
        
        dataset = Dataset.from_json(args.result_path)
        logger.info(f"Loaded {len(dataset)} samples from {args.result_path}")
    else:
        logger.info("Initializing sglang backend...")
        sgl.set_default_backend(sgl.RuntimeEndpoint(base_url=args.base_url, api_key=args.api_key))

        logger.info(f"Loading dataset from {args.data.dataset_path}...")
        dataset: Dataset = get_dataset(args.data.dataset_path, args.data.dataset_split)
        tokenizer = AutoTokenizer.from_pretrained(args.model_path)
        if args.customized_chat_template is not None:
            if os.path.exists(args.customized_chat_template):
                tokenizer.chat_template = open(args.customized_chat_template).read()
            else:
                tokenizer.chat_template = args.customized_chat_template

        logger.info(f"Loading inference method '{args.method.approach}' and verifier...")    
        method: SglFunction = INFERENCE_METHODS[args.method.approach]
        verifier: Callable[[str, List[str]], List[List[float]]] = load_verifier(args.verifier, tokenizer)

        # Adjust to use nested `data` configuration
        if args.data.dataset_size is not None:
            dataset = dataset.select(range(args.data.dataset_size))
        else:
            args.data.dataset_size = len(dataset)
        if args.data.dataset_start is None:
            args.data.dataset_start = 0
        if args.data.dataset_end is None:
            args.data.dataset_end = len(dataset)
        dataset = dataset.select(range(args.data.dataset_start, args.data.dataset_end))
        dataset = dataset.select(range(args.process_id, len(dataset), args.num_processes))
        
        # Check if cached normalized dataset exists
        if os.path.exists(args.dataset_cache_path):
            logger.info(f"Loading cached normalized dataset from {args.dataset_cache_path}...")
            dataset = Dataset.from_json(args.dataset_cache_path)
        else:
            logger.info("Normalizing dataset...")
            dataset = dataset.map(
                normalize_dataset,
                num_proc=32,
                desc="Normalizing dataset",
                with_indices=True,
                fn_kwargs={"args": args.data, "tokenizer": tokenizer},
            )
            # Save normalized dataset to cache
            logger.info(f"Saving normalized dataset to cache: {args.dataset_cache_path}")
            os.makedirs(os.path.dirname(args.dataset_cache_path), exist_ok=True)
            dataset.to_json(args.dataset_cache_path, lines=True)
        
        print(f"Dataset keys: {dataset.column_names}")
        
        prompts = dataset["_prompt"]
        params_list = [dict(prompt=prompt, sample=sample, verifier=verifier, tokenizer=tokenizer, args=args.method) for sample, prompt in zip(dataset, prompts)]
        _t0 = time.perf_counter()
        states = method.run_batch(params_list, num_threads=args.max_threads, progress_bar=True)
        run_batch_seconds = time.perf_counter() - _t0

        logger.info("Processing results...")

        result_keys = states[0].stream_executor.variables.keys()
        for key in result_keys:
            dataset = dataset.add_column(key, [state[key] for state in states])

        # token counting
        dataset = dataset.map(
            count_tokens,
            num_proc=32,
            desc="Counting tokens",
            fn_kwargs={"tokenizer": tokenizer},
        )

        # save results ahead of evaluation
        dataset.to_json(args.result_path, lines=True)
    
    if args.run_eval:
        if args.data.dataset_type == "math":
            dataset = dataset.map(prediction_evaluate, num_proc=32, desc="Evaluating math")
            dataset = dataset.map(majority_voting_evaluate, num_proc=32, desc="Evaluating math")
            dataset = dataset.map(best_score_evaluate, num_proc=32, desc="Evaluating math")
            dataset = dataset.map(weighted_sum_evaluate, num_proc=32, desc="Evaluating math")
        elif args.data.dataset_type == "prefeval_explicit":
            dataset = dataset.map(evaluate_prefeval_explicit, num_proc=32, desc="Evaluating prefeval explicit", fn_kwargs={"model_name": args.data.llm_as_a_judge_model})
        elif args.data.dataset_type == "prefeval_choice":
            dataset = dataset.map(evaluate_prefeval_choice, num_proc=32, desc="Evaluating prefeval choice")
        elif args.data.dataset_type == "prefeval_implicit":
            dataset = dataset.map(evaluate_prefeval_implicit, num_proc=32, desc="Evaluating prefeval implicit", fn_kwargs={"model_name": args.data.llm_as_a_judge_model})
        elif args.data.dataset_type == "multifaceted":
            dataset = dataset.map(evaluate_multifaceted_bench, num_proc=32, desc="Evaluating multifaceted bench", fn_kwargs={"model_name": args.data.llm_as_a_judge_model})
        else:
            raise ValueError(f"Unsupported dataset type: {args.data.dataset_type}")

    metrics = dataset["metric"]
    metric = {key: np.mean([m[key] for m in metrics]) for key in metrics[0].keys()}
    metric["num_samples"] = len(dataset)
    metric["run_batch_seconds"] = run_batch_seconds
    
    if metric["num_samples"] > 0:
        metric["run_batch_seconds_per_sample"] = run_batch_seconds / metric["num_samples"]
    else:
        metric["run_batch_seconds_per_sample"] = 0.0
    # Aggregate token usage metrics if available
    if "prompt_tokens" in dataset.column_names:
        metric["avg_prompt_tokens"] = float(np.mean(dataset["prompt_tokens"]))
    if "total_tokens" in dataset.column_names:
        metric["avg_total_tokens"] = float(np.mean(dataset["total_tokens"]))
    # Average completion tokens per response across all outputs
    if "outputs_tokens_sum" in dataset.column_names and "outputs" in dataset.column_names:
        outputs_tokens_sum_list = dataset["outputs_tokens_sum"]
        outputs_list = dataset["outputs"]
        total_responses = sum(len(outs or []) for outs in outputs_list)
        tokens_sum = sum(outputs_tokens_sum_list)
        metric["avg_response_tokens"] = (tokens_sum / total_responses) if total_responses > 0 else 0.0
    elif "prediction_tokens" in dataset.column_names:
        # Fallback: use prediction tokens as a proxy for completion length
        metric["avg_response_tokens"] = float(np.mean(dataset["prediction_tokens"]))
    else:
        metric["avg_response_tokens"] = 0.0

    dataset.to_json(args.result_path, lines=True)
    json.dump(metric, open(args.metric_path, "w"), ensure_ascii=False, indent=4)
    
    logger.info(f"Done! Metrics: {metric}")


@hydra.main(version_base=None, config_path="configs", config_name="config")
def hydra_main(cfg: DictConfig):
    """Hydra entry point wrapping the original main logic."""
    # Ensure we operate from the original working directory to keep relative paths unchanged.
    os.chdir(get_original_cwd())

    # Post-process derived parameters / defaults (mutates in-place).
    args = post_process_cfg(cfg)
    print(f"Experiment output directory: {args.output_dir}")

    # Persist the resolved config for reproducibility.
    with open(os.path.join(args.output_dir, "config.yaml"), "w") as f:
        config_dict = OmegaConf.to_container(args, resolve=True)
        yaml.dump(config_dict, f, default_flow_style=False)

    logger.info(f"Running search with config: {pformat(config_dict)}")
    run_search(args)


if __name__ == "__main__":
    hydra_main()
