# peptide/train.py
from __future__ import annotations

import argparse
import json
import math
import time
from pathlib import Path

import torch
from torch.nn.functional import softplus

from util import set_seed
from .peptide_env import Sequences, Policy
from .peptide_reward import LogReward
from .peptide_sampling import forward_reward, dual_forward_reward
from .peptide_utils import load_runs, new_run_dir, log_jsonl, save_ckpt


def load_seq_size(peptide_dir: Path) -> int:
    vocab_path = peptide_dir / "rf_models" / "encoders" / "vocab.json"
    with open(vocab_path, "r", encoding="utf-8") as f:
        return json.load(f)["seq_size"]


def parse_args():
    p = argparse.ArgumentParser()
    p.add_argument("--config", type=str, required=True)
    g = p.add_mutually_exclusive_group(required=True)
    g.add_argument("--run", type=str)
    g.add_argument("--all", action="store_true")
    return p.parse_args()


def run_one(cfg: dict, seq_size: int):
    # attach seq_size to config BEFORE naming (so hash + config.json match)
    cfg = dict(cfg)
    cfg["seq_size"] = seq_size

    run_dir = new_run_dir(exp="peptide", method="dtb", cfg=cfg, out_root="runs")
    metrics_path = run_dir / "metrics.jsonl"

    device = cfg.get("device", "cpu")
    seed = int(cfg["seed"])
    set_seed(seed)

    log_reward = LogReward(cutoff=float(cfg["cut_off"]))
    env = Sequences(
        seq_size=seq_size,
        batch_size=int(cfg["batch_size"])//2,
        log_reward=log_reward,
        eps=float(cfg["eps"]),
        seed=seed,
    )

    # policy
    fnet = Policy(
        emb_dim=int(cfg["emb_dim"]),
        hidden=int(cfg["hidden"]),
        pos_dim=int(cfg["pos_dim"]),
        window=int(cfg["window"]),
    ).to(device)
    logz = torch.nn.Parameter(torch.zeros(1, device=device))

    # divergent policy
    div_fnet = Policy(
        emb_dim=int(cfg["emb_dim"]),
        hidden=int(cfg["hidden"]),
        pos_dim=int(cfg["pos_dim"]),
        window=int(cfg["window"]),
    ).to(device)
    div_logz = torch.nn.Parameter(torch.zeros(1, device=device))

    opt = torch.optim.AdamW(
        [{"params": fnet.parameters(), "lr": float(cfg["lr_pf"])},
         {"params": [logz], "lr": float(cfg["lr_logz"])}]
    )
    div_opt = torch.optim.AdamW(
        [{"params": div_fnet.parameters(), "lr": float(cfg["div_lr_pf"])},
         {"params": [div_logz], "lr": float(cfg["div_lr_logz"])}]
    )

    epochs = int(cfg["epochs"])
    log_every = int(cfg.get("log_every", 10))
    save_every = int(cfg.get("save_every", 1000))
    threshold = float(cfg["threshold"])
    reward_alpha = float(cfg["reward_alpha"])
    
    def tb_loss() -> torch.Tensor:
        gfn_log_r = forward_reward(env, fnet, logz)
        log_r = env.log_reward()
        return (gfn_log_r - log_r).pow(2).mean()

    t0 = time.time()

    for epoch in range(epochs + 1):
        # --- main TB ---
        env.reset()
        opt.zero_grad(set_to_none=True)
        main_loss = tb_loss()

        # --- dual forward / combined update ---
        env.reset()
        opt.zero_grad(set_to_none=True)
        div_opt.zero_grad(set_to_none=True)

        log_R_hat_div, log_R_hat = dual_forward_reward(env, div_fnet, fnet, div_logz, logz)
        log_R = env.log_reward()

        back_loss = (log_R_hat - log_R).pow(2).mean()

        with torch.no_grad():
            w1 = torch.sigmoid(logz - div_logz)

        loss_main = back_loss + (main_loss - back_loss) * w1
        loss_main.backward()
        opt.step()

        with torch.no_grad():
            mask = (log_R_hat - log_R).gt(math.log(threshold))
            known_frac = mask.float().mean().item()

        log_ratio = log_R_hat_div - reward_alpha*log_R
        div_loss = torch.where(mask, softplus(log_ratio), log_ratio).pow(2).mean()
        div_loss.backward()
        div_opt.step()

        if epoch % log_every == 0:
            row = {
                "epoch": epoch,
                "time_sec": time.time() - t0,
                "loss_main": float(loss_main.item()),
                "main_loss_tb": float(main_loss.item()),
                "back_loss": float(back_loss.item()),
                "div_loss": float(div_loss.item()),
                "known_frac": float(known_frac),
                "logz_main": float(logz.item()),
                "logz_aux": float(div_logz.item()),
                "avg_logR": float(log_R.mean().item()),
                "max_logR": float(log_R.max().item()),
                "run_id": cfg.get("run_id"),
            }
            log_jsonl(metrics_path, row)
            print(
                f"[{cfg.get('run_id','?')}] [{epoch:5d}] "
                f"loss_main={row['loss_main']:.4f} div={row['div_loss']:.4f} "
                f"known={row['known_frac']:.3f} logz={row['logz_main']:.3f}"
            )

        if epoch % save_every == 0:
            payload = {
                "epoch": epoch,
                "cfg": cfg,
                "fnet": fnet.state_dict(),
                "logz": float(logz.item()),
                "div_fnet": div_fnet.state_dict(),
                "div_logz": float(div_logz.item()),
                "opt": opt.state_dict(),
                "div_opt": div_opt.state_dict(),
                "rng_torch": torch.random.get_rng_state(),
            }
            save_ckpt(run_dir, epoch=epoch, tag="epoch", payload=payload)

    payload = {
        "epoch": epochs,
        "cfg": cfg,
        "fnet": fnet.state_dict(),
        "logz": float(logz.item()),
        "div_fnet": div_fnet.state_dict(),
        "div_logz": float(div_logz.item()),
        "opt": opt.state_dict(),
        "div_opt": div_opt.state_dict(),
        "rng_torch": torch.random.get_rng_state(),
    }
    save_ckpt(run_dir, epoch=epochs, tag="latest", payload=payload)


def main():
    args = parse_args()
    peptide_dir = Path(__file__).resolve().parent
    seq_size = load_seq_size(peptide_dir)

    runs = load_runs(args.config, run_id=args.run if not args.all else None)
    for cfg in runs:
        run_one(cfg, seq_size=seq_size)


if __name__ == "__main__":
    main()