from __future__ import annotations

import argparse
import time

import torch
import torch.nn as nn

from util import set_seed
from ..diffusionGrid_env import DiffGrid
from ..diffusionGrid_losses import tb
from ..diffusionGrid_rewards import build_log_reward_fn
from ..diffusionGrid_nets import FourierTimePolicy
from ..diffusionGrid_sampling import marginal_log_reward, backward_reward
from ..diffusionGrid_util import plot_epoch_panels, plot_epoch_panels_v2, plot_epoch_panels_v3
from ..diffusionGrid_utils import load_runs, new_run_dir, log_jsonl, save_ckpt
from .base_core import sa_double_loss


def parse_args():
    p = argparse.ArgumentParser()
    g = p.add_mutually_exclusive_group(required=True)
    g.add_argument("--run", type=str)
    g.add_argument("--all", action="store_true")
    return p.parse_args()


@torch.no_grad()
def l1_tv_distance(eval_env: DiffGrid, fnet, bnet, logz) -> float:
    eval_env.set_full_grid_T()
    log_r_hat = marginal_log_reward(eval_env, fnet, bnet, logz, batch=10)
    model_r = log_r_hat.exp()
    true_r = eval_env.log_reward().exp()
    dist_model = model_r / model_r.sum()
    dist_true = true_r / true_r.sum()
    return float((dist_model - dist_true).abs().sum().item() / 2.0)


def _make_rnd(in_dim: int = 3, hid: int = 64, multiplier: float = 1) -> nn.Module:
    return multiplier * nn.Sequential(
        nn.Linear(in_dim, hid), nn.ReLU(),
        nn.Linear(hid, hid), nn.ReLU(),
        nn.Linear(hid, 1),
    )


