import os, json, random
from pathlib import Path
from typing import List, Dict, Any
from collections import Counter
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
from huggingface_hub import InferenceClient
from requests import HTTPError

DATASET_DIR  = Path(".../dataset")
OUTPUT_ROOT  = Path(".../output")
DATASET_FILES = [
    "linear_algebra.json",
    "single_variable_calculus.json",
    "multivariable_calculus.json",
    "differential_equations.json",
    "discrete_math.json",
    "trigonometry.json",
    "pre_calculus.json",
]

MAX_TOKENS = 2048
SC_SAMPLES = 5
TOT_PATHS  = 3
SAMPLE_SIZE = None

STRATEGIES = ["zero_shot", "few_shot_cot", "tree_of_thought", "self_consistency"]
TEMPS = {"zero_shot":0.0,"few_shot_cot":0.0,"tree_of_thought":0.7,"self_consistency":0.9}

HF_TOKEN = os.getenv("HF_TOKEN")
if not HF_TOKEN:
    raise RuntimeError("HF_TOKEN is not set")
HF_PROVIDER = os.getenv("HF_INFERENCE_PROVIDER")

MODELS = [
    {"name": "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B",  "provider": "hf"},
    {"name": "google/gemma-2-9b-it",                      "provider": "hf"},
    {"name": "meta-llama/Meta-Llama-3-8B-Instruct",       "provider": "hf"},
    {"name": "meta-llama/Meta-Llama-3-70B-Instruct",      "provider": "hf"},
    {"name": "meta-llama/Llama-4-Scout-17B-16E-Instruct", "provider": "hf"},
    {"name": "Qwen/Qwen2.5-7B-Instruct",                  "provider": "hf"},
    {"name": "mistralai/Mistral-7B-Instruct-v0.3",        "provider": "hf"},
]

def build_zero_shot_prompt(problem: str):
    return [{"role":"system","content":"Conclude the final answer in the form: \\boxed{your final answer here}."},
            {"role":"user","content":f"Solve the following math problem: {problem}"}]

def build_few_shot_cot_prompt(problem: str, few_shot_examples: List[Dict[str, Any]]):
    system = "You are a highly skilled mathematics expert. Solve the problem step by step. Conclude with \\boxed{your final answer here}."
    fs_blocks = []
    for ex in few_shot_examples:
        q = ex.get("question","")
        steps = ex.get("steps",[]) or []
        ans = ex.get("answer","")
        fs_blocks.append(f"Q: {q}\n\nA: " + "\n".join(steps) + f"\n\n\\boxed{{{ans}}}\n")
    user = "\n".join(fs_blocks) + f"\nQ: {problem}\n\nA:"
    return [{"role":"system","content":system},{"role":"user","content":user}]

def build_tot_prompt(problem: str):
    return [{"role":"system","content":"You are a highly skilled mathematics expert. Brainstorm multiple distinct solution paths. End with \\boxed{final answer}."},
            {"role":"user","content":problem}]

def build_self_consistency_prompt(problem: str):
    return [{"role":"system","content":"You are a highly skilled mathematics expert. Solve with clear reasoning. End with \\boxed{final answer}."},
            {"role":"user","content":problem}]

def pick_few_shot_examples(dataset, target_item, k=3):
    subtopic = target_item.get("subtopic")
    qid = str(target_item.get("id",""))
    cands = [ex for ex in dataset if ex.get("subtopic")==subtopic and str(ex.get("id",""))!=qid and ex.get("steps") and ex.get("answer")]
    return random.sample(cands, min(k, len(cands)))

class TransientAPIError(Exception): pass

def _flatten_messages(messages: List[Dict[str,str]]) -> str:
    parts = []
    sys = [m.get("content","") for m in messages if m.get("role")=="system" and isinstance(m.get("content"), str)]
    if sys: parts.append("System:\n" + "\n".join(sys))
    for m in messages:
        role = m.get("role"); content = m.get("content")
        if isinstance(content, list):
            content = "\n".join(c.get("text","") for c in content if isinstance(c, dict) and c.get("type")=="input_text")
        if role=="user" and isinstance(content, str): parts.append("User:\n" + content)
        elif role=="assistant" and isinstance(content, str): parts.append("Assistant:\n" + content)
    parts.append("Assistant:")
    return "\n\n".join(parts)

def _make_client(model_id: str) -> InferenceClient:
    return InferenceClient(model=model_id, token=HF_TOKEN, timeout=180, provider=HF_PROVIDER)

