#!/usr/bin/env python3
"""Convenience wrapper around lm-evaluation-harness for MolecularIQ."""

from __future__ import annotations

import argparse
import ast
import json
import os
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional

import yaml

from lm_eval.evaluator import simple_evaluate
from lm_eval.tasks import TaskManager


REPO_ROOT = Path(__file__).resolve().parent
CHEM_CONFIG_DIR = REPO_ROOT / "lm_eval" / "tasks" / "chemsets" / "model_configs"
DEFAULT_CONFIG_DIR = REPO_ROOT / "configs"


def _load_yaml_config(path: Path) -> Dict[str, Any]:
    if not path.exists():
        raise FileNotFoundError(f"Model config not found: {path}")
    with path.open("r", encoding="utf-8") as handle:
        return yaml.safe_load(handle) or {}


def _coerce_value(raw: str) -> Any:
    text = raw.strip()
    lower = text.lower()
    if lower in {"none", "null"}:
        return None
    if lower in {"true", "false"}:
        return lower == "true"
    try:
        return ast.literal_eval(text)
    except (ValueError, SyntaxError):
        return text


def _apply_override(config: Dict[str, Any], dotted_key: str, value: Any) -> None:
    keys = dotted_key.split(".")
    target = config
    for key in keys[:-1]:
        if key not in target or not isinstance(target[key], dict):
            target[key] = {}
        target = target[key]
    target[keys[-1]] = value


def _parse_overrides(items: Iterable[str]) -> Dict[str, Any]:
    overrides: Dict[str, Any] = {}
    for entry in items:
        if "=" not in entry:
            raise ValueError(f"Override must be in KEY=VALUE format, got '{entry}'")
        key, raw_value = entry.split("=", 1)
        overrides[key.strip()] = _coerce_value(raw_value)
    return overrides


def _first_non_none(*values: Any) -> Any:
    for value in values:
        if value is not None:
            return value
    return None


def _finalize_tasks(
    cli_tasks: List[str] | None,
    eval_config: Dict[str, Any],
    config: Dict[str, Any],
    chem_defaults: Optional[Dict[str, Any]] = None,
) -> List[str]:
    if cli_tasks:
        collected: List[str] = []
        for value in cli_tasks:
            collected.extend([task.strip() for task in value.split(",") if task.strip()])
        return collected

    configured = _first_non_none(
        eval_config.get("tasks"),
        config.get("tasks"),
        (chem_defaults or {}).get("tasks"),
    )
    if isinstance(configured, str):
        return [configured]
    if isinstance(configured, (list, tuple)):
        return list(configured)
    raise ValueError("No tasks provided. Supply --task or add a 'tasks' entry to the config file.")


def _build_eval_kwargs(
    config: Dict[str, Any],
    chem_defaults: Optional[Dict[str, Any]],
    args: argparse.Namespace,
) -> Dict[str, Any]:
    eval_config: Dict[str, Any] = config.get("eval") or config.get("evaluation") or config

    model = _first_non_none(eval_config.get("model"), config.get("model"))
    if model is None:
        raise KeyError(
            "Model backend must be specified either at the top level of the config or under the 'eval' section."
        )

    model_args = _first_non_none(eval_config.get("model_args"), config.get("model_args"))
    if model_args is None:
        pretrained = _first_non_none(
            eval_config.get("model_path"),
            config.get("model_path"),
            (chem_defaults or {}).get("model_path"),
        )
        if pretrained is not None:
            model_args = {"pretrained": pretrained}
        else:
            model_args = {}

    eval_kwargs: Dict[str, Any] = {
        "model": model,
        "model_args": model_args,
        "tasks": _finalize_tasks(args.task, eval_config, config, chem_defaults),
        "task_manager": TaskManager(include_path=_first_non_none(
            eval_config.get("include_path"), config.get("include_path")
        )),
    }

    # Optional evaluation settings from config or CLI overrides
    option_mapping = {
        "num_fewshot": args.num_fewshot,
        "batch_size": args.batch_size,
        "max_batch_size": args.max_batch_size,
        "limit": args.limit,
    }

    for key, cli_value in option_mapping.items():
        value = _first_non_none(
            cli_value,
            eval_config.get(key),
            config.get(key),
            (chem_defaults or {}).get(key),
        )
        if value is not None:
            eval_kwargs[key] = value

    # Generation and prompting options
    for key in [
        "gen_kwargs",
        "system_instruction",
        "apply_chat_template",
        "fewshot_as_multiturn",
        "chat_template_args",
        "device",
        "use_cache",
        "cache_requests",
        "rewrite_requests_cache",
        "delete_requests_cache",
        "log_samples",
        "write_out",
        "metadata",
    ]:
        value = _first_non_none(
            eval_config.get(key),
            config.get(key),
            (chem_defaults or {}).get(key),
        )
        if value is not None:
            eval_kwargs[key] = value

    if "system_instruction" not in eval_kwargs:
        system_prompt = _first_non_none(
            eval_config.get("system_prompt"),
            config.get("system_prompt"),
            (chem_defaults or {}).get("system_instruction"),
            (chem_defaults or {}).get("system_prompt"),
        )
        if system_prompt is not None:
            eval_kwargs["system_instruction"] = system_prompt

    return eval_kwargs


