from __future__ import annotations

import argparse
import os
import sys
import time
from dataclasses import dataclass
from heapq import nlargest
from types import SimpleNamespace
from typing import Any, Dict, List, Optional


def _repo_root() -> str:
    return os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir))


def _default_molopt_repo(rl_repo: str) -> str:
    return os.path.abspath(os.path.join(rl_repo, os.pardir, "mol_opt"))


def _topk_scores(mol_buffer: Dict[str, List[Any]], k: int) -> List[float]:
    if k <= 0:
        return []
    vals = list(mol_buffer.values())
    if not vals:
        return []
    top = nlargest(k, vals, key=lambda v: float(v[0]))
    return sorted((float(v[0]) for v in top), reverse=True)


def _mean(xs: List[float]) -> float:
    return float(sum(xs) / len(xs)) if xs else 0.0


def _env_truthy(key: str) -> bool:
    val = os.environ.get(key)
    if val is None:
        return False
    return str(val).strip().lower() in {"1", "true", "t", "yes", "y", "on"}


@dataclass
class EarlyStopConfig:
    enabled: bool
    patience: int
    eps: float
    min_calls: int
    topk: int


def _parse_args(argv: Optional[List[str]] = None) -> argparse.Namespace:
    p = argparse.ArgumentParser(description="mol_opt Graph-GRPO runner with optional early-stop for AUC@10.")
    p.add_argument("--ckpt", required=True, help="Checkpoint path used by RL_Graph_Generation proposer.")
    p.add_argument("--oracle", required=True, help="mol_opt oracle name, e.g. median2 / Ranolazine_MPO.")
    p.add_argument("--max-oracle-calls", type=int, default=10000)
    p.add_argument("--seed", type=int, default=0)
    p.add_argument("--output-dir", default=None, help="Base output dir (default: mol_opt/main/graph_grpo/results).")
    p.add_argument("--freq-log", type=int, default=100, help="mol_opt logging frequency.")

    p.add_argument("--early-stop", action="store_true", help="Enable early stop when top-k scores plateau.")
    p.add_argument("--early-stop-patience", type=int, default=1000, help="Patience in *oracle calls* without top-k change.")
    p.add_argument("--early-stop-eps", type=float, default=0.0, help="Tolerance for top-k score changes.")
    p.add_argument("--early-stop-min-calls", type=int, default=1000, help="Do not early-stop before this many oracle calls.")
    p.add_argument("--early-stop-topk", type=int, default=10, help="Track top-k list (default 10, i.e. AUC@10).")

    return p.parse_args(argv)