@retry(reraise=True, stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, min=1, max=10), retry=retry_if_exception_type(TransientAPIError))
def call_model(provider: str, model_id: str, messages, temperature: float, max_tokens: int) -> str:
    try:
        if provider != "hf":
            raise ValueError(f"Unknown provider {provider}")
        client = _make_client(model_id)
        try:
            resp = client.chat_completion(messages=messages, max_tokens=max_tokens, temperature=temperature)
            if hasattr(resp, "choices"):
                return resp.choices[0].message["content"].strip()
            return resp["choices"][0]["message"]["content"].strip()
        except HTTPError as e:
            if "supported task: text-generation" not in str(e).lower():
                raise TransientAPIError(str(e)) from e
        prompt = _flatten_messages(messages)
        out = client.text_generation(prompt, max_new_tokens=max_tokens, temperature=temperature, do_sample=(temperature>0), return_full_text=False)
        return out.strip()
    except TransientAPIError:
        raise
    except Exception as e:
        raise TransientAPIError(str(e)) from e

def run_zero_shot(itm, m):
    return call_model(m["provider"], m["name"], build_zero_shot_prompt(itm["question"]), TEMPS["zero_shot"], MAX_TOKENS)

def run_few_shot_cot(itm, m, dataset):
    fs = pick_few_shot_examples(dataset, itm, k=3)
    return call_model(m["provider"], m["name"], build_few_shot_cot_prompt(itm["question"], fs), TEMPS["few_shot_cot"], MAX_TOKENS)

def run_tree_of_thought(itm, m):
    msgs = build_tot_prompt(itm["question"])
    outs = []
    for _ in range(TOT_PATHS):
        try:
            outs.append(call_model(m["provider"], m["name"], msgs, TEMPS["tree_of_thought"], MAX_TOKENS))
        except Exception as e:
            outs.append(f"[ERROR PATH] {e}")
    return "\n\n".join([f"Path {i+1}:\n{t}" for i,t in enumerate(outs)])

def extract_final_answer(text: str) -> str:
    lines = text.strip().splitlines()
    for line in reversed(lines):
        if "answer" in line.lower():
            return line.split(":",1)[-1].strip()
    return lines[-1].strip() if lines else ""

def run_self_consistency(itm, m):
    msgs = build_self_consistency_prompt(itm["question"])
    answers = []
    for _ in range(SC_SAMPLES):
        try:
            out = call_model(m["provider"], m["name"], msgs, TEMPS["self_consistency"], MAX_TOKENS)
            answers.append((extract_final_answer(out), out))
        except Exception as e:
            answers.append(("[ERROR]", f"[ERROR SAMPLE] {e}"))
    winner, count = Counter(a for a,_ in answers).most_common(1)[0]
    report = [f"Most frequent answer ({count}/{SC_SAMPLES}): {winner}", "", "All samples:"]
    for i,(_,raw) in enumerate(answers,1):
        report += [f"--- Sample {i} ---", raw, ""]
    return "\n".join(report).strip()

def run_item(item, model_cfg, strategy, dataset):
    try:
        if strategy=="zero_shot": out = run_zero_shot(item, model_cfg)
        elif strategy=="few_shot_cot": out = run_few_shot_cot(item, model_cfg, dataset)
        elif strategy=="tree_of_thought": out = run_tree_of_thought(item, model_cfg)
        elif strategy=="self_consistency": out = run_self_consistency(item, model_cfg)
        else: raise ValueError(f"Unknown strategy: {strategy}")
        return {**item, "model":model_cfg["name"], "provider":model_cfg["provider"], "strategy":strategy, "temperature":TEMPS[strategy], "model_answer":out}
    except Exception as e:
        return {**item, "model":model_cfg["name"], "provider":model_cfg["provider"], "strategy":strategy, "temperature":TEMPS[strategy], "model_answer":None, "error":str(e)}

def run_task(model_cfg, strategy, dataset, sample_size=None, out_dir=Path("."), tag_prefix=""):
    prefix = f"{tag_prefix}_" if tag_prefix else ""
    tag = f"{prefix}{strategy}_{model_cfg['name'].replace('/', '-')}"
    out_dir.mkdir(parents=True, exist_ok=True)
    out_path = out_dir / f"{tag}.json"
    items = dataset if sample_size is None else dataset[:sample_size]
    results = []
    for itm in items:
        results.append(run_item(itm, model_cfg, strategy, dataset))
    with out_path.open("w", encoding="utf-8") as f:
        json.dump(results, f, indent=2, ensure_ascii=False)
    return str(out_path)

def main():
    for fname in DATASET_FILES:
        ds_path = DATASET_DIR / fname
        ds_name = Path(fname).stem
        with ds_path.open("r", encoding="utf-8") as f:
            dataset = json.load(f)
        if SAMPLE_SIZE is not None:
            dataset = dataset[:SAMPLE_SIZE]
        out_dir = OUTPUT_ROOT / ds_name / "open_sourced"
        for m in MODELS:
            for s in STRATEGIES:
                run_task(m, s, dataset, SAMPLE_SIZE, out_dir, ds_name)

if __name__ == "__main__":
    main()