import argparse, json, os, random, re, yaml
from typing import Dict, Any, List
from .llm import LLMClient
from .sanitize import sanitize
from .provenance import hash_str, manifest_record
from .artifacts import write_jsonl, count_tokens

CRITERIA = ["clarity","novelty","methodology","reproducibility","ethics"]

# ---------- helpers ----------

def load_yaml(path): 
    with open(path, encoding="utf-8") as f: 
        return yaml.safe_load(f)

def load_topics(path):
    return [json.loads(l) for l in open(path, encoding="utf-8")]

def clamp01(x: float) -> float:
    try:
        return max(0.0, min(1.0, float(x)))
    except Exception:
        return 0.0

def scores_from_seed(seed:int)->Dict[str,float]:
    """Deterministic fallback so offline stub can reproduce without API."""
    rnd = random.Random(seed)
    base = rnd.uniform(0.48,0.62)
    return {k: clamp01(base + rnd.uniform(-0.08,0.12)) for k in CRITERIA}

def recommendation_from_Q(Q: float, tau_acc: float, tau_min: float, k_means: Dict[str,float], k_min: int) -> str:
    pass_min = sum(1 for v in k_means.values() if v >= tau_min) >= k_min
    if Q >= tau_acc and pass_min:
        return "Accept"
    if Q >= (tau_acc - 0.05):
        return "Continue"
    return "Reject"

def write_manifest(path, kind, model_family, model_version, temperature, seed, prompt_text, redacted_text):
    rec = manifest_record(kind, model_family, model_version, temperature, seed,
                          hash_str(prompt_text or ""), hash_str(redacted_text or ""))
    write_jsonl(path, rec)

def extract_last_json(text: str) -> Any:
    """Find and parse the last JSON object in text. Return dict or None."""
    candidates = list(re.finditer(r"\{.*\}", text, re.S))
    for m in reversed(candidates):
        s = m.group(0)
        try:
            return json.loads(s)
        except Exception:
            continue
    return None

def aggregate_equal_weights(scores_list: List[Dict[str,float]]) -> Dict[str,float]:
    if not scores_list:
        return {k: 0.0 for k in CRITERIA}
    k_means = {}
    for k in CRITERIA:
        k_means[k] = sum(clamp01(d.get(k,0.0)) for d in scores_list) / len(scores_list)
    return k_means

def Q_from_means(k_means: Dict[str,float]) -> float:
    # equal weights over five criteria
    return sum(k_means[k] for k in CRITERIA) / len(CRITERIA)

def apply_patches(text: str, patches: List[Dict[str,str]]) -> str:
    """Naive string replace patches; keep original if replacement fails."""
    if not patches:
        return text
    out = text
    for p in patches:
        before = p.get("before","")
        after  = p.get("after","")
        if before and before in out and after:
            out = out.replace(before, after)
    return out

