# peptide/baselines/train.py
from __future__ import annotations

import argparse
import json
import time
from pathlib import Path

import torch

from util import set_seed
from ..peptide_env import Sequences, Policy
from ..peptide_reward import LogReward
from ..peptide_sampling import forward_reward
from ..peptide_utils import load_runs, new_run_dir, log_jsonl, save_ckpt


def load_seq_size(peptide_dir: Path) -> int:
    vocab_path = peptide_dir / "rf_models" / "encoders" / "vocab.json"
    with open(vocab_path, "r", encoding="utf-8") as f:
        return json.load(f)["seq_size"]


def parse_args():
    p = argparse.ArgumentParser()
    p.add_argument("--config", type=str, required=True)
    g = p.add_mutually_exclusive_group(required=True)
    g.add_argument("--run", type=str)
    g.add_argument("--all", action="store_true")
    return p.parse_args()


def run_one(cfg: dict, seq_size: int):
    cfg = dict(cfg)
    cfg["seq_size"] = seq_size
    
    run_dir = new_run_dir(exp="peptide", method="tb", cfg=cfg, out_root="runs")
    metrics_path = run_dir / "metrics.jsonl"

    device = cfg.get("device", "cpu")
    seed = int(cfg["seed"])
    set_seed(seed)

    log_reward = LogReward(cutoff=float(cfg["cut_off"]))
    env = Sequences(
        seq_size=seq_size,
        batch_size=int(cfg["batch_size"]),
        log_reward=log_reward,
        eps=float(cfg["eps"]),
        seed=seed,
    )

    fnet = Policy(
        emb_dim=int(cfg["emb_dim"]),
        hidden=int(cfg["hidden"]),
        pos_dim=int(cfg["pos_dim"]),
        window=int(cfg["window"]),
    ).to(device)
    logz = torch.nn.Parameter(torch.zeros(1, device=device))

    opt = torch.optim.AdamW(
        [{"params": fnet.parameters(), "lr": float(cfg["lr_pf"])},
         {"params": [logz], "lr": float(cfg["lr_logz"])}]
    )

    epochs = int(cfg["epochs"])
    log_every = int(cfg.get("log_every", 10))
    save_every = int(cfg.get("save_every", 1000))

    def tb_loss():
        gfn_log_r = forward_reward(env, fnet, logz)
        log_r = env.log_reward()
        return (gfn_log_r - log_r).pow(2).mean(), log_r

    t0 = time.time()

    for epoch in range(epochs + 1):
        env.reset()
        opt.zero_grad(set_to_none=True)
        loss, log_r = tb_loss()
        loss.backward()
        opt.step()

        if epoch % log_every == 0:
            row = {
                "epoch": epoch,
                "time_sec": time.time() - t0,
                "loss_total": float(loss.item()),
                "tb_loss": float(loss.item()),
                "logz_main": float(logz.item()),
                "avg_logR": float(log_r.mean().item()),
                "max_logR": float(log_r.max().item()),
                "run_id": cfg.get("run_id"),
            }
            log_jsonl(metrics_path, row)
            print(f"[{cfg.get('run_id','?')}] [{epoch:5d}] tb_loss={row['tb_loss']:.4f} logz={row['logz_main']:.3f}")

        if epoch % save_every == 0:
            payload = {
                "epoch": epoch,
                "cfg": cfg,
                "fnet": fnet.state_dict(),
                "logz": float(logz.item()),
                "opt": opt.state_dict(),
                "rng_torch": torch.random.get_rng_state(),
            }
            save_ckpt(run_dir, epoch=epoch, tag="epoch", payload=payload)

            # TODO: save samples

    payload = {
        "epoch": epochs,
        "cfg": cfg,
        "fnet": fnet.state_dict(),
        "logz": float(logz.item()),
        "opt": opt.state_dict(),
        "rng_torch": torch.random.get_rng_state(),
    }
    save_ckpt(run_dir, epoch=epochs, tag="latest", payload=payload)


def main():
    args = parse_args()
    peptide_dir = Path(__file__).resolve().parents[1]  # .../peptide
    seq_size = load_seq_size(peptide_dir)

    runs = load_runs(args.config, run_id=args.run if not args.all else None)
    for cfg in runs:
        run_one(cfg, seq_size=seq_size)


if __name__ == "__main__":
    main()