def main(argv: Optional[List[str]] = None) -> int:
    args = _parse_args(argv)
    use_default_ckpt_env = _env_truthy("GRAPH_GRPO_USE_DEFAULT_CKPT")

    rl_repo = os.environ.get("RL_GRAPH_REPO") or _repo_root()
    molopt_repo = os.environ.get("MOLOPT_REPO") or _default_molopt_repo(rl_repo)
    if not os.path.isdir(molopt_repo):
        raise RuntimeError(f"MOLOPT_REPO not found: {molopt_repo}")
    sys.path.insert(0, molopt_repo)

    import yaml                
    from tdc import Oracle as TDCOracle                

    from main.graph_grpo import run as graph_run                

    ckpt_arg = str(args.ckpt or "")
    ckpt_provided = ckpt_arg not in {"", "__default__"}
    use_default_ckpt = use_default_ckpt_env and not ckpt_provided

    ckpt_path = ""
    if not use_default_ckpt:
        ckpt_path = os.path.abspath(os.path.expanduser(ckpt_arg))
        if not os.path.exists(ckpt_path):
            raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}")

    early = EarlyStopConfig(
        enabled=bool(args.early_stop),
        patience=max(int(args.early_stop_patience), 0),
        eps=float(args.early_stop_eps),
        min_calls=max(int(args.early_stop_min_calls), 0),
        topk=max(int(args.early_stop_topk), 1),
    )

    default_base = os.path.join(molopt_repo, "main", "graph_grpo", "results")
    base_output_dir = os.path.abspath(os.path.expanduser(args.output_dir or default_base))
    run_output_dir = os.path.join(base_output_dir, f"{args.oracle}_{int(args.seed)}")
    os.makedirs(run_output_dir, exist_ok=True)

    os.environ["GRAPH_GRPO_OUTPUT_DIR"] = run_output_dir
    if ckpt_path:
        os.environ["GRAPH_GRPO_CKPT"] = ckpt_path
    else:
        os.environ.pop("GRAPH_GRPO_CKPT", None)

    cfg_path = os.path.join(molopt_repo, "main", "graph_grpo", "hparams_default.yaml")
    with open(cfg_path, "r", encoding="utf-8") as f:
        config_default = yaml.safe_load(f) or {}
    config_default["checkpoint_path"] = ckpt_path if ckpt_path else ""

    optimizer_args = SimpleNamespace(
        n_jobs=-1,
        smi_file=None,
        output_dir=run_output_dir,
        max_oracle_calls=int(args.max_oracle_calls),
        freq_log=int(args.freq_log),
        log_results=False,
        log_code=False,
        wandb="disabled",
        oracles=[args.oracle],
        seed=[int(args.seed)],
        task="simple",
        method="graph_grpo",
    )

    class FastGraphGRPOOptimizer(graph_run.GraphGRPO_Optimizer):                
        def _optimize(self, oracle, config):                          
            self.oracle.assign_evaluator(oracle)
            oracle_name = getattr(oracle, "name", None) or getattr(self.args, "oracles", ["unknown"])[0]

            def maybe_early_stop(
                stale_calls: int,
                best_avg: Optional[float],
                best_top: Optional[List[float]],
                last_n: int,
            ) -> tuple[int, Optional[float], Optional[List[float]], int, bool]:
                n_now = len(self.oracle.mol_buffer)
                newly = max(n_now - last_n, 0)
                last_n = n_now
                if not early.enabled or newly <= 0:
                    return stale_calls, best_avg, best_top, last_n, False
                if n_now < max(early.topk, early.min_calls):
                    return 0, best_avg, best_top, last_n, False

                cur_top = _topk_scores(self.oracle.mol_buffer, early.topk)
                cur_avg = _mean(cur_top)
                if cur_avg >= 1.0 - 1e-12:
                    return 0, cur_avg, cur_top, last_n, True
                if best_avg is None:
                    return 0, cur_avg, cur_top, last_n, False

                improved = cur_avg > (best_avg + early.eps)
                if improved:
                    return 0, cur_avg, cur_top, last_n, False

                stale_calls += newly
                if stale_calls >= early.patience:
                    return stale_calls, best_avg, best_top, last_n, True
                return stale_calls, best_avg, best_top, last_n, False

            if graph_run._should_use_external_proposer():
                client = graph_run._start_external_proposer(oracle_name, config, self.args)
                try:
                    replay = {}
                    state: Dict[str, Any] = {"seed": int(getattr(self, "seed", 0)), "round_idx": 0, "propose_idx": 0}

                    stale_calls = 0
                    best_avg: Optional[float] = None
                    best_top: Optional[List[float]] = None
                    last_n = 0

                    while not self.finish:
                        smiles, state = client.propose(batch_size=int(config.get("batch_size", 0) or 0), state=state)
                        if not smiles:
                            continue

                        scores = self.oracle(smiles)
                        if not isinstance(scores, list):
                            scores = [scores]
                        scores = [float(x) for x in scores]
                        state["n_oracle"] = len(self.oracle.mol_buffer)
                        state = client.observe(smiles=smiles, scores=scores, state=state)

                        stale_calls, best_avg, best_top, last_n, should_stop = maybe_early_stop(
                            stale_calls, best_avg, best_top, last_n
                        )
                        if should_stop:
                            if best_avg is not None and best_avg >= 1.0 - 1e-12:
                                print(
                                    f"[early-stop] oracle={oracle_name} seed={getattr(self, 'seed', 0)} "
                                    f"n_oracle={len(self.oracle.mol_buffer)} avg_top{early.topk}=1.0 reached",
                                    file=sys.stderr,
                                    flush=True,
                                )
                            else:
                                print(
                                    f"[early-stop] oracle={oracle_name} seed={getattr(self, 'seed', 0)} "
                                    f"n_oracle={len(self.oracle.mol_buffer)} stale_calls={stale_calls} "
                                    f"avg_top{early.topk} did not improve >= {early.patience} (best_avg={best_avg})",
                                    file=sys.stderr,
                                    flush=True,
                                )
                            try:
                                out_path = os.path.join(str(self.args.output_dir), "early_stop.txt")
                                with open(out_path, "w", encoding="utf-8") as f:
                                    f.write(f"oracle={oracle_name}\n")
                                    f.write(f"seed={getattr(self, 'seed', 0)}\n")
                                    f.write(f"n_oracle={len(self.oracle.mol_buffer)}\n")
                                    f.write(f"stale_calls={stale_calls}\n")
                                    f.write(f"patience={early.patience}\n")
                                    f.write(f"eps={early.eps}\n")
                                    f.write(f"min_calls={early.min_calls}\n")
                                    f.write(f"topk={early.topk}\n")
                                    f.write(f"best_avg_topk={best_avg}\n")
                                    f.write(f"best_topk_scores={best_top}\n")
                            except Exception:
                                pass
                            break

                    _ = replay
                finally:
                    client.close()
                return

            import torch                
            from pmo_bridge.grpo_proposer import GraphGRPOProposer                

            cfg = self._compose_grpo_cfg(oracle_name, config)
            device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
            proposer = GraphGRPOProposer(cfg=cfg, device=device)

            replay = {}
            state = {"seed": int(getattr(self, "seed", 0)), "round_idx": 0, "propose_idx": 0}

            stale_calls = 0
            best_avg = None
            best_top = None
            last_n = 0

            while not self.finish:
                smiles = proposer.propose(batch_size=int(config.get("batch_size", 0) or 0), replay=replay, state=state)
                if not smiles:
                    continue

                scores = self.oracle(smiles)
                state["n_oracle"] = len(self.oracle.mol_buffer)
                proposer.observe(smiles, scores, replay=replay, state=state)

                stale_calls, best_avg, best_top, last_n, should_stop = maybe_early_stop(
                    stale_calls, best_avg, best_top, last_n
                )
                if should_stop:
                    if best_avg is not None and best_avg >= 1.0 - 1e-12:
                        print(
                            f"[early-stop] oracle={oracle_name} seed={getattr(self, 'seed', 0)} "
                            f"n_oracle={len(self.oracle.mol_buffer)} avg_top{early.topk}=1.0 reached",
                            file=sys.stderr,
                            flush=True,
                        )
                    else:
                        print(
                            f"[early-stop] oracle={oracle_name} seed={getattr(self, 'seed', 0)} "
                            f"n_oracle={len(self.oracle.mol_buffer)} stale_calls={stale_calls} "
                            f"avg_top{early.topk} did not improve >= {early.patience} (best_avg={best_avg})",
                            file=sys.stderr,
                            flush=True,
                        )
                    try:
                        out_path = os.path.join(str(self.args.output_dir), "early_stop.txt")
                        with open(out_path, "w", encoding="utf-8") as f:
                            f.write(f"oracle={oracle_name}\n")
                            f.write(f"seed={getattr(self, 'seed', 0)}\n")
                            f.write(f"n_oracle={len(self.oracle.mol_buffer)}\n")
                            f.write(f"stale_calls={stale_calls}\n")
                            f.write(f"patience={early.patience}\n")
                            f.write(f"eps={early.eps}\n")
                            f.write(f"min_calls={early.min_calls}\n")
                            f.write(f"topk={early.topk}\n")
                            f.write(f"best_avg_topk={best_avg}\n")
                            f.write(f"best_topk_scores={best_top}\n")
                    except Exception:
                        pass
                    break

    oracle = TDCOracle(name=str(args.oracle))
    optimizer = FastGraphGRPOOptimizer(args=optimizer_args)

    start = time.time()
    optimizer.optimize(oracle=oracle, config=config_default, seed=int(args.seed))
    dur = time.time() - start
    print(f"[done] output_dir={run_output_dir} elapsed={dur:.1f}s", file=sys.stderr)
    return 0


if __name__ == "__main__":
    raise SystemExit(main())
