from __future__ import annotations

import argparse
import time

import torch

from util import set_seed
from ..diffusionGrid_env import DiffGrid
from ..diffusionGrid_rewards import build_log_reward_fn
from ..diffusionGrid_nets import FourierTimePolicy
from ..diffusionGrid_sampling import marginal_log_reward, backward_reward
from ..diffusionGrid_util import plot_epoch_panels, plot_epoch_panels_v3
from ..diffusionGrid_utils import load_runs, new_run_dir, log_jsonl, save_ckpt
from .base_core import dual_forward_trajectory_log_prob


def parse_args():
    p = argparse.ArgumentParser()
    g = p.add_mutually_exclusive_group(required=True)
    g.add_argument("--run", type=str)
    g.add_argument("--all", action="store_true")
    return p.parse_args()


@torch.no_grad()
def l1_tv_distance(eval_env: DiffGrid, fnet, bnet, logz) -> float:
    eval_env.set_full_grid_T()
    log_r_hat = marginal_log_reward(eval_env, fnet, bnet, logz, batch=10)
    model_r = log_r_hat.exp()
    true_r = eval_env.log_reward().exp()
    dist_model = model_r / model_r.sum()
    dist_true = true_r / true_r.sum()
    return float((dist_model - dist_true).abs().sum().item() / 2.0)


