# peptide/baselines/teacher_student_train.py
from __future__ import annotations

import argparse
import json
import time
from pathlib import Path

import torch

from util import set_seed
from ..peptide_env import Sequences, Policy
from ..peptide_reward import LogReward
from ..peptide_sampling import dual_forward_traj_log_prob
from ..peptide_utils import load_runs, new_run_dir, log_jsonl, save_ckpt


def load_seq_size(peptide_dir: Path) -> int:
    vocab_path = peptide_dir / "rf_models" / "encoders" / "vocab.json"
    with open(vocab_path, "r", encoding="utf-8") as f:
        return json.load(f)["seq_size"]


def parse_args():
    p = argparse.ArgumentParser()
    p.add_argument("--config", type=str, required=True)
    g = p.add_mutually_exclusive_group(required=True)
    g.add_argument("--run", type=str)
    g.add_argument("--all", action="store_true")
    return p.parse_args()


def run_one(cfg: dict, seq_size: int):
    cfg = dict(cfg)
    cfg["seq_size"] = seq_size

    run_dir = new_run_dir(exp="peptide", method="teacher_student", cfg=cfg, out_root="runs")
    metrics_path = run_dir / "metrics.jsonl"

    device = cfg.get("device", "cpu")
    seed = int(cfg["seed"])
    set_seed(seed)

    # log R_teacher(x) = E_{tau~P_B(.|x)}[ log(eps + (1 + C 1_{delta>0}) delta^2) ] + alpha log R(x)
    teacher_C = float(cfg.get("teacher_C", 1.0))
    teacher_eps = float(cfg.get("teacher_eps", 1e-8))
    teacher_alpha = float(cfg.get("reward_alpha"))

    log_reward = LogReward(cutoff=float(cfg["cut_off"]))
    env = Sequences(
        seq_size=seq_size,
        batch_size=int(cfg["batch_size"]),
        log_reward=log_reward,
        eps=float(cfg["eps"]),
        seed=seed,
    )

    # -----------------------
    # Main model = STUDENT
    # Aux model  = TEACHER
    # -----------------------
    student_fnet = Policy(
        emb_dim=int(cfg["emb_dim"]),
        hidden=int(cfg["hidden"]),
        pos_dim=int(cfg["pos_dim"]),
        window=int(cfg["window"]),
    ).to(device)
    student_logz = torch.nn.Parameter(torch.zeros(1, device=device))

    teacher_fnet = Policy(
        emb_dim=int(cfg["emb_dim"]),
        hidden=int(cfg["hidden"]),
        pos_dim=int(cfg["pos_dim"]),
        window=int(cfg["window"]),
    ).to(device)
    teacher_logz = torch.nn.Parameter(torch.zeros(1, device=device))

    student_opt = torch.optim.AdamW(
        [{"params": student_fnet.parameters(), "lr": float(cfg["lr_pf"])},
         {"params": [student_logz], "lr": float(cfg["lr_logz"])}]
    )
    teacher_opt = torch.optim.AdamW(
        [{"params": teacher_fnet.parameters(), "lr": float(cfg.get("teacher_lr_pf"))},
         {"params": [teacher_logz], "lr": float(cfg.get("teacher_lr_logz"))}]
    )

    epochs = int(cfg["epochs"])
    log_every = int(cfg.get("log_every", 10))
    save_every = int(cfg.get("save_every", 1000))

    t0 = time.time()

    for epoch in range(epochs + 1):
        env.reset()
        student_opt.zero_grad(set_to_none=True)
        teacher_opt.zero_grad(set_to_none=True)

        if epoch % 2 == 0:
            # sample from student
            logp_student, logp_teacher = dual_forward_traj_log_prob(env, student_fnet, teacher_fnet, training=True)
        else:
            # sample from teacher
            logp_teacher, logp_student = dual_forward_traj_log_prob(env, teacher_fnet, student_fnet, training=True)

        logR = env.log_reward()  # [B] (log R(x))

        # --- Student TB loss ---
        # TB residual: (logZ + logPF - logR)
        student_res = (student_logz + logp_student - logR)  # [B]
        student_loss = student_res.pow(2).mean()

        # --- Teacher reward built from Student discrepancy (stop-grad on student) ---
        # delta = logR - (logZ + logPF)  (positive => undersampling)
        with torch.no_grad():
            delta = (logR - (student_logz + logp_student))  # [B]
            w = 1.0 + teacher_C * (delta > 0).float()
            logR_teacher = torch.log(teacher_eps + w * delta.pow(2)) + teacher_alpha * logR  # [B]

        # Teacher TB loss using its own forward flow against logR_teacher
        teacher_res = (teacher_logz + logp_teacher - logR_teacher)  # [B]
        teacher_loss = teacher_res.pow(2).mean()

        # Backprop (independente)
        student_loss.backward()
        student_opt.step()

        teacher_loss.backward()
        teacher_opt.step()

        if epoch % log_every == 0:
            row = {
                "epoch": epoch,
                "time_sec": time.time() - t0,
                "loss_total": float(student_loss.item()),
                "student_tb_loss": float(student_loss.item()),
                "teacher_tb_loss": float(teacher_loss.item()),
                "logz_main": float(student_logz.item()),
                "logz_aux": float(teacher_logz.item()),
                "avg_logR": float(logR.mean().item()),
                "max_logR": float(logR.max().item()),
                "avg_delta": float(delta.mean().item()),
                "pos_delta_frac": float((delta > 0).float().mean().item()),
                "run_id": cfg.get("run_id"),
            }
            log_jsonl(metrics_path, row)
            print(
                f"[{cfg.get('run_id','?')}] [{epoch:5d}] "
                f"studentTB={row['student_tb_loss']:.4f} teacherTB={row['teacher_tb_loss']:.4f} "
                f"logz={row['logz_main']:.3f} posδ={row['pos_delta_frac']:.3f}"
            )

        if epoch % save_every == 0:
            payload = {
                "epoch": epoch,
                "cfg": cfg,
                "student_fnet": student_fnet.state_dict(),
                "student_logz": float(student_logz.item()),
                "teacher_fnet": teacher_fnet.state_dict(),
                "teacher_logz": float(teacher_logz.item()),
                "student_opt": student_opt.state_dict(),
                "teacher_opt": teacher_opt.state_dict(),
                "rng_torch": torch.random.get_rng_state(),
            }
            save_ckpt(run_dir, epoch=epoch, tag="epoch", payload=payload)

    payload = {
        "epoch": epochs,
        "cfg": cfg,
        "student_fnet": student_fnet.state_dict(),
        "student_logz": float(student_logz.item()),
        "teacher_fnet": teacher_fnet.state_dict(),
        "teacher_logz": float(teacher_logz.item()),
        "student_opt": student_opt.state_dict(),
        "teacher_opt": teacher_opt.state_dict(),
        "rng_torch": torch.random.get_rng_state(),
    }
    save_ckpt(run_dir, epoch=epochs, tag="latest", payload=payload)


def main():
    args = parse_args()
    peptide_dir = Path(__file__).resolve().parents[1]
    seq_size = load_seq_size(peptide_dir)

    # Se você remover o campo "script" do TOML, ajuste aqui (ou mude load_runs).
    runs = load_runs(args.config, run_id=args.run if not args.all else None)
    for cfg in runs:
        run_one(cfg, seq_size=seq_size)


if __name__ == "__main__":
    main()