def _resolve_chemistry_config(chem_ref: str, base_path: Path) -> Path:
    chem_path = Path(chem_ref).expanduser()
    if chem_path.suffix.lower() != ".yaml":
        candidate = chem_path.with_suffix(".yaml")
        if candidate.exists():
            chem_path = candidate
    if chem_path.exists():
        if chem_path.is_dir():
            raise FileNotFoundError(
                f"Chemistry config reference '{chem_ref}' points to a directory; expected a YAML file."
            )
        return chem_path.resolve()

    # look relative to provided config directory
    relative_candidate = (base_path / chem_ref).with_suffix(".yaml")
    if relative_candidate.exists():
        return relative_candidate.resolve()

    fallback_new = DEFAULT_CONFIG_DIR / f"{chem_ref}.yaml"
    if fallback_new.exists():
        return fallback_new.resolve()

    fallback = CHEM_CONFIG_DIR / f"{chem_ref}.yaml"
    if fallback.exists():
        return fallback.resolve()

    raise FileNotFoundError(
        f"Unable to locate chemistry model config for reference '{chem_ref}'."
    )


def _load_chemistry_defaults(chem_ref: str, base_path: Path) -> Dict[str, Any]:
    chem_path = _resolve_chemistry_config(chem_ref, base_path)
    return _load_yaml_config(chem_path), chem_path


def _setup_environment(config: Dict[str, Any], config_path: Path) -> Dict[str, Any]:
    chem_config = config.get("chem_model_config")
    if chem_config is None and config_path.parent in {CHEM_CONFIG_DIR, DEFAULT_CONFIG_DIR}:
        chem_config = config_path.stem
        config["chem_model_config"] = chem_config
    if chem_config is None:
        raise KeyError(
            "'chem_model_config' must be specified when the model config resides outside the configured search paths (e.g. configs/)."
        )

    chem_defaults, chem_path = _load_chemistry_defaults(chem_config, config_path.parent)
    os.environ["LM_EVAL_MODEL_CONFIG"] = str(chem_path)

    disable_inline = config.get("disable_inline_prompt")
    if isinstance(disable_inline, bool):
        os.environ["LM_EVAL_DISABLE_INLINE_PROMPT"] = str(disable_inline).lower()
    else:
        os.environ.pop("LM_EVAL_DISABLE_INLINE_PROMPT", None)

    return chem_defaults


def _create_run_directory(base_dir: Path, config_path: Path) -> Path:
    timestamp = datetime.utcnow().strftime("%Y%m%d-%H%M%S")
    run_name = f"{config_path.stem}_{timestamp}"
    run_dir = base_dir / run_name
    run_dir.mkdir(parents=True, exist_ok=True)
    return run_dir


def _save_run_artifacts(run_dir: Path, config: Dict[str, Any], overrides: Dict[str, Any], results: Dict[str, Any]) -> None:
    (run_dir / "metrics.json").write_text(
        json.dumps(results, indent=2, default=str), encoding="utf-8"
    )
    (run_dir / "config_used.yaml").write_text(
        yaml.safe_dump(config, sort_keys=False), encoding="utf-8"
    )
    if overrides:
        (run_dir / "overrides.json").write_text(
            json.dumps(overrides, indent=2), encoding="utf-8"
        )


def _print_summary(results: Dict[str, Any]) -> None:
    summary = results.get("results", {})
    if not summary:
        print("No metrics returned. Check model/task settings.")
        return
    print("\n=== Evaluation Summary ===")
    for task_name, metrics in summary.items():
        print(f"Task: {task_name}")
        for metric_name, value in metrics.items():
            print(f"  {metric_name}: {value}")


def main() -> None:
    parser = argparse.ArgumentParser(description="Evaluate a model on MolecularIQ tasks.")
    parser.add_argument("--model_config", required=True, type=Path, help="Path to the model YAML config.")
    parser.add_argument(
        "--task",
        action="append",
        default=None,
        help="Task name(s) to evaluate. Can be provided multiple times or as a comma-separated list.",
    )
    parser.add_argument("--output_dir", default=Path("results"), type=Path, help="Directory to store evaluation outputs.")
    parser.add_argument("--num_fewshot", type=int, default=None, help="Override number of few-shot examples.")
    parser.add_argument("--batch_size", type=str, default=None, help="Override evaluation batch size.")
    parser.add_argument("--max_batch_size", type=int, default=None, help="Override max batch size when using auto batching.")
    parser.add_argument("--limit", type=float, default=None, help="Optional limit per task (int or float).")
    parser.add_argument(
        "--set",
        dest="overrides",
        action="append",
        default=[],
        help="Override config values using dotted keys, e.g. --set model_args.tensor_parallel_size=2",
    )

    args = parser.parse_args()

    config_path = args.model_config.resolve()
    config = _load_yaml_config(config_path)

    if "model" not in config:
        raise KeyError("Config file must define a 'model' entry identifying the lm-eval model backend.")

    overrides = _parse_overrides(args.overrides)
    for key, value in overrides.items():
        _apply_override(config, key, value)

    chem_defaults = _setup_environment(config, config_path)
    eval_kwargs = _build_eval_kwargs(config, chem_defaults, args)

    results = simple_evaluate(**eval_kwargs)

    output_dir = Path(args.output_dir).expanduser().resolve()
    run_dir = _create_run_directory(output_dir, config_path)
    _save_run_artifacts(run_dir, config, overrides, results)
    _print_summary(results)
    print(f"\nResults written to: {run_dir}")


if __name__ == "__main__":
    main()
