# peptide/baselines/sa_train.py
from __future__ import annotations

import argparse
import json
import time
from pathlib import Path

import torch
import torch.nn as nn

from util import set_seed
from ..peptide_env import Sequences, Policy
from ..peptide_reward import LogReward
from ..peptide_sampling import forward_reward
from ..peptide_utils import load_runs, new_run_dir, log_jsonl, save_ckpt
from .base_core import sa_double_loss


def load_seq_size(peptide_dir: Path) -> int:
    vocab_path = peptide_dir / "rf_models" / "encoders" / "vocab.json"
    with open(vocab_path, "r", encoding="utf-8") as f:
        return json.load(f)["seq_size"]


def parse_args():
    p = argparse.ArgumentParser()
    p.add_argument("--config", type=str, required=True)
    g = p.add_mutually_exclusive_group(required=True)
    g.add_argument("--run", type=str)
    g.add_argument("--all", action="store_true")
    return p.parse_args()


def run_one(cfg: dict, seq_size: int):
    cfg = dict(cfg)
    cfg["seq_size"] = seq_size

    run_dir = new_run_dir(exp="peptide", method="sa", cfg=cfg, out_root="runs")
    metrics_path = run_dir / "metrics.jsonl"

    device = cfg.get("device", "cpu")
    seed = int(cfg["seed"])
    reward_alpha = float(cfg["reward_alpha"])
    set_seed(seed)

    log_reward = LogReward(cutoff=float(cfg["cut_off"]))
    env = Sequences(
        seq_size=seq_size,
        batch_size=int(cfg["batch_size"])//2,
        log_reward=log_reward,
        eps=float(cfg["eps"]),
        seed=seed,
    )

    # main
    fnet = Policy(
        emb_dim=int(cfg["emb_dim"]),
        hidden=int(cfg["hidden"]),
        pos_dim=int(cfg["pos_dim"]),
        window=int(cfg["window"]),
    ).to(device)
    logz = torch.nn.Parameter(torch.zeros(1, device=device))

    opt = torch.optim.AdamW(
        [{"params": fnet.parameters(), "lr": float(cfg["lr_pf"])},
         {"params": [logz], "lr": float(cfg["lr_logz"])}]
    )

    # SA
    sa_fnet = Policy(
        emb_dim=int(cfg["emb_dim"]),
        hidden=int(cfg["hidden"]),
        pos_dim=int(cfg["pos_dim"]),
        window=int(cfg["window"]),
    ).to(device)
    sa_logz = torch.nn.Parameter(torch.zeros(1, device=device))

    sa_opt = torch.optim.AdamW(
        [{"params": sa_fnet.parameters(), "lr": float(cfg.get("sa_lr_pf"))},
         {"params": [sa_logz], "lr": float(cfg.get("sa_lr_logz"))}]
    )

    # RND
    rnd = nn.Sequential(
        nn.Linear(seq_size, 64), nn.ReLU(),
        nn.Linear(64, 64), nn.ReLU(),
        nn.Linear(64, 1),
    )
    rnd_tgt = nn.Sequential(
        nn.Linear(seq_size, 64), nn.ReLU(),
        nn.Linear(64, 64), nn.ReLU(),
        nn.Linear(64, 1),
    )
    rnd_tgt.requires_grad_(False)

    rnd_opt = torch.optim.AdamW([{"params": rnd.parameters(), "lr": float(cfg["lr_pf"])}])

    def tb_loss_on():
        gfn_log_r = forward_reward(env, fnet, logz)
        log_r = env.log_reward()
        return (gfn_log_r - log_r).pow(2).mean(), log_r

    epochs = int(cfg["epochs"])
    log_every = int(cfg.get("log_every", 10))
    save_every = int(cfg.get("save_every", 1000))

    t0 = time.time()

    for epoch in range(epochs + 1):
        # 1) SA/off-policy loss
        env.reset()
        opt.zero_grad(set_to_none=True)
        sa_opt.zero_grad(set_to_none=True)
        rnd_opt.zero_grad(set_to_none=True)

        sa_tb_loss, tb_loss_off, rnd_loss = sa_double_loss(env, sa_fnet, sa_logz, fnet, logz, rnd, rnd_tgt, beta_e=reward_alpha)

        # update SA (fix ordering: backward -> step)
        sa_tb_loss.backward(retain_graph=True)
        sa_opt.step()

        # 2) main TB update using on+off
        opt.zero_grad(set_to_none=True)
        env.reset()
        tb_on, log_r = tb_loss_on()

        tb_total = 0.5 * (tb_on + tb_loss_off)
        tb_total.backward(retain_graph=True)
        opt.step()

        # 3) RND update
        rnd_opt.zero_grad(set_to_none=True)
        rnd_loss.backward()
        rnd_opt.step()

        if epoch % log_every == 0:
            row = {
                "epoch": epoch,
                "time_sec": time.time() - t0,
                "loss_total": float(tb_total.item()),
                "tb_loss": float(tb_total.item()),
                "tb_loss_on": float(tb_on.item()),
                "tb_loss_off": float(tb_loss_off.item()),
                "sa_tb_loss": float(sa_tb_loss.item()),
                "rnd_loss": float(rnd_loss.item()),
                "logz_main": float(logz.item()),
                "logz_aux": float(sa_logz.item()),
                "avg_logR": float(log_r.mean().item()),
                "max_logR": float(log_r.max().item()),
                "run_id": cfg.get("run_id"),
            }
            log_jsonl(metrics_path, row)
            print(
                f"[{cfg.get('run_id','?')}] [{epoch:5d}] "
                f"tb={row['tb_loss']:.4f} on={row['tb_loss_on']:.4f} off={row['tb_loss_off']:.4f} "
                f"sa={row['sa_tb_loss']:.4f} rnd={row['rnd_loss']:.4f}"
            )

        if epoch % save_every == 0:
            payload = {
                "epoch": epoch,
                "cfg": cfg,
                "fnet": fnet.state_dict(),
                "logz": float(logz.item()),
                "sa_fnet": sa_fnet.state_dict(),
                "sa_logz": float(sa_logz.item()),
                "rnd": rnd.state_dict(),
                "opt": opt.state_dict(),
                "sa_opt": sa_opt.state_dict(),
                "rnd_opt": rnd_opt.state_dict(),
                "rng_torch": torch.random.get_rng_state(),
            }
            save_ckpt(run_dir, epoch=epoch, tag="epoch", payload=payload)

    payload = {
        "epoch": epochs,
        "cfg": cfg,
        "fnet": fnet.state_dict(),
        "logz": float(logz.item()),
        "sa_fnet": sa_fnet.state_dict(),
        "sa_logz": float(sa_logz.item()),
        "rnd": rnd.state_dict(),
        "opt": opt.state_dict(),
        "sa_opt": sa_opt.state_dict(),
        "rnd_opt": rnd_opt.state_dict(),
        "rng_torch": torch.random.get_rng_state(),
    }
    save_ckpt(run_dir, epoch=epochs, tag="latest", payload=payload)


def main():
    args = parse_args()
    peptide_dir = Path(__file__).resolve().parents[1]
    seq_size = load_seq_size(peptide_dir)

    runs = load_runs(args.config, run_id=args.run if not args.all else None)
    for cfg in runs:
        run_one(cfg, seq_size=seq_size)


if __name__ == "__main__":
    main()