# ---------- main pipeline ----------

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--config", required=True)
    ap.add_argument("--topics", required=True)
    ap.add_argument("--out", required=True)
    ap.add_argument("--rounds", type=int, default=2)
    ap.add_argument("--reviewers", type=int, default=3)
    ap.add_argument("--cross-review", choices=["on","off"], default="off")  # reserved
    ap.add_argument("--planner", choices=["on","off"], default="on")
    ap.add_argument("--seed", type=int, default=42)
    args = ap.parse_args()

    os.makedirs(args.out, exist_ok=True)
    cfg = load_yaml(args.config)
    thresholds = cfg["thresholds"]
    Tmax = min(args.rounds, cfg["budgets"]["Tmax"])

    topics = load_topics(args.topics)

    # LLM clients (stub or openai depending on env)
    llm_author  = LLMClient(mode=os.getenv("LLM4R_MODE","stub"), temperature=0.5)
    llm_reviewer= LLMClient(mode=os.getenv("LLM4R_MODE","stub"), temperature=0.3)
    llm_reviser = LLMClient(mode=os.getenv("LLM4R_MODE","stub"), temperature=0.4)
    llm_meta    = LLMClient(mode=os.getenv("LLM4R_MODE","stub"), temperature=0.3)

    # IO paths
    mans_path = os.path.join(args.out,"manuscripts.jsonl")
    revs_path = os.path.join(args.out,"reviews.jsonl")
    resps_path = os.path.join(args.out,"responses.jsonl")
    decs_path = os.path.join(args.out,"decisions.jsonl")
    logs_path = os.path.join(args.out,"logs.jsonl")
    prov_path = os.path.join(args.out,"manifest.jsonl")

    # load prompts
    author_prompt_tpl  = open("prompts/author.txt","r",encoding="utf-8").read()
    reviewer_prompt_tpl= open("prompts/reviewer.txt","r",encoding="utf-8").read()
    reviser_prompt_tpl = open("prompts/reviser.txt","r",encoding="utf-8").read()
    meta_prompt_tpl    = open("prompts/meta.txt","r",encoding="utf-8").read()

    rnd = random.Random(args.seed)

    for topic in topics:
        topic_id = topic["id"]

        # --------- Round 0: Author draft ----------
        author_prompt = f"{author_prompt_tpl}\n\nTOPIC: {topic['prompt']}"
        manuscript = llm_author.generate(author_prompt, max_tokens=1200)
        manuscript = sanitize(manuscript)
        write_jsonl(mans_path, {"topic_id": topic_id, "round": 0, "text": manuscript})
        write_manifest(prov_path, "manuscript", llm_author.mode, os.getenv("LLM4R_MODEL","stub-1.0"), 0.5, args.seed, author_prompt, manuscript)
        write_jsonl(logs_path, {"topic_id": topic_id, "round": 0, "role":"author", "tokens": count_tokens(manuscript), "latency_s": rnd.uniform(1.0,2.0)})

        last_Q = 0.0

        # --------- Review/Revise/Meta rounds ----------
        for t in range(1, Tmax+1):
            reviews = []

            # ===== Reviewers =====
            for j in range(args.reviewers):
                seed = (hash(topic_id) ^ (t<<8) ^ (j<<2) ^ args.seed) & 0xffffffff
                # LLM review: prompt + manuscript text
                rv_call = f"{reviewer_prompt_tpl}\n\n--- MANUSCRIPT START ---\n{manuscript}\n--- MANUSCRIPT END ---"
                rv_text = llm_reviewer.generate(rv_call, max_tokens=900)
                rv_text = sanitize(rv_text)
                data = extract_last_json(rv_text)

                if not isinstance(data, dict) or "scores" not in data:
                    # fallback: deterministic stub scores
                    scores = scores_from_seed(seed)
                    pros = ["well motivated","clear structure"]
                    cons = ["missing ablations","limited baselines"]
                    rec = recommendation_from_Q(Q_from_means(scores), thresholds["tau_acc"], thresholds["tau_min"], scores, thresholds["k_min"])
                    data = {
                        "scores": scores,
                        "pros": pros, "cons": cons, "risks": [],
                        "recommendation": rec,
                        "evidence": {}, "safety": {"flags": [], "notes": "fallback-stub"},
                        "confidence": 0.5
                    }

                # sanitize and clamp
                scores = {k: clamp01(data.get("scores",{}).get(k,0.0)) for k in CRITERIA}
                rec = data.get("recommendation","Continue")
                rv_obj = {
                    "topic_id": topic_id, "round": t, "reviewer_id": f"R{j+1}",
                    "scores": scores,
                    "pros": data.get("pros",[])[:2],
                    "cons": data.get("cons",[])[:2],
                    "risks": data.get("risks",[])[:2],
                    "recommendation": rec,
                    "evidence": {k: str(data.get("evidence",{}).get(k,"")) for k in CRITERIA},
                    "safety": data.get("safety", {"flags":[], "notes":""}),
                    "confidence": clamp01(data.get("confidence", 0.5))
                }
                reviews.append(rv_obj)
                write_jsonl(revs_path, rv_obj)
                write_manifest(prov_path, "review", llm_reviewer.mode, os.getenv("LLM4R_MODEL","stub-1.0"), 0.3, seed, reviewer_prompt_tpl, json.dumps(rv_obj, ensure_ascii=False))
                write_jsonl(logs_path, {"topic_id": topic_id, "round": t, "role":"reviewer", "tokens": 180, "latency_s": rnd.uniform(0.4,0.9)})

            # Aggregate (equal weights)
            k_means = aggregate_equal_weights([r["scores"] for r in reviews])
            Q = Q_from_means(k_means)

            # ===== Reviser =====
            if args.planner == "on":
                # Construct compact context for reviser
                short_reviews = [
                    {
                        "reviewer_id": r["reviewer_id"],
                        "scores": r["scores"],
                        "pros": r["pros"],
                        "cons": r["cons"],
                        "recommendation": r["recommendation"]
                    } for r in reviews
                ]
                rev_call = (
                    f"{reviser_prompt_tpl}\n\n"
                    f"--- MANUSCRIPT (round {t-1}) ---\n{manuscript}\n--- END ---\n\n"
                    f"--- REVIEWS (JSON) ---\n{json.dumps(short_reviews, ensure_ascii=False)}\n--- END ---"
                )
                resp_text = llm_reviser.generate(rev_call, max_tokens=900)
                resp_text = sanitize(resp_text)
                resp = extract_last_json(resp_text)
                if not isinstance(resp, dict):
                    resp = {
                        "change_plan": [{"id":"C1","section":"Experiments","action":"add ablation","patch_hint":"report ΔQ and Accept@R2","verification":"tab updated"}],
                        "response_letter": {"to":"Reviewers","summary":"We clarified methods and added ablations.","per_reviewer":{}},
                        "patches": []
                    }
                change_plan = resp.get("change_plan", [])
                response_letter = resp.get("response_letter", {})
                patches = resp.get("patches", [])

                # apply patches (best-effort)
                revised = apply_patches(manuscript, patches)
                if revised == manuscript:
                    revised = manuscript + f"\n\n[Round {t} revisions applied.]"

                write_jsonl(resps_path, {"topic_id": topic_id, "round": t, "change_plan": change_plan, "response": response_letter})
                write_manifest(prov_path, "response", llm_reviser.mode, os.getenv("LLM4R_MODEL","stub-1.0"), 0.4, args.seed + t, reviser_prompt_tpl, json.dumps(response_letter, ensure_ascii=False))
                write_jsonl(mans_path, {"topic_id": topic_id, "round": t, "text": revised})
                write_manifest(prov_path, "manuscript", llm_author.mode, os.getenv("LLM4R_MODEL","stub-1.0"), 0.5, args.seed + t, "[revised]"+reviser_prompt_tpl, revised)
                write_jsonl(logs_path, {"topic_id": topic_id, "round": t, "role":"reviser", "tokens": count_tokens(revised), "latency_s": rnd.uniform(0.6,1.0)})
                manuscript = revised
            else:
                # no planner: simple append
                manuscript = manuscript + f"\n\n[Round {t} revisions applied.]"
                write_jsonl(mans_path, {"topic_id": topic_id, "round": t, "text": manuscript})
                write_jsonl(logs_path, {"topic_id": topic_id, "round": t, "role":"author", "tokens": count_tokens(manuscript), "latency_s": rnd.uniform(0.5,1.0)})

            # ===== Meta decision =====
            meta_input = {
                "k_means": k_means,
                "Q": Q,
                "num_reviews": len(reviews),
                "recommendations": [r["recommendation"] for r in reviews]
            }
            meta_call = f"{meta_prompt_tpl}\n\n--- AGGREGATE ---\n{json.dumps(meta_input, ensure_ascii=False)}\n--- END ---"
            meta_text = llm_meta.generate(meta_call, max_tokens=500)
            meta_text = sanitize(meta_text)
            meta_json = extract_last_json(meta_text)

            if isinstance(meta_json, dict) and "decision" in meta_json and "Q" in meta_json and "k_means" in meta_json:
                # trust LLM meta if it returns sane values; clamp and backstop with thresholds
                k_means_llm = {k: clamp01(meta_json.get("k_means",{}).get(k, k_means.get(k,0.0))) for k in CRITERIA}
                Q_llm = float(meta_json.get("Q", Q))
                decision = meta_json.get("decision","Continue")
                # backstop rule to avoid Accept with low criteria
                decision_rule = recommendation_from_Q(Q_llm, thresholds["tau_acc"], thresholds["tau_min"], k_means_llm, thresholds["k_min"])
                if decision == "Accept" and decision_rule != "Accept":
                    decision = decision_rule
                k_means = k_means_llm
                Q = Q_llm
            else:
                decision = recommendation_from_Q(Q, thresholds["tau_acc"], thresholds["tau_min"], k_means, thresholds["k_min"])

            # small “no progress” early-stop heuristic as before
            if Q - last_Q < cfg["thresholds"]["epsilon"] and t < Tmax and decision != "Accept":
                decision = "Reject"

            write_jsonl(decs_path, {"topic_id": topic_id, "round": t, "decision": decision, "Q": Q, "k_means": k_means})
            write_manifest(prov_path, "meta", llm_meta.mode, os.getenv("LLM4R_MODEL","stub-1.0"), 0.3, args.seed + 100 + t, meta_prompt_tpl, json.dumps({"decision":decision,"Q":Q}, ensure_ascii=False))
            write_jsonl(logs_path, {"topic_id": topic_id, "round": t, "role":"meta", "tokens": 40, "latency_s": 0.2})
            last_Q = Q

    print("Pipeline done. Artifacts written to", args.out)

if __name__ == "__main__":
    main()
