from __future__ import annotations

import argparse
import json
import os
import re
import sys
import traceback
from dataclasses import dataclass
from typing import Any, Dict, Optional


def _oracle_to_grpo_cfg_name(oracle_name: str) -> str:
    name = str(oracle_name).strip().lower()
    name = name.replace(" ", "_").replace("-", "_").replace("/", "_")
    name = re.sub(r"__+", "_", name)
    if name.endswith("_current"):
        name = name[:-8]
    if name.endswith("_latest"):
        name = name[:-7]
    if name in {"median_1", "median_2"}:
        name = name.replace("_", "")
    return name


def _abspath_if_relative(root: str, path: Any) -> Any:
    if not path:
        return path
    path = os.path.expanduser(str(path))
    return path if os.path.isabs(path) else os.path.join(root, path)


def _write_jsonl(out_fp, payload: Dict[str, Any]) -> None:
    out_fp.write(json.dumps(payload, ensure_ascii=False, separators=(",", ":")) + "\n")
    out_fp.flush()


def _env_truthy(key: str) -> bool:
    val = os.environ.get(key)
    if val is None:
        return False
    return str(val).strip().lower() in {"1", "true", "t", "yes", "y", "on"}


@dataclass
class _Session:
    proposer: Any


def _build_proposer(
    *,
    repo_root: str,
    oracle_name: str,
    grpo_cfg: Optional[str],
    checkpoint_path: Optional[str],
    num_reward_workers: Optional[int],
    device: str,
    experiment: Optional[str] = None,
    dataset: Optional[str] = None,
    extra_hydra_overrides: Optional[list[str]] = None,
) -> Any:
    cfg_name = (grpo_cfg or _oracle_to_grpo_cfg_name(oracle_name)).strip()

    import torch
    from hydra import compose, initialize_config_dir
    from hydra.core.global_hydra import GlobalHydra

    from eval_grpo_sampler import GraphGRPOProposer

    config_dir = os.path.join(repo_root, "configs")
    if GlobalHydra.instance().is_initialized():
        GlobalHydra.instance().clear()

    with initialize_config_dir(version_base="1.3", config_dir=config_dir):
        overrides: list[str] = []

        if experiment:
            overrides.append(f"+experiment={experiment}")
        if dataset:
            overrides.append(f"dataset={dataset}")
        if extra_hydra_overrides:
            overrides.extend([str(x) for x in extra_hydra_overrides if str(x).strip()])

        is_zinc_based = (
            str(oracle_name).strip().lower().endswith("_mpo") or 
            str(cfg_name).endswith("_mpo") or
            str(cfg_name).startswith("isomers_") or
            cfg_name in {
                "drd2", "gsk3b", "jnk3", "qed", "sa", "logp",
                "albuterol_similarity", "mestranol_similarity",
                "celecoxib_rediscovery", "thiothixene_rediscovery", "troglitazone_rediscovery",
                "deco_hop", "scaffold_hop", "valsartan_smarts", "median1", "median2",
                "perindopril_rings", "perindopril_mpo", "sitagliptin_mpo", "zaleplon_mpo"
            }
        )
        if is_zinc_based and not experiment and not dataset and not extra_hydra_overrides:
            overrides.extend(["+experiment=zinc", "dataset=zinc"])

        overrides.append(f"+grpo={cfg_name}")

        cfg = compose(config_name="config", overrides=overrides)

    target_task = str(getattr(getattr(cfg, "grpo", None), "target_task", "")).lower()
    if target_task in ["osimertinib_mpo", "fexofenadine_mpo", "ranolazine_mpo", "perindopril_mpo", "amlodipine_mpo", "sitagliptin_mpo", "zaleplon_mpo"]:
        if cfg.dataset.name != "zinc250k":
            print(f"⚠️ [Config Patch] Forcing dataset.name='zinc250k' for {target_task} (was {cfg.dataset.name})")
            cfg.dataset.name = "zinc250k"
            cfg.dataset.datadir = "data/zinc/"
            cfg.dataset.empty = True
            cfg.dataset.remove_h = True
            cfg.dataset.aromatic = False

    try:
        cfg.dataset.datadir = _abspath_if_relative(repo_root, cfg.dataset.datadir)
    except Exception:
        pass

    if not _env_truthy("GRAPH_GRPO_USE_DEFAULT_CKPT"):
        ckpt_override = checkpoint_path or os.environ.get("GRAPH_GRPO_CKPT")
        if ckpt_override:
            try:
                cfg.grpo.pretrained_checkpoint = _abspath_if_relative(repo_root, ckpt_override)
            except Exception:
                pass
            try:
                cfg.grpo.resume_from_checkpoint = cfg.grpo.pretrained_checkpoint
            except Exception:
                pass

    if num_reward_workers is not None:
        try:
            cfg.grpo.num_reward_workers = int(num_reward_workers)
        except Exception:
            pass

    try:
        cfg.general.test_only = True
    except Exception:
        pass

    if device == "auto":
        torch_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    else:
        torch_device = torch.device(device)

    return GraphGRPOProposer(cfg=cfg, device=torch_device)


