# Step-aware lm-eval, to be used only to intercept additional stats from the diffusion sampler
# Requires: pip install matplotlib
import os

os.environ["HF_ALLOW_CODE_EVAL"] = "1"

import numpy as np
from lm_eval import simple_evaluate
from lm_eval.api.registry import get_model
import argparse
import json
from lm_eval.utils import make_table, handle_non_serializable
from lm_eval.loggers import EvaluationTracker
from lm_eval.tasks import TaskManager

from pathlib import Path
import sys
import time

####### oh I really stopped caring in this part
try:
    import modeling  # noqa: F401
except ModuleNotFoundError:
    try:
        wd = Path.cwd()
        sys.path.append(str(wd))
        import modeling  # noqa: F401
    except ModuleNotFoundError:
        wd = Path.cwd().parent
        sys.path.append(str(wd))
        import modeling  # noqa: F401


#########################
def handle_arg_string(arg):
    if arg.lower() == "true":
        return True
    elif arg.lower() == "false":
        return False
    elif arg.isnumeric():
        return int(arg)
    try:
        return float(arg)
    except ValueError:
        return arg


def simple_parse_args_string(args_string) -> dict:
    """
    Parses something like
        args1=val1,arg2=val2
    Into a dictionary
    """
    if args_string is None:
        return {}
    args_string = args_string.strip()
    if not args_string:
        return {}
    arg_list = [arg for arg in args_string.split(",") if arg]
    args_dict = {kv[0]: handle_arg_string("=".join(kv[1:])) for kv in [arg.split("=") for arg in arg_list]}
    return args_dict


def evaluate_with_steps(model, tasks, model_args, task_manager, **kwargs):
    """Run lm-eval + capture adaptive computation steps in one pass."""

    stats_per_query = []

    # Store original generate method from the underlying Huginn model
    huginn_model = model._model  # HFLM stores HF model in _model
    original_generate = huginn_model.generate

    def stats_tracking_generate(*args, **gen_kwargs):
        # Ensure we get step info back
        model._model.generation_config.return_dict_in_generate = True

        # Call original generate
        rough_time_per_sample = time.time()
        output = huginn_model.generate_diffusion_style(*args, **gen_kwargs)

        if not output.scores:
            raise ValueError("No scores recorded, is this an adaptive generation?")
        # Extract step counts
        output.scores["time_per_sample"] = time.time() - rough_time_per_sample
        stats_per_query.append(output.scores)

        return output.sequences

    # Monkey patch during evaluations
    huginn_model.generate = stats_tracking_generate

    try:
        # Run normal lm-eval
        results = simple_evaluate(
            model, model_args=model_args, tasks=tasks, task_manager=task_manager, confirm_run_unsafe_code=True, **kwargs
        )

        # Add summary statistics to results
        # for task in tasks:
        # if "gsm8k" in task:
        #     for sample in stats_per_query:
        #         sample["gsm8k_acc_per_sample"] = sample["exact_match"]
        summary_stats = {}
        for key in stats_per_query[0]:
            if key in ["recurrence_per_position", "token_stable_per_position"]:  # tensor stats
                summary_stats[f"{key}_mean"] = np.mean(
                    [scores[key].float().mean().item() for scores in stats_per_query]
                )
                summary_stats[f"{key}_median"] = np.median([scores[key].median().item() for scores in stats_per_query])
                summary_stats[f"{key}_max"] = np.max([scores[key].max().item() for scores in stats_per_query])
                summary_stats[f"{key}_min"] = np.min([scores[key].min().item() for scores in stats_per_query])
            else:
                summary_stats[f"{key}_mean"] = np.mean([scores[key] for scores in stats_per_query])
                summary_stats[f"{key}_median"] = np.median([scores[key] for scores in stats_per_query])
                summary_stats[f"{key}_max"] = np.max([scores[key] for scores in stats_per_query])
                summary_stats[f"{key}_min"] = np.min([scores[key] for scores in stats_per_query])

        # Add to each task's results
        if results is not None:  # None on other workers?
            for task_name in results["results"]:  # type: ignore
                results["results"][task_name].update(summary_stats)  # type: ignore
    finally:
        # Restore original method
        huginn_model.generate = original_generate
    # remove to reduce size of json
    for scores in stats_per_query:
        del scores["recurrence_per_position"]
        del scores["token_stable_per_position"]
    return results, stats_per_query


def parse_gen_kwargs(gen_kwargs_str):
    """Parse gen_kwargs string like 'key1=val1,key2=val2' into dict."""
    if not gen_kwargs_str:
        return {}

    kwargs = {}
    for pair in gen_kwargs_str.split(","):
        if "=" in pair:
            key, value = pair.split("=", 1)
            # Try to convert to appropriate type
            if value.lower() == "true":
                value = True
            elif value.lower() == "false":
                value = False
            elif value.replace(".", "").replace("-", "").replace("e", "").isdigit():
                value = float(value) if "." in value or "e" in value.lower() else int(value)
            kwargs[key] = value
    return kwargs