def run_one(cfg: dict):
    cfg = dict(cfg)

    run_dir = new_run_dir(exp="diffusionGrid", method="sa", cfg=cfg, out_root="runs")
    metrics_path = run_dir / "metrics.jsonl"
    eval_path = run_dir / "eval.jsonl"
    fig_dir = run_dir / "figures"

    device = cfg.get("device", "cpu")
    seed = int(cfg.get("seed", 42))
    set_seed(seed)

    reward_kind = str(cfg.get("reward_kind", "8g"))
    reward_mult = cfg.get("reward_multiplier", 1)
    size = int(cfg.get("size", 15))
    batch_size = int(cfg.get("batch_size", 512)) // 2
    eps = float(cfg.get("eps", 0.1))
    reward_alpha = float(cfg.get("reward_alpha"))
    
    epochs = int(cfg.get("epochs", 4000))
    log_every = int(cfg.get("log_every", 10))
    eval_every = int(cfg.get("eval_every", 100))
    save_every = int(cfg.get("save_every", 1000))
    marginal_batch = int(cfg.get("marginal_batch", 15))

    # lrs main
    lr_pf = float(cfg.get("lr_pf", 5e-3))
    lr_pb = float(cfg.get("lr_pb", 5e-3))
    lr_logz = float(cfg.get("lr_logz", 5e-2))

    # lrs sibling (sa)
    sa_lr_pf = float(cfg.get("sa_lr_pf", 5e-4))
    sa_lr_pb = float(cfg.get("sa_lr_pb", 1e-3))
    sa_lr_logz = float(cfg.get("sa_lr_logz", 5e-3))

    log_reward_fn = build_log_reward_fn(reward_kind, size=size, multiplier=reward_mult)
    env = DiffGrid(size=size, batch_size=batch_size, log_reward=log_reward_fn, seed=seed, eps=eps)
    eval_env = DiffGrid(size=size, batch_size=2 * batch_size, log_reward=log_reward_fn, seed=seed, eps=eps)

    # RND
    rnd = _make_rnd(in_dim=int(cfg.get("rnd_in_dim", 3)), hid=int(cfg.get("rnd_hid", 64))).to(device)
    rnd_tgt = _make_rnd(in_dim=int(cfg.get("rnd_in_dim", 3)), hid=int(cfg.get("rnd_hid", 64))).to(device)
    rnd_tgt.requires_grad_(False)
    rnd_tgt.eval()

    # nets
    fnet = FourierTimePolicy(
        hidden_dim=int(cfg.get("hidden_dim_f", 128)),
        num_layers=int(cfg.get("num_layers_f", 3)),
        n_freq=int(cfg.get("n_freq_f", 8)),
    ).to(device)
    bnet = FourierTimePolicy(
        hidden_dim=int(cfg.get("hidden_dim_b", 128)),
        num_layers=int(cfg.get("num_layers_b", 1)),
        n_freq=int(cfg.get("n_freq_b", 8)),
    ).to(device)
    logz = torch.nn.Parameter(torch.zeros(1, device=device))

    sa_fnet = FourierTimePolicy(
        hidden_dim=int(cfg.get("sa_hidden_dim_f", 64)),
        num_layers=int(cfg.get("sa_num_layers_f", 2)),
        n_freq=int(cfg.get("sa_n_freq_f", 16)),
    ).to(device)
    sa_bnet = FourierTimePolicy(
        hidden_dim=int(cfg.get("sa_hidden_dim_b", 64)),
        num_layers=int(cfg.get("sa_num_layers_b", 2)),
        n_freq=int(cfg.get("sa_n_freq_b", 8)),
    ).to(device)
    sa_logz = torch.nn.Parameter(torch.zeros(1, device=device))

    opt = torch.optim.AdamW(
        [{"params": fnet.parameters(), "lr": lr_pf},
         {"params": bnet.parameters(), "lr": lr_pb},
         {"params": [logz], "lr": lr_logz}]
    )
    sa_opt = torch.optim.AdamW(
        [{"params": sa_fnet.parameters(), "lr": sa_lr_pf},
         {"params": sa_bnet.parameters(), "lr": sa_lr_pb},
         {"params": [sa_logz], "lr": sa_lr_logz}]
    )

    rnd_opt = torch.optim.AdamW([{"params": rnd.parameters(), "lr": float(cfg.get("rnd_lr", lr_pf))}])

    sch = torch.optim.lr_scheduler.LinearLR(opt, start_factor=1.0, end_factor=0.1, total_iters=epochs)
    div_sch = torch.optim.lr_scheduler.LinearLR(sa_opt, start_factor=1.0, end_factor=0.1, total_iters=epochs)
    rnd_sch = torch.optim.lr_scheduler.LinearLR(rnd_opt, start_factor=1.0, end_factor=0.1, total_iters=epochs)
    t0 = time.time()

    sa_samples = None
    samples = None

    for epoch in range(epochs + 1):
        # -----------------------------------------
        # Step A: sibling update + off-policy TB + rnd loss (shared graph)
        # -----------------------------------------
        env.reset()
        opt.zero_grad(set_to_none=True)
        sa_opt.zero_grad(set_to_none=True)

        sa_tb_loss, tb_loss_off, rnd_loss = sa_double_loss(
            env,
            sa_fnet, sa_bnet, sa_logz,
            fnet, bnet, logz,
            rnd, rnd_tgt, beta_e=reward_alpha
        )
        sa_samples = env.pos.clone()

        sa_tb_loss.backward(retain_graph=True)
        sa_opt.step()
        div_sch.step()

        # -----------------------------------------
        # Step B: on-policy TB update
        # -----------------------------------------
        env.reset()
        tb_loss_on = tb(fnet, bnet, logz, env)
        samples = env.pos.clone()

        tb_loss_total = 0.5 * (tb_loss_on + tb_loss_off)
        tb_loss_total.backward(retain_graph=True)
        opt.step()
        sch.step()

        # -----------------------------------------
        # Step C: RND update
        # -----------------------------------------
        rnd_opt.zero_grad(set_to_none=True)
        rnd_loss.backward()
        rnd_opt.step()
        rnd_sch.step()

        if epoch % log_every == 0:
            row = {
                "epoch": epoch,
                "time_sec": time.time() - t0,
                "tb_loss_on": float(tb_loss_on.item()),
                "tb_loss_off": float(tb_loss_off.item()),
                "tb_loss": float(tb_loss_total.item()),
                "sa_tb_loss": float(sa_tb_loss.item()),
                "rnd_loss": float(rnd_loss.item()),
                "logz_main": float(logz.item()),
                "logz_aux": float(sa_logz.item()),
                "Z_main": float(logz.exp().item()),
                "Z_aux": float(sa_logz.exp().item()),
                "run_id": cfg.get("run_id"),
                "seed": seed,
            }
            log_jsonl(metrics_path, row)
            print(
                f"[{cfg.get('run_id','?')}] [seed={seed}] [{epoch:5d}] "
                f"tb={row['tb_loss']:.4f} sa={row['sa_tb_loss']:.4f} rnd={row['rnd_loss']:.4f}"
            )

        if epoch % eval_every == 0:
            eval_env.set_full_grid_T()
            log_p_hat = marginal_log_reward(eval_env, fnet, bnet, 0, batch=marginal_batch)
            log_p_hat_sa = marginal_log_reward(eval_env, sa_fnet, sa_bnet, 0, batch=marginal_batch)
            plot_epoch_panels_v3(eval_env, log_p_hat, log_p_hat_sa, epoch, out_dir=fig_dir)
            # plot_epoch_panels_v2(eval_env, log_p_hat, log_p_hat_sa, epoch, out_dir=fig_dir)
            # log_r_hat = marginal_log_reward(eval_env, fnet, bnet, logz, batch=marginal_batch)
            # plot_epoch_panels(
            #     eval_env,
            #     log_r_hat,
            #     div_samples=sa_samples,
            #     samples=samples,
            #     epoch=epoch,
            #     out_dir=fig_dir,
            # )

            easy_mask = (eval_env.pos.pow(2).sum(dim=-1) < 9**2)
            hard_mask = (eval_env.pos.pow(2).sum(dim=-1) >= 9**2)
            log_r = eval_env.log_reward()
            loss_abs = (backward_reward(eval_env, fnet, bnet, logz) - log_r).abs()
            easy_loss = loss_abs[easy_mask].mean().item()
            hard_loss = loss_abs[hard_mask].mean().item()
            l1 = l1_tv_distance(eval_env, fnet, bnet, logz)

            log_jsonl(
                eval_path,
                {
                    "epoch": epoch,
                    "l1_tv": l1,
                    "easy_pos_loss": easy_loss,
                    "hard_pos_loss": hard_loss,
                    "run_id": cfg.get("run_id"),
                    "seed": seed,
                },
            )

        if epoch % save_every == 0:
            payload = {
                "epoch": epoch,
                "cfg": cfg,
                "fnet": fnet.state_dict(),
                "bnet": bnet.state_dict(),
                "logz": float(logz.item()),
                "sa_fnet": sa_fnet.state_dict(),
                "sa_bnet": sa_bnet.state_dict(),
                "sa_logz": float(sa_logz.item()),
                "rnd": rnd.state_dict(),
                "rnd_tgt": rnd_tgt.state_dict(),
                "opt": opt.state_dict(),
                "sa_opt": sa_opt.state_dict(),
                "rnd_opt": rnd_opt.state_dict(),
                "rng_torch": torch.random.get_rng_state(),
            }
            save_ckpt(run_dir, epoch=epoch, tag="epoch", payload=payload)

    payload = {
        "epoch": epochs,
        "cfg": cfg,
        "fnet": fnet.state_dict(),
        "bnet": bnet.state_dict(),
        "logz": float(logz.item()),
        "sa_fnet": sa_fnet.state_dict(),
        "sa_bnet": sa_bnet.state_dict(),
        "sa_logz": float(sa_logz.item()),
        "rnd": rnd.state_dict(),
        "rnd_tgt": rnd_tgt.state_dict(),
        "opt": opt.state_dict(),
        "sa_opt": sa_opt.state_dict(),
        "rnd_opt": rnd_opt.state_dict(),
        "rng_torch": torch.random.get_rng_state(),
    }
    save_ckpt(run_dir, epoch=epochs, tag="latest", payload=payload)


def main():
    args = parse_args()
    runs = load_runs("diffusionGrid/experiments.toml", run_id=args.run if not args.all else None)
    for cfg in runs:
        run_one(cfg)


if __name__ == "__main__":
    main()