# cirbench/utils/api/base.py
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any, Dict, List
from .gemini import GeminiRunner
from .qwen import QwenRunner
from .deepseek import DeepseekRunner
from .gpt import GPTRunner
from .claude import ClaudeRunner
from .grok import GrokRunner
from .llama import LlamaRunner

import re
import os

def select_model_cfg(cfg):
    kind = os.getenv("CIRBENCH_MODEL_KIND")
    name = os.getenv("CIRBENCH_MODEL_NAME")
    if kind and name:
        for m in (cfg.models or []):
            if getattr(m, "kind", None) == kind and getattr(m, "name", None) == name:
                return {"kind": m.kind, "name": m.name, "params": (m.params or {})}
        return {"kind": kind, "name": name, "params": {}}
    if cfg.models:
        m = cfg.models[0]
        return {"kind": m.kind, "name": m.name, "params": (m.params or {})}
    return {"kind":"rule","name":"golden","params":{}}

@dataclass
class Completion:
    text: str
    meta: dict = field(default_factory=dict)

class ModelRunner:
    def generate(self, prompts: List[str], **decode) -> List[Completion]:
        raise NotImplementedError

class RuleRunner(ModelRunner):
    def __init__(self, mode: str = "golden"):
        self.mode = mode
    def generate(self, prompts: List[str], **decode) -> List[Completion]:
        outs = []
        for p in prompts:
            m = re.search(r"<IR>(.*?)</IR>", p, re.S)
            ir = m.group(1) if m else ""
            if "objective: MINIMIZE CODE SIZE" in p:
                json_block = '{"objective":"os","edits":[{"site":"i0","kind":"simplifycfg"}]}'
            elif "objective: MAXIMIZE RUNTIME SPEED" in p:
                json_block = '{"objective":"o3","edits":[{"site":"i0","kind":"strength_reduce"}]}'
            elif "Repair inconsistent phi" in p:
                json_block = '{"fix":{"kind":"phi_repair","bb":"L1","strategy":"align_predecessors"}}'
            elif "Perform mem2reg promotion" in p:
                json_block = '{"refactor":{"kind":"mem2reg","allocas":["i5"]}}'
            else:
                json_block = '{"labels":[{"type":"alias","pair":["%p","%q"],"value":"May"}]}'
            meta = {
                "finish_reason": "stop",
                "prompt_tokens": None,
                "out_tokens": None,
                "total_tokens": None,
                "latency_ms": 0
            }
            outs.append(Completion(f"<CIR_JSON>{json_block}</CIR_JSON>\n<IR_OUT>\n{ir}\n</IR_OUT>", meta=meta))
        return outs
from . import gemini
from . import qwen
from . import deepseek
from . import gpt
from . import claude
from . import grok
from . import llama
REGISTRY = {
    "gemini": gemini,
    "qwen":   qwen,
    "deepseek":   deepseek,
    "gpt":   gpt,
    "claude":   claude,
    "grok":   grok,
    "llama":   llama,
}
def make_runner(model_cfg: Dict[str,Any]) -> ModelRunner:
    """
    Build a concrete runner from a model config dict.

    Accepted keys (for compatibility across callers / YAML / CLI override):
      - kind:      canonical provider id, e.g. "gemini", "qwen", "rule"
      - provider:  alias of 'kind' (CLI may use this naming)
      - name:      model name/slug (e.g. "qwen-plus", "gemini-2.5-flash")
      - params:    provider-specific kwargs (stop_sequences, max_output_tokens, ...)
    """
    # Accept both 'kind' and 'provider' as synonyms; default to 'rule'
    kind   = (model_cfg.get("kind") or model_cfg.get("provider") or "rule").strip().lower()
    params = dict(model_cfg.get("params") or {})
    name   = (model_cfg.get("name") or params.get("model") or "").strip() or None

    if kind == "rule":
        return RuleRunner(params.get("mode","golden"))

    if kind == "gemini":
        # GeminiRunner historically accepted only (params); some forks accept (name, params).
        try:
            return GeminiRunner(name=name, params=params)  # type: ignore[arg-type]
        except TypeError:
            return GeminiRunner(params)  # type: ignore[call-arg]

    if kind == "qwen":
        try:
            return QwenRunner(name, params)  # type: ignore[call-arg]
        except TypeError:
            return QwenRunner(params)        # type: ignore[call-arg]
        
    if kind == "deepseek":
        try:
            return DeepseekRunner(name, params)  # type: ignore[call-arg]
        except TypeError:
            return DeepseekRunner(params)        # type: ignore[call-arg]
        
    if kind == "gpt":
        try:
            return GPTRunner(name, params)  # type: ignore[call-arg]
        except TypeError:
            return GPTRunner(params)        # type: ignore[call-arg]
        
    if kind == "claude":
        try:
            return ClaudeRunner(name, params)  # type: ignore[call-arg]
        except TypeError:
            return ClaudeRunner(params)        # type: ignore[call-arg]
    
    if kind == "grok":
        try:
            return GrokRunner(name, params)  # type: ignore[call-arg]
        except TypeError:
            return GrokRunner(params)        # type: ignore[call-arg]
    
    if kind == "llama":
        try:
            return LlamaRunner(name, params)  # type: ignore[call-arg]
        except TypeError:
            return LlamaRunner(params)        # type: ignore[call-arg]

    class NotReady(ModelRunner):
        def generate(self, prompts: List[str], **decode) -> List[Completion]:
            raise RuntimeError(
                f"Model kind/provider '{kind}' not implemented. "
                "Available kinds: rule, gemini, qwen."
            )
    return NotReady()