import os, json, random
from pathlib import Path
from collections import Counter
from typing import List, Dict, Any, Tuple
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type

from openai import OpenAI
try:
    import anthropic
except ImportError:
    anthropic = None

DATASET_DIR = Path("./dataset")
OUTPUT_ROOT = Path("./output")

DATASET_FILES = [
    "linear_algebra.json",
    "single_variable_calculus.json",
    "multivariable_calculus.json",
    "differential_equations.json",
    "discrete_math.json",
    "trigonometry.json",
    "pre_calculus.json",
]

MAX_TOKENS = 2048
SC_SAMPLES = 5
TOT_PATHS = 3
SAMPLE_SIZE = None

MODELS = [
    {"name": "gpt-4.1", "provider": "openai"},
    {"name": "gpt-3.5-turbo-0125", "provider": "openai"},
    {"name": "o3", "provider": "openai"},
    {"name": "claude-3-7-sonnet-20250219", "provider": "anthropic"},
]

STRATEGIES = [
    "zero_shot",
    "few_shot_cot",
    "tree_of_thought",
    "self_consistency",
]

TEMPS = {
    "zero_shot": 0.0,
    "few_shot_cot": 0.0,
    "tree_of_thought": 0.7,
    "self_consistency": 0.9,
}

openai_client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
anthropic_client = anthropic.Client(api_key=os.getenv("ANTHROPIC_API_KEY")) if anthropic and os.getenv("ANTHROPIC_API_KEY") else None

def build_zero_shot_prompt(problem: str) -> List[Dict[str, str]]:
    system = "Conclude the final answer in the form: \\boxed{your final answer here}."
    user = f"Solve the following math problem: {problem}"
    return [{"role": "system", "content": system}, {"role": "user", "content": user}]

def build_few_shot_cot_prompt(problem: str, few_shot_examples: List[Dict[str, Any]]) -> List[Dict[str, str]]:
    system = "You are a highly skilled mathematics expert. Solve the problem step by step. Conclude with your final answer in the form: \\boxed{your final answer here}."
    fs_blocks = []
    for ex in few_shot_examples:
        q = ex.get("question", "")
        steps = ex.get("steps", []) or []
        ans = ex.get("answer", "")
        steps_text = "\n".join(steps)
        fs_blocks.append(f"Q: {q}\n\nA: {steps_text}\n\n\\boxed{{{ans}}}\n")
    user = "\n".join(fs_blocks) + f"\nQ: {problem}\n\nA:"
    return [{"role": "system", "content": system}, {"role": "user", "content": user}]

def build_tot_prompt(problem: str) -> List[Dict[str, str]]:
    system = "You are a highly skilled mathematics expert. Brainstorm multiple distinct solution paths for the given problem. At the end, clearly state the final answer in the form: \\boxed{your final answer here}."
    return [{"role": "system", "content": system}, {"role": "user", "content": problem}]

def build_self_consistency_prompt(problem: str) -> List[Dict[str, str]]:
    system = "You are a highly skilled mathematics expert. Solve the problem with clear reasoning. At the end, clearly state the final answer in the form: \\boxed{your final answer here}."
    return [{"role": "system", "content": system}, {"role": "user", "content": problem}]

def pick_few_shot_examples(dataset: List[Dict[str, Any]], target_item: Dict[str, Any], k: int = 3) -> List[Dict[str, Any]]:
    subtopic = target_item.get("subtopic")
    qid = str(target_item.get("id", ""))
    candidates = [ex for ex in dataset if ex.get("subtopic") == subtopic and str(ex.get("id", "")) != qid and ex.get("steps") and ex.get("answer")]
    return random.sample(candidates, min(k, len(candidates)))

class TransientAPIError(Exception):
    pass

@retry(reraise=True, stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, min=1, max=10), retry=retry_if_exception_type(TransientAPIError))
def call_model(provider: str, model: str, messages, temperature: float, max_tokens: int) -> str:
    try:
        if provider == "openai":
            is_reasoning = model.lower().startswith(("o1", "o3", "o4"))
            params = {"model": model, "messages": messages}
            if is_reasoning:
                params["max_completion_tokens"] = max_tokens
            else:
                params["temperature"] = temperature
                params["max_tokens"] = max_tokens
            resp = openai_client.chat.completions.create(**params)
            return resp.choices[0].message.content.strip()
        elif provider == "anthropic":
            if anthropic_client is None:
                raise RuntimeError("Anthropic client unavailable.")
            sys_content = ""
            user_content = ""
            for m in messages:
                if m["role"] == "system":
                    sys_content += (m["content"] + "\n")
                elif m["role"] == "user":
                    user_content += (m["content"] + "\n")
            resp = anthropic_client.messages.create(
                model=model,
                system=sys_content.strip() if sys_content else None,
                max_tokens=max_tokens,
                temperature=temperature,
                messages=[{"role": "user", "content": user_content.strip()}],
            )
            parts = []
            for block in resp.content:
                if getattr(block, "type", "") == "text":
                    parts.append(block.text)
            return "\n".join(parts).strip()
        else:
            raise ValueError("Unknown provider")
    except Exception as e:
        raise TransientAPIError(str(e)) from e

