import json
import logging
import os
from pprint import pformat

import hydra
import numpy as np
import yaml
from datasets import Dataset
from hydra.utils import get_original_cwd
from omegaconf import DictConfig, OmegaConf

# Register custom resolvers
OmegaConf.register_new_resolver("basename", lambda path: os.path.basename(path))
OmegaConf.register_new_resolver("eval", eval)

from utils.evaluator import (
    math_verify_evaluate,
    evaluate_prefeval_explicit,
    evaluate_prefeval_implicit,
    evaluate_prefeval_choice,
    evaluate_ping_pong,
    evaluate_multifaceted_bench,
)
from utils.math import *

logging.basicConfig(level=logging.INFO)

logging.getLogger("openai").setLevel(logging.ERROR)
logging.getLogger("httpx").setLevel(logging.ERROR)
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


def post_process_cfg(args):
    assert args.eval_result_path is not None, "eval_result_path is required"

    os.makedirs(args.output_dir, exist_ok=True)
    return args
    

def run_eval(args):
    logger.info("Running in eval_only mode - loading existing results...")
    if not os.path.exists(args.eval_result_path):
        raise FileNotFoundError(f"eval_result_path {args.eval_result_path} not found")
    
    dataset = Dataset.from_json(args.eval_result_path)
    logger.info(f"Loaded {len(dataset)} samples from {args.eval_result_path}")

    if "metric" not in dataset.column_names:
        dataset = dataset.add_column("metric", [{} for _ in range(len(dataset))])

    if "outputs" in dataset.column_names:
        def remap_prediction(example):
            n_samples = args.method.n_samples
            example["prediction"] = example["outputs"][np.argmax(example["agg_scores"][:n_samples])]
            return example
        dataset = dataset.map(remap_prediction, num_proc=32, desc="Remapping prediction")
    
    if args.run_eval:
        if args.data.dataset_type == "math":
            dataset = dataset.map(prediction_evaluate, num_proc=32, desc="Evaluating math")
            dataset = dataset.map(majority_voting_evaluate, num_proc=32, desc="Evaluating math")
            dataset = dataset.map(best_score_evaluate, num_proc=32, desc="Evaluating math")
            dataset = dataset.map(weighted_sum_evaluate, num_proc=32, desc="Evaluating math")
        elif args.data.dataset_type == "prefeval_explicit":
            dataset = dataset.map(evaluate_prefeval_explicit, num_proc=32, desc="Evaluating prefeval explicit", fn_kwargs={"model_name": args.data.llm_as_a_judge_model})
        elif args.data.dataset_type == "prefeval_choice":
            dataset = dataset.map(evaluate_prefeval_choice, num_proc=32, desc="Evaluating prefeval choice")
        elif args.data.dataset_type == "prefeval_implicit":
            dataset = dataset.map(evaluate_prefeval_implicit, num_proc=32, desc="Evaluating prefeval implicit", fn_kwargs={"model_name": args.data.llm_as_a_judge_model})
        elif args.data.dataset_type == "multifaceted":
            dataset = dataset.map(evaluate_multifaceted_bench, num_proc=32, desc="Evaluating multifaceted bench", fn_kwargs={"model_name": args.data.llm_as_a_judge_model})
        else:
            raise ValueError(f"Unsupported dataset type: {args.data.dataset_type}")
    
    metrics = dataset["metric"]
    metric = {key: np.mean([m[key] for m in metrics]) for key in metrics[0].keys()}
    metric["num_samples"] = len(dataset)
    metric["run_batch_seconds"] = 0.0
    
    if metric["num_samples"] > 0:
        metric["run_batch_seconds_per_sample"] = 0.0
    else:
        metric["run_batch_seconds_per_sample"] = 0.0
    # Aggregate token usage metrics if available
    if "prompt_tokens" in dataset.column_names:
        metric["avg_prompt_tokens"] = float(np.mean(dataset["prompt_tokens"]))
    if "total_tokens" in dataset.column_names:
        metric["avg_total_tokens"] = float(np.mean(dataset["total_tokens"]))
    # Average completion tokens per response across all outputs
    if "outputs_tokens_sum" in dataset.column_names and "outputs" in dataset.column_names:
        outputs_tokens_sum_list = dataset["outputs_tokens_sum"]
        outputs_list = dataset["outputs"]
        total_responses = sum(len(outs or []) for outs in outputs_list)
        tokens_sum = sum(outputs_tokens_sum_list)
        metric["avg_response_tokens"] = (tokens_sum / total_responses) if total_responses > 0 else 0.0
    elif "prediction_tokens" in dataset.column_names:
        # Fallback: use prediction tokens as a proxy for completion length
        metric["avg_response_tokens"] = float(np.mean(dataset["prediction_tokens"]))
    else:
        metric["avg_response_tokens"] = 0.0

    dataset.to_json(args.result_path, lines=True)
    json.dump(metric, open(args.metric_path, "w"), ensure_ascii=False, indent=4)
    
    logger.info(f"Done! Metrics: {metric}")


@hydra.main(version_base=None, config_path="configs", config_name="config")
def hydra_main(cfg: DictConfig):
    """Hydra entry point wrapping the original main logic."""
    # Ensure we operate from the original working directory to keep relative paths unchanged.
    os.chdir(get_original_cwd())

    # Post-process derived parameters / defaults (mutates in-place).
    args = post_process_cfg(cfg)
    print(f"Experiment output directory: {args.output_dir}")

    # Persist the resolved config for reproducibility.
    with open(os.path.join(args.output_dir, "config.yaml"), "w") as f:
        config_dict = OmegaConf.to_container(args, resolve=True)
        yaml.dump(config_dict, f, default_flow_style=False)

    logger.info(f"Running search with config: {pformat(config_dict)}")
    run_eval(args)


if __name__ == "__main__":
    hydra_main()