def _serve(session: _Session, *, in_fp, out_fp) -> int:
    while True:
        try:
            line = in_fp.readline()
        except EOFError:
            break
        except Exception:
            break
            
        if not line:
            break
            
        line = line.strip()
        if not line:
            print("📩 Received empty line on stdin", file=sys.stderr, flush=True)
            continue

        req_id = None
        try:
            req = json.loads(line)
            req_id = req.get("id")
            method = req.get("method")
            params = req.get("params") or {}
            
            print(f"📩 [{method}] Request received (id={req_id})", file=sys.stderr, flush=True)

            if method == "ping":
                _write_jsonl(out_fp, {"id": req_id, "result": {"ok": True}, "error": None})
                continue

            if method == "shutdown":
                out_dir = os.environ.get("GRAPH_GRPO_OUTPUT_DIR")
                if out_dir and hasattr(session.proposer, "save_topk_visualizations"):
                    try:
                        print(
                            f"🖼️ [shutdown] Saving top-10 visualizations to {out_dir}",
                            file=sys.stderr,
                            flush=True,
                        )
                        session.proposer.save_topk_visualizations(out_dir, topk=10)
                    except Exception as e:
                        print(f"⚠️ [shutdown] Failed to save top-10 visualizations: {e}", file=sys.stderr, flush=True)
                _write_jsonl(out_fp, {"id": req_id, "result": {"ok": True}, "error": None})
                return 0

            if method == "propose":
                batch_size = int(params.get("batch_size", 0) or 0)
                state = dict(params.get("state") or {})
                smiles = session.proposer.propose(batch_size=batch_size, replay=None, state=state)
                _write_jsonl(
                    out_fp,
                    {"id": req_id, "result": {"smiles": smiles, "state": state}, "error": None},
                )
                print(f"✅ [{method}] Request processed (id={req_id})", file=sys.stderr, flush=True)
                continue

            if method == "observe":
                smiles = list(params.get("smiles") or [])
                scores = params.get("scores")
                state = dict(params.get("state") or {})
                session.proposer.observe(smiles=smiles, scores=scores, replay=None, state=state)
                _write_jsonl(out_fp, {"id": req_id, "result": {"state": state}, "error": None})
                continue

            raise ValueError(f"Unknown method: {method}")
        except Exception as e:
            _write_jsonl(
                out_fp,
                {
                    "id": req_id,
                    "result": None,
                    "error": {
                        "type": type(e).__name__,
                        "message": str(e),
                        "traceback": traceback.format_exc(limit=50),
                    },
                },
            )
    return 0


def main(argv: Optional[list[str]] = None) -> int:
    parser = argparse.ArgumentParser(description="GraphGRPO proposer JSONL server (stdin/stdout).")
    parser.add_argument("--oracle-name", required=True, help="mol_opt oracle name, e.g. Ranolazine_MPO")
    parser.add_argument("--grpo-cfg", default=None, help="Override configs/grpo/<name>.yaml (without .yaml)")
    parser.add_argument("--checkpoint-path", default=None, help="Override checkpoint (or set GRAPH_GRPO_CKPT)")
    parser.add_argument("--num-reward-workers", type=int, default=None, help="Optional override; proposer never calls oracle")
    parser.add_argument("--device", default="auto", help='Torch device string, e.g. "cpu", "cuda:0", or "auto"')
    parser.add_argument(
        "--experiment",
        default=None,
        help="Optional Hydra override (e.g. 'zinc'); equivalent to '+experiment=<name>' used in training.",
    )
    parser.add_argument(
        "--dataset",
        default=None,
        help="Optional Hydra override (e.g. 'zinc'); equivalent to 'dataset=<name>' used in training.",
    )
    parser.add_argument(
        "--hydra-overrides",
        default=None,
        help="Optional extra Hydra overrides, separated by ';' (e.g. '+experiment=zinc;dataset=zinc').",
    )
    parser.add_argument("--mock", action="store_true", help="Use a tiny mock proposer (for protocol testing).")
    args = parser.parse_args(argv)

    proto_out = sys.stdout
    sys.stdout = sys.stderr

    repo_root = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir))
    src_dir = os.path.join(repo_root, "src")
    if src_dir not in sys.path:
        sys.path.insert(0, src_dir)

    if args.mock:
        class _MockProposer:
            def __init__(self):
                self._i = 0

            def propose(self, batch_size: int, replay: Any, state: Dict[str, Any]) -> list[str]:
                _ = batch_size, replay
                self._i += 1
                state["round_idx"] = int(state.get("round_idx", 0)) + 1
                return [f"MOCK_{self._i}"]

            def observe(self, smiles: list[str], scores: Any, replay: Any, state: Dict[str, Any]) -> None:
                _ = smiles, scores, replay, state

        proposer = _MockProposer()
    else:
        extra_overrides: Optional[list[str]] = None
        if args.hydra_overrides:
            extra_overrides = [s.strip() for s in str(args.hydra_overrides).split(";") if s.strip()]
        proposer = _build_proposer(
            repo_root=repo_root,
            oracle_name=args.oracle_name,
            grpo_cfg=args.grpo_cfg,
            checkpoint_path=args.checkpoint_path,
            num_reward_workers=args.num_reward_workers,
            device=args.device,
            experiment=args.experiment,
            dataset=args.dataset,
            extra_hydra_overrides=extra_overrides,
        )

    print("🚀 Proposer building...", file=sys.stderr, flush=True)
    session = _Session(proposer=proposer)
    print("✅ Proposer server ready. Listening on stdin...", file=sys.stderr, flush=True)
    return _serve(session, in_fp=sys.stdin, out_fp=proto_out)


if __name__ == "__main__":
    raise SystemExit(main())