def run_zero_shot(item: Dict[str, Any], model_cfg: Dict[str, str]) -> str:
    return call_model(model_cfg["provider"], model_cfg["name"], build_zero_shot_prompt(item["question"]), TEMPS["zero_shot"], MAX_TOKENS)

def run_few_shot_cot(item: Dict[str, Any], model_cfg: Dict[str, str], dataset: List[Dict[str, Any]]) -> str:
    fs = pick_few_shot_examples(dataset, item, k=3)
    return call_model(model_cfg["provider"], model_cfg["name"], build_few_shot_cot_prompt(item["question"], fs), TEMPS["few_shot_cot"], MAX_TOKENS)

def run_tree_of_thought(item: Dict[str, Any], model_cfg: Dict[str, str]) -> str:
    messages = build_tot_prompt(item["question"])
    outputs = []
    for _ in range(TOT_PATHS):
        try:
            outputs.append(call_model(model_cfg["provider"], model_cfg["name"], messages, TEMPS["tree_of_thought"], MAX_TOKENS))
        except Exception as e:
            outputs.append(f"[ERROR PATH] {e}")
    joined = []
    for i, t in enumerate(outputs, 1):
        joined.append(f"Path {i}:\n{t}")
    return "\n\n".join(joined)

def extract_final_answer(text: str) -> str:
    lines = text.strip().splitlines()
    for line in reversed(lines):
        if "answer" in line.lower():
            return line.split(":", 1)[-1].strip()
    return lines[-1].strip() if lines else ""

def run_self_consistency(item: Dict[str, Any], model_cfg: Dict[str, str]) -> str:
    messages = build_self_consistency_prompt(item["question"])
    answers = []
    for _ in range(SC_SAMPLES):
        try:
            raw = call_model(model_cfg["provider"], model_cfg["name"], messages, TEMPS["self_consistency"], MAX_TOKENS)
            answers.append((extract_final_answer(raw), raw))
        except Exception as e:
            answers.append(("[ERROR]", f"[ERROR SAMPLE] {e}"))
    votes = Counter(a for a, _ in answers)
    winner, count = votes.most_common(1)[0]
    report = [f"Most frequent answer ({count}/{SC_SAMPLES}): {winner}", "", "All samples:"]
    for i, (_, raw) in enumerate(answers, 1):
        report.append(f"--- Sample {i} ---")
        report.append(raw)
        report.append("")
    return "\n".join(report).strip()

def run_item(item: Dict[str, Any], model_cfg: Dict[str, str], strategy: str, dataset: List[Dict[str, Any]]) -> Dict[str, Any]:
    try:
        if strategy == "zero_shot":
            out = run_zero_shot(item, model_cfg)
        elif strategy == "few_shot_cot":
            out = run_few_shot_cot(item, model_cfg, dataset)
        elif strategy == "tree_of_thought":
            out = run_tree_of_thought(item, model_cfg)
        elif strategy == "self_consistency":
            out = run_self_consistency(item, model_cfg)
        else:
            raise ValueError("Unknown strategy")
        return {**item, "model": model_cfg["name"], "provider": model_cfg["provider"], "strategy": strategy, "temperature": TEMPS[strategy], "model_answer": out}
    except Exception as e:
        return {**item, "model": model_cfg["name"], "provider": model_cfg["provider"], "strategy": strategy, "temperature": TEMPS[strategy], "model_answer": None, "error": str(e)}

def run_task(model_cfg: Dict[str, str], strategy: str, dataset: List[Dict[str, Any]], sample_size: int, out_dir: Path, tag_prefix: str) -> Path:
    tag = f"{tag_prefix}_{strategy}_{model_cfg['name']}".replace("/", "-")
    out_path = out_dir / f"{tag}.json"
    items = dataset if sample_size is None else dataset[:sample_size]
    results = []
    for i, itm in enumerate(items, 1):
        results.append(run_item(itm, model_cfg, strategy, dataset))
        if i % 50 == 0:
            print(f"[{tag}] {i}/{len(items)}")
    out_dir.mkdir(parents=True, exist_ok=True)
    with out_path.open("w", encoding="utf-8") as f:
        json.dump(results, f, indent=2, ensure_ascii=False)
    print(f"[DONE] {tag} -> {out_path}")
    return out_path

def main():
    for fname in DATASET_FILES:
        ds_path = DATASET_DIR / fname
        ds_name = Path(fname).stem
        print(f"\n===== Dataset: {ds_name} =====")
        with ds_path.open("r", encoding="utf-8") as f:
            dataset = json.load(f)
        if SAMPLE_SIZE is not None:
            dataset = dataset[:SAMPLE_SIZE]
        out_dir = OUTPUT_ROOT / ds_name
        for m in MODELS:
            for s in STRATEGIES:
                try:
                    run_task(m, s, dataset, SAMPLE_SIZE, out_dir, ds_name)
                except Exception as e:
                    print(f"[TASK ERROR][{ds_name}][{m['name']}][{s}] {e}")
        print(f"===== Finished: {ds_name} =====\n")

if __name__ == "__main__":
    main()