# Usage script - replaces your command line call
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="lm-eval with adaptive computation step tracking")
    parser.add_argument("--model", default="hf", help="Model type")
    parser.add_argument("--model_args", required=True, help="Model arguments (comma-separated)")
    parser.add_argument("--tasks", required=True, help="Tasks to evaluate (comma-separated)")
    parser.add_argument("--batch_size", type=int, default=1, help="Batch size")
    parser.add_argument("--num_fewshot", type=int, default=0, help="Number of few-shot examples")
    parser.add_argument("--output_path", help="Output directory")
    parser.add_argument("--apply_chat_template", action="store_true", help="Apply chat template")
    parser.add_argument("--system_instruction", help="System instruction")
    parser.add_argument("--fewshot_as_multiturn", action="store_true", help="Few-shot as multiturn")
    parser.add_argument("--gen_kwargs", help="Generation kwargs (comma-separated)")
    parser.add_argument("--confirm_run_unsafe_code", action="store_true", help="Dummy arg, always turned on.")
    parser.add_argument("--limit", type=int, help="Limit number of examples")
    parser.add_argument(
        "--include_path",
        type=str,
        default=None,
        metavar="DIR",
        help="Additional path to include if there are external tasks to include.",
    )
    parser.add_argument(
        "--metadata",
        type=json.loads,
        default=None,
        help="""JSON string metadata to pass to task configs, for example '{"max_seq_lengths":[4096,8192]}'. Will be merged with model_args. Can also be set in task config.""",
    )
    args = parser.parse_args()

    # Create model
    model = get_model(args.model).create_from_arg_string(args.model_args)

    # Parse tasks
    tasks = [task.strip() for task in args.tasks.split(",")]
    metadata = (
        simple_parse_args_string(args.model_args)
        if isinstance(args.model_args, str)
        else args.model_args
        if isinstance(args.model_args, dict)
        else {}
    ) | (args.metadata if isinstance(args.metadata, dict) else simple_parse_args_string(args.metadata))

    task_manager = TaskManager(include_path=args.include_path, metadata=metadata)

    # Parse gen_kwargs
    gen_kwargs = parse_gen_kwargs(args.gen_kwargs)

    # Build evaluation kwargs
    eval_kwargs = {
        "batch_size": args.batch_size,
        "num_fewshot": args.num_fewshot,
        "gen_kwargs": gen_kwargs,
        "log_samples": False,
    }

    if args.apply_chat_template:
        eval_kwargs["apply_chat_template"] = True
    if args.system_instruction:
        if args.system_instruction == "$HUGINN_SYS":
            eval_kwargs["system_instruction"] = (
                "You are a helpful assistant that can assist users with mathematical reasoning."
            )
        else:
            eval_kwargs["system_instruction"] = args.system_instruction
    if args.fewshot_as_multiturn:
        eval_kwargs["fewshot_as_multiturn"] = True
    if args.limit:
        eval_kwargs["limit"] = args.limit

    evaluation_tracker = EvaluationTracker(output_path=args.output_path)
    eval_kwargs["evaluation_tracker"] = evaluation_tracker

    # Run evaluation
    results, per_sample_stats = evaluate_with_steps(
        model=model, tasks=tasks, model_args=args.model_args, task_manager=task_manager, **eval_kwargs
    )

    # Print summary
    if results is not None:
        batch_sizes = ",".join(map(str, results["config"]["batch_sizes"]))
        print(
            f"{args.model} ({args.model_args}), gen_kwargs: ({args.gen_kwargs}), limit: {args.limit}, num_fewshot: {args.num_fewshot}, "
            f"batch_size: {args.batch_size}{f' ({batch_sizes})' if batch_sizes else ''}"
        )
        print(make_table(results))
        if "groups" in results:  # type: ignore
            print(make_table(results, "groups"))

    # Steal lm-eval printer:
    if results is not None:
        dumped = json.dumps(results, indent=2, default=handle_non_serializable, ensure_ascii=False)
        evaluation_tracker.save_results_aggregated(results=results, samples=None)  # type: ignore
        full_path = f"{args.output_path}/{evaluation_tracker.general_config_tracker.model_name_sanitized}"
        # Save all step counts in a single json
        if full_path:
            with open(f"{full_path}/step_counts_{evaluation_tracker.date_id}.json", "w") as f:
                json.dump(per_sample_stats, f)


# To run (replaces your original command):
"""
CUDA_VISIBLE_DEVICES=2,3,4,5 accelerate launch evaluate_raven/record_steps_in_bench.py \
  --model hf \
  --model_args "pretrained=tomg-group-umd/huginn_swa_75_7_ema_0.9_merge,trust_remote_code=True,dtype=bfloat16,mean_recurrence=32" \
  --tasks gsm8k_cot \
  --batch_size 1 \
  --num_fewshot 8 \
  --output_path outputs/step_counting \
  --apply_chat_template \
  --system_instruction "You are a helpful assistant that can assist users with mathematical reasoning." \
  --fewshot_as_multiturn \
  --gen_kwargs "criterion=entropy-diff,exit_threshold=1e-2,cache_lookup_strategy=latest-m4-compress-s16"
"""

"""
python CUDA_VISIBLE_DEVICES=1 evaluate_raven/record_steps_in_bench.py \
  --model hf \
  --model_args "pretrained=tomg-group-umd/huginn_swa_75_7_ema_0.9_merge,trust_remote_code=True,dtype=bfloat16,mean_recurrence=32" \
  --tasks gsm8k_cot
"""
