import os
import json
import math
from pathlib import Path
from typing import Optional

import wandb
import torch
from torch.nn.functional import softplus

from .peptide_env import Sequences, Policy
from .peptide_reward import LogReward
from .peptide_sampling import forward_reward, dual_forward_reward
from util import set_seed


# --- Resolve seq_size from vocab.json ---
_HERE = Path(__file__).resolve().parent
_REPO_ROOT = _HERE
vocab_path = _REPO_ROOT / "rf_models" / "encoders" / "vocab.json"
with open(vocab_path, "r") as f:
    seq_size = json.load(f)["seq_size"]


def tb_loss(env, fnet, logz):
    gfn_log_r = forward_reward(env, fnet, logz)
    log_r = env.log_reward()
    return (gfn_log_r - log_r).pow(2).mean()


def log_checkpoint_artifact(
    *,
    epoch: int,
    fnet,
    logz,
    div_fnet,
    div_logz,
    opt=None,
    div_opt=None,
    is_best: bool = False,
    ckpt_dir: str = "checkpoints",
    artifact_name: Optional[str] = None,
):
    Path(ckpt_dir).mkdir(parents=True, exist_ok=True)
    ckpt_path = os.path.join(ckpt_dir, "model.pt")

    payload = {
        "epoch": epoch,
        "fnet": fnet.state_dict(),
        "logz": float(logz.item()),
        "div_fnet": div_fnet.state_dict(),
        "div_logz": float(div_logz.item()),
        "config": dict(wandb.config),
        "seq_size": seq_size,
    }
    if opt is not None:
        payload["opt"] = opt.state_dict()
    if div_opt is not None:
        payload["div_opt"] = div_opt.state_dict()

    torch.save(payload, ckpt_path)

    # Nome fixo => versões do mesmo artifact
    if artifact_name is None:
        artifact_name = f"peptide_gfn_{wandb.run.id}"

    aliases = ["latest", f"epoch_{epoch}"]
    if is_best:
        aliases.append("best")

    art = wandb.Artifact(name=artifact_name, type="model")
    art.add_file(ckpt_path)
    wandb.log_artifact(art, aliases=aliases)

    return ckpt_path


def train():
    defaults = dict(
        # --- Ambiente ---
        cut_off=0.96,
        epochs=2000,
        seed=42,
        batch_size=128,
        threshold=0.3,
        eps=0.1,
        device="cpu",

        # --- Otimização ---
        lr_pf=1e-2,
        lr_logz=1e-1,
        div_lr_pf=1e-3,
        div_lr_logz=1e-2,

        # --- Policy architecture (agora tunável via wandb) ---
        emb_dim=64,
        window=6,
        pos_dim=16,     # precisa ser par
        hidden=128,
        force_stop_on_full=True,

        # --- Logging / checkpoints ---
        log_every=10,
        save_every=200,          # salva artifact a cada N epochs
        save_on_best=False,      # salva também quando bater novo best em logz
        ckpt_dir="checkpoints",
        artifact_name=None,      # se quiser shared entre runs: "peptide_gfn_global"
    )

    with wandb.init(config=defaults) as run:
        config = run.config
        device = config.device

        if config.pos_dim % 2 != 0:
            raise ValueError("pos_dim deve ser par (para sinusoidal positional encoding).")

        set_seed(config.seed)

        log_reward = LogReward(cutoff=config.cut_off)
        env = Sequences(
            seq_size=seq_size,
            batch_size=config.batch_size,
            log_reward=log_reward,
            eps=config.eps,
            seed=config.seed,
        )

        # --- Models ---
        fnet = Policy(
            emb_dim=config.emb_dim,
            window=config.window,
            pos_dim=config.pos_dim,
            hidden=config.hidden,
            force_stop_on_full=config.force_stop_on_full,
        ).to(device)
        logz = torch.nn.Parameter(torch.zeros(1, device=device))

        div_fnet = Policy(
            emb_dim=config.emb_dim,
            window=config.window,
            pos_dim=config.pos_dim,
            hidden=config.hidden,
            force_stop_on_full=config.force_stop_on_full,
        ).to(device)
        div_logz = torch.nn.Parameter(torch.zeros(1, device=device))

        # --- Optimizers ---
        opt = torch.optim.AdamW([
            {"params": fnet.parameters(), "lr": config.lr_pf},
            {"params": [logz], "lr": config.lr_logz},
        ])

        div_opt = torch.optim.AdamW([
            {"params": div_fnet.parameters(), "lr": config.div_lr_pf},
            {"params": [div_logz], "lr": config.div_lr_logz},
        ])

        # garante que o eixo x seja epoch no dashboard
        wandb.define_metric("epoch")
        wandb.define_metric("*", step_metric="epoch")

        best_logz = float("-inf")

        for epoch in range(config.epochs + 1):
            # --- Main TB ---
            env.reset()
            opt.zero_grad()
            main_loss = tb_loss(env, fnet, logz)

            # --- Dual/divergence step ---
            env.reset()
            div_opt.zero_grad()
            opt.zero_grad()

            log_R_hat_div, log_R_hat = dual_forward_reward(env, div_fnet, fnet, div_logz, logz)
            log_R = env.log_reward()

            back_loss = (log_R_hat - log_R).pow(2).mean()
            with torch.no_grad():
                w1 = torch.sigmoid(logz - div_logz)

            loss = back_loss + (main_loss - back_loss) * w1
            loss.backward()
            opt.step()

            with torch.no_grad():
                mask = (log_R_hat - log_R).gt(math.log(config.threshold))

            div_loss = torch.where(
                mask,
                softplus(log_R_hat_div - log_R),
                log_R_hat_div - log_R
            ).pow(2).mean()

            div_loss.backward()
            div_opt.step()

            # --- Scalars ---
            tb_val = float(main_loss.item())
            div_val = float(div_loss.item())
            exploration_ratio = float(mask.float().mean().item())

            cur_logz = float(logz.item())
            cur_Z = float(logz.exp().item())
            cur_div_logz = float(div_logz.item())
            cur_div_Z = float(div_logz.exp().item())

            improved = cur_logz > best_logz
            if improved:
                best_logz = cur_logz

            # --- Log (métricas que você já salvava; + logz pra ficar explícito) ---
            if epoch % config.log_every == 0:
                wandb.log({
                    "epoch": epoch,
                    "loss/tb": tb_val,
                    "loss/div": div_val,
                    "params/logz": cur_logz,
                    "params/Z": cur_Z,
                    "params/div_logz": cur_div_logz,
                    "params/div_Z": cur_div_Z,
                    "metrics/exploration_ratio": exploration_ratio,
                })

                print(
                    f"Epoch {epoch}: loss={tb_val:.4f}, logz={cur_logz:.4f}, "
                    f"div_loss={div_val:.4f}, div_logz={cur_div_logz:.4f}, "
                    f"div_samples={exploration_ratio:.4f}"
                )

            # --- Save artifacts ---
            do_periodic_save = (config.save_every is not None) and (config.save_every > 0) and (epoch % config.save_every == 0)
            do_best_save = bool(config.save_on_best) and improved
            do_last_save = (epoch == config.epochs)

            if do_periodic_save or do_best_save or do_last_save:
                log_checkpoint_artifact(
                    epoch=epoch,
                    fnet=fnet, logz=logz,
                    div_fnet=div_fnet, div_logz=div_logz,
                    opt=opt, div_opt=div_opt,
                    is_best=bool(do_best_save),   # marca "best" quando for por melhora
                    ckpt_dir=config.ckpt_dir,
                    artifact_name=config.artifact_name,
                )


if __name__ == "__main__":
    train()