def run_one(cfg: dict):
    cfg = dict(cfg)

    run_dir = new_run_dir(exp="diffusionGrid", method="teacher_student", cfg=cfg, out_root="runs")
    metrics_path = run_dir / "metrics.jsonl"
    eval_path = run_dir / "eval.jsonl"
    fig_dir = run_dir / "figures"

    device = cfg.get("device", "cpu")
    seed = int(cfg.get("seed", 42))
    set_seed(seed)

    reward_kind = str(cfg.get("reward_kind", "8g"))
    size = int(cfg.get("size", 15))
    batch_size = int(cfg.get("batch_size", 512))
    eps = float(cfg.get("eps", 0.1))

    # paper params (same spirit as peptide TS)
    teacher_C = float(cfg.get("teacher_C", 19.0))
    teacher_eps = float(cfg.get("teacher_eps", 1e-8))
    teacher_alpha = float(cfg.get("reward_alpha"))

    epochs = int(cfg.get("epochs", 4000))
    log_every = int(cfg.get("log_every", 10))
    eval_every = int(cfg.get("eval_every", 100))
    save_every = int(cfg.get("save_every", 1000))
    marginal_batch = int(cfg.get("marginal_batch", 15))

    # lrs
    lr_pf = float(cfg.get("lr_pf", 5e-3))
    lr_pb = float(cfg.get("lr_pb", 5e-3))
    lr_logz = float(cfg.get("lr_logz", 5e-2))
    teacher_lr_pf = float(cfg.get("teacher_lr_pf", 5e-4))
    teacher_lr_pb = float(cfg.get("teacher_lr_pb", 1e-3))
    teacher_lr_logz = float(cfg.get("teacher_lr_logz", 5e-3))

    log_reward_fn = build_log_reward_fn(reward_kind, size=size)
    env = DiffGrid(size=size, batch_size=batch_size, log_reward=log_reward_fn, seed=seed, eps=eps)
    eval_env = DiffGrid(size=size, batch_size=batch_size, log_reward=log_reward_fn, seed=seed, eps=eps)

    # -----------------------
    # Student (main) + Teacher (aux)
    # -----------------------
    student_fnet = FourierTimePolicy(
        hidden_dim=int(cfg.get("student_hidden_dim", 64)),
        num_layers=int(cfg.get("student_num_layers_f", 2)),
        n_freq=int(cfg.get("student_n_freq_f", 16)),
    ).to(device)
    student_bnet = FourierTimePolicy(
        hidden_dim=int(cfg.get("student_hidden_dim", 64)),
        num_layers=int(cfg.get("student_num_layers_b", 2)),
        n_freq=int(cfg.get("student_n_freq_b", 8)),
    ).to(device)
    student_logz = torch.nn.Parameter(torch.zeros(1, device=device))

    teacher_fnet = FourierTimePolicy(
        hidden_dim=int(cfg.get("teacher_hidden_dim", 128)),
        num_layers=int(cfg.get("teacher_num_layers_f", 3)),
        n_freq=int(cfg.get("teacher_n_freq_f", 8)),
    ).to(device)
    teacher_bnet = FourierTimePolicy(
        hidden_dim=int(cfg.get("teacher_hidden_dim", 128)),
        num_layers=int(cfg.get("teacher_num_layers_b", 1)),
        n_freq=int(cfg.get("teacher_n_freq_b", 8)),
    ).to(device)
    teacher_logz = torch.nn.Parameter(torch.zeros(1, device=device))

    student_opt = torch.optim.AdamW(
        [
            {"params": student_fnet.parameters(), "lr": lr_pf},
            {"params": student_bnet.parameters(), "lr": lr_pb},
            {"params": [student_logz], "lr": lr_logz},
        ]
    )
    teacher_opt = torch.optim.AdamW(
        [
            {"params": teacher_fnet.parameters(), "lr": teacher_lr_pf},
            {"params": teacher_bnet.parameters(), "lr": teacher_lr_pb},
            {"params": [teacher_logz], "lr": teacher_lr_logz},
        ]
    )

    sch_student = torch.optim.lr_scheduler.LinearLR(student_opt, start_factor=1.0, end_factor=0.1, total_iters=epochs)
    sch_teacher = torch.optim.lr_scheduler.LinearLR(teacher_opt, start_factor=1.0, end_factor=0.1, total_iters=epochs)

    student_samples = env.pos
    teacher_samples = env.pos

    t0 = time.time()

    for epoch in range(epochs + 1):
        env.reset()
        student_opt.zero_grad(set_to_none=True)
        teacher_opt.zero_grad(set_to_none=True)

        # alterna behavior policy
        if epoch % 2 == 0:
            logratio_student, logratio_teacher = dual_forward_trajectory_log_prob(
                env, student_fnet, student_bnet, teacher_fnet, teacher_bnet
            )
            student_samples = env.pos.clone().detach()
        else:
            logratio_teacher, logratio_student = dual_forward_trajectory_log_prob(
                env, teacher_fnet, teacher_bnet, student_fnet, student_bnet
            )
            teacher_samples = env.pos.clone().detach()

        logR = env.log_reward()  # [B]

        # Student TB residual (assumindo logratio = logPF - logPB)
        student_res = (student_logz + logratio_student - logR)
        student_loss = student_res.pow(2).mean()

        # Teacher reward from student discrepancy (paper), stop-grad
        with torch.no_grad():
            delta = (logR - (student_logz + logratio_student))  # = -student_res
            w = 1.0 + teacher_C * (delta > 0).float()
            logR_teacher = torch.log(teacher_eps + w * delta.pow(2)) + teacher_alpha * logR

            pos_delta_frac = float((delta > 0).float().mean().item())
            avg_delta = float(delta.mean().item())

        teacher_res = (teacher_logz + logratio_teacher - logR_teacher)
        teacher_loss = teacher_res.pow(2).mean()

        student_loss.backward()
        student_opt.step()
        sch_student.step()

        teacher_loss.backward()
        teacher_opt.step()
        sch_teacher.step()

        if epoch % log_every == 0:
            row = {
                "epoch": epoch,
                "time_sec": time.time() - t0,
                "student_tb_loss": float(student_loss.item()),
                "teacher_tb_loss": float(teacher_loss.item()),
                "logz_main": float(student_logz.item()),
                "logz_aux": float(teacher_logz.item()),
                "Z_main": float(student_logz.exp().item()),
                "Z_aux": float(teacher_logz.exp().item()),
                "avg_logR": float(logR.mean().item()),
                "max_logR": float(logR.max().item()),
                "avg_delta": avg_delta,
                "pos_delta_frac": pos_delta_frac,
                "run_id": cfg.get("run_id"),
                "seed": seed,
            }
            log_jsonl(metrics_path, row)
            print(
                f"[{cfg.get('run_id','?')}] [seed={seed}] [{epoch:5d}] "
                f"studentTB={row['student_tb_loss']:.4f} teacherTB={row['teacher_tb_loss']:.4f} "
                f"logz={row['logz_main']:.3f} posδ={row['pos_delta_frac']:.3f}"
            )

        if epoch % eval_every == 0:
            eval_env.set_full_grid_T()
            log_p_hat = marginal_log_reward(eval_env, student_fnet, student_bnet, 0, batch=marginal_batch)
            log_p_hat_teacher = marginal_log_reward(eval_env, teacher_fnet, teacher_bnet, 0, batch=marginal_batch)
            plot_epoch_panels_v3(eval_env, log_p_hat, log_p_hat_teacher, epoch, out_dir=fig_dir)
            # log_r_hat = marginal_log_reward(eval_env, student_fnet, student_bnet, student_logz, batch=marginal_batch)
            # plot_epoch_panels(
            #     eval_env,
            #     log_r_hat,
            #     div_samples=teacher_samples,
            #     samples=student_samples,
            #     epoch=epoch,
            #     out_dir=fig_dir,
            # )

            easy_mask = (eval_env.pos.pow(2).sum(dim=-1) < 9**2)
            hard_mask = (eval_env.pos.pow(2).sum(dim=-1) >= 9**2)
            log_r = eval_env.log_reward()
            loss = (backward_reward(eval_env, student_fnet, student_bnet, student_logz) - log_r).abs()
            easy_loss = loss[easy_mask].mean().item()
            hard_loss = loss[hard_mask].mean().item()
            l1 = l1_tv_distance(eval_env, student_fnet, student_bnet, student_logz)

            log_jsonl(
                eval_path,
                {
                    "epoch": epoch,
                    "l1_tv": l1,
                    "easy_pos_loss": easy_loss,
                    "hard_pos_loss": hard_loss,
                    "run_id": cfg.get("run_id"),
                    "seed": seed,
                },
            )

        if epoch % save_every == 0:
            payload = {
                "epoch": epoch,
                "cfg": cfg,
                "student_fnet": student_fnet.state_dict(),
                "student_bnet": student_bnet.state_dict(),
                "student_logz": float(student_logz.item()),
                "teacher_fnet": teacher_fnet.state_dict(),
                "teacher_bnet": teacher_bnet.state_dict(),
                "teacher_logz": float(teacher_logz.item()),
                "student_opt": student_opt.state_dict(),
                "teacher_opt": teacher_opt.state_dict(),
                "rng_torch": torch.random.get_rng_state(),
            }
            save_ckpt(run_dir, epoch=epoch, tag="epoch", payload=payload)

    payload = {
        "epoch": epochs,
        "cfg": cfg,
        "student_fnet": student_fnet.state_dict(),
        "student_bnet": student_bnet.state_dict(),
        "student_logz": float(student_logz.item()),
        "teacher_fnet": teacher_fnet.state_dict(),
        "teacher_bnet": teacher_bnet.state_dict(),
        "teacher_logz": float(teacher_logz.item()),
        "student_opt": student_opt.state_dict(),
        "teacher_opt": teacher_opt.state_dict(),
        "rng_torch": torch.random.get_rng_state(),
    }
    save_ckpt(run_dir, epoch=epochs, tag="latest", payload=payload)


def main():
    args = parse_args()
    runs = load_runs("diffusionGrid/experiments.toml", run_id=args.run if not args.all else None)
    for cfg in runs:
        run_one(cfg)


if __name__ == "__main__":
    main()