from __future__ import annotations

import argparse
import math
import time
from pathlib import Path

import torch
from torch.nn.functional import softplus

from util import set_seed
from .diffusionGrid_env import DiffGrid
from .diffusionGrid_losses import tb
from .diffusionGrid_rewards import build_log_reward_fn
from .diffusionGrid_nets import FourierTimePolicy
from .diffusionGrid_sampling import marginal_log_reward, forward_reward, backward_reward
from .diffusionGrid_util import plot_epoch_panels, plot_epoch_panels_v2, plot_epoch_panels_v3
from .diffusionGrid_utils import load_runs, new_run_dir, log_jsonl, save_ckpt


def parse_args():
    p = argparse.ArgumentParser()
    g = p.add_mutually_exclusive_group(required=True)
    g.add_argument("--run", type=str)
    g.add_argument("--all", action="store_true")
    return p.parse_args()


@torch.no_grad()
def l1_tv_distance(eval_env: DiffGrid, fnet, bnet, logz) -> float:
    eval_env.set_full_grid_T()
    log_r_hat = marginal_log_reward(eval_env, fnet, bnet, logz, batch=10)
    model_r = log_r_hat.exp()
    true_r = eval_env.log_reward().exp()
    dist_model = model_r / model_r.sum()
    dist_true = true_r / true_r.sum()
    return float((dist_model - dist_true).abs().sum().item() / 2.0)


def run_one(cfg: dict):
    cfg = dict(cfg)

    run_dir = new_run_dir(exp="diffusionGrid", method="dtb", cfg=cfg, out_root="runs")
    metrics_path = run_dir / "metrics.jsonl"
    eval_path = run_dir / "eval.jsonl"
    fig_dir = run_dir / "figures"

    device = cfg.get("device", "cpu")
    seed = int(cfg["seed"])
    set_seed(seed)

    reward_kind = str(cfg["reward_kind"])
    reward_mult = cfg.get("reward_multiplier", 1)
    size = int(cfg["size"])
    eps = float(cfg.get("eps", 0.1))
    threshold = float(cfg["threshold"])
    reward_alpha = float(cfg.get("reward_alpha"))
    batch_size = int(cfg["batch_size"]) // 2
    log_reward = build_log_reward_fn(reward_kind, size=size, multiplier=reward_mult)

    env = DiffGrid(size=size, batch_size=batch_size, log_reward=log_reward, seed=seed, eps=eps)
    eval_env = DiffGrid(size=size, batch_size=2 * batch_size, log_reward=log_reward, seed=seed, eps=eps)

    fnet = FourierTimePolicy(
        hidden_dim=int(cfg.get("hidden_dim_f", 128)),
        num_layers=int(cfg.get("num_layers_f", 3)),
        n_freq=int(cfg.get("n_freq_f", 8)),
    ).to(device)
    bnet = FourierTimePolicy(
        hidden_dim=int(cfg.get("hidden_dim_b", 128)),
        num_layers=int(cfg.get("num_layers_b", 1)),
        n_freq=int(cfg.get("n_freq_b", 8)),
    ).to(device)
    logz = torch.nn.Parameter(torch.zeros(1, device=device))

    div_fnet = FourierTimePolicy(
        hidden_dim=int(cfg.get("div_hidden_dim_f", 64)),
        num_layers=int(cfg.get("div_num_layers_f", 2)),
        n_freq=int(cfg.get("div_n_freq_f", 16)),
    ).to(device)
    div_bnet = FourierTimePolicy(
        hidden_dim=int(cfg.get("div_hidden_dim_b", 64)),
        num_layers=int(cfg.get("div_num_layers_b", 2)),
        n_freq=int(cfg.get("div_n_freq_b", 8)),
    ).to(device)
    div_logz = torch.nn.Parameter(torch.zeros(1, device=device))

    opt = torch.optim.AdamW(
        [{"params": fnet.parameters(), "lr": float(cfg["lr_pf"])},
         {"params": bnet.parameters(), "lr": float(cfg["lr_pb"])},
         {"params": [logz], "lr": float(cfg["lr_logz"])}]
    )

    div_opt = torch.optim.AdamW(
        [{"params": div_fnet.parameters(), "lr": float(cfg["div_lr_pf"])},
         {"params": div_bnet.parameters(), "lr": float(cfg["div_lr_pb"])},
         {"params": [div_logz], "lr": float(cfg["div_lr_logz"])}]
    )

    epochs = int(cfg["epochs"])

    sch = torch.optim.lr_scheduler.LinearLR(opt, start_factor=1.0, end_factor=0.1, total_iters=epochs)
    div_sch = torch.optim.lr_scheduler.LinearLR(div_opt, start_factor=1.0, end_factor=0.1, total_iters=epochs)

    log_every = int(cfg.get("log_every", 10))
    eval_every = int(cfg.get("eval_every", 100))
    save_every = int(cfg.get("save_every", 1000))
    marginal_batch = int(cfg.get("marginal_batch", 15))

    def tb_loss() -> torch.Tensor:
        return tb(fnet, bnet, logz, env)

    t0 = time.time()

    for epoch in range(epochs + 1):
        # --- main TB (puro) ---
        env.reset()
        opt.zero_grad(set_to_none=True)
        main_loss = tb_loss()
        samples = env.pos.clone()

        # --- combined update (backward reward) + div update ---
        env.reset()
        opt.zero_grad(set_to_none=True)
        div_opt.zero_grad(set_to_none=True)

        log_R_theta = forward_reward(env, div_fnet, div_bnet, div_logz)
        div_samples = env.pos.clone()

        log_R = env.log_reward()
        log_R_hat = backward_reward(env, fnet, bnet, logz)
        back_loss = (log_R_hat - log_R).pow(2).mean()

        with torch.no_grad():
            w1 = torch.sigmoid(logz - div_logz)

        loss_main = back_loss + (main_loss - back_loss) * w1
        loss_main.backward()
        opt.step()
        sch.step()

        with torch.no_grad():
            mask = (log_R_hat - log_R).gt(math.log(threshold))
            known_frac = mask.float().mean().item()

        log_ratio = log_R_theta - reward_alpha*log_R
        div_loss = torch.where(mask, softplus(log_ratio), log_ratio).pow(2).mean()
        div_loss.backward()
        div_opt.step()
        div_sch.step()

        if epoch % log_every == 0:
            row = {
                "epoch": epoch,
                "time_sec": time.time() - t0,
                "loss_main": float(loss_main.item()),
                "tb_loss": float(main_loss.item()),
                "back_loss": float(back_loss.item()),
                "div_loss": float(div_loss.item()),
                "known_frac": float(known_frac),
                "logz_main": float(logz.item()),
                "logz_aux": float(div_logz.item()),
                "Z_main": float(logz.exp().item()),
                "Z_aux": float(div_logz.exp().item()),
                "avg_logR": float(log_R.mean().item()),
                "max_logR": float(log_R.max().item()),
                "run_id": cfg.get("run_id"),
                "seed": seed,
            }
            log_jsonl(metrics_path, row)
            print(
                f"[{cfg.get('run_id','?')}] [seed={seed}] [{epoch:5d}] "
                f"tb={row['tb_loss']:.4f} div={row['div_loss']:.4f} "
                f"known={row['known_frac']:.3f} logz={row['logz_main']:.3f}"
            )

        if epoch % eval_every == 0:
            eval_env.set_full_grid_T()
            log_p_hat = marginal_log_reward(eval_env, fnet, bnet, 0, batch=marginal_batch)
            log_p_hat_div = marginal_log_reward(eval_env, div_fnet, div_bnet, 0, batch=marginal_batch)
            plot_epoch_panels_v3(eval_env, log_p_hat, log_p_hat_div, epoch, out_dir=fig_dir)
            # plot_epoch_panels_v2(eval_env, log_p_hat, log_p_hat_div, epoch, out_dir=fig_dir)
            # log_r_hat = marginal_log_reward(eval_env, fnet, bnet, logz, batch=marginal_batch)
            # plot_epoch_panels(
            #     eval_env,
            #     log_r_hat,
            #     div_samples=div_samples,
            #     samples=samples,
            #     epoch=epoch,
            #     out_dir=fig_dir,
            # )

            log_r = eval_env.log_reward()
            easy_mask = (eval_env.pos.pow(2).sum(dim=-1) < 9**2)
            hard_mask = (eval_env.pos.pow(2).sum(dim=-1) >= 9**2)
            loss = (backward_reward(eval_env, fnet, bnet, logz) - log_r).abs()
            easy_loss = loss[easy_mask].mean().item()
            hard_loss = loss[hard_mask].mean().item()
            l1 = l1_tv_distance(eval_env, fnet, bnet, logz)
            log_jsonl(
                eval_path,
                {
                    "epoch": epoch,
                    "l1_tv": l1,
                    "easy_pos_loss": easy_loss,
                    "hard_pos_loss": hard_loss,
                    "run_id": cfg.get("run_id"),
                    "seed": seed,
                },
            )

        if epoch % save_every == 0:
            payload = {
                "epoch": epoch,
                "cfg": cfg,
                "fnet": fnet.state_dict(),
                "bnet": bnet.state_dict(),
                "logz": float(logz.item()),
                "div_fnet": div_fnet.state_dict(),
                "div_bnet": div_bnet.state_dict(),
                "div_logz": float(div_logz.item()),
                "opt": opt.state_dict(),
                "div_opt": div_opt.state_dict(),
                "rng_torch": torch.random.get_rng_state(),
            }
            save_ckpt(run_dir, epoch=epoch, tag="epoch", payload=payload)

    payload = {
        "epoch": epochs,
        "cfg": cfg,
        "fnet": fnet.state_dict(),
        "bnet": bnet.state_dict(),
        "logz": float(logz.item()),
        "div_fnet": div_fnet.state_dict(),
        "div_bnet": div_bnet.state_dict(),
        "div_logz": float(div_logz.item()),
        "opt": opt.state_dict(),
        "div_opt": div_opt.state_dict(),
        "rng_torch": torch.random.get_rng_state(),
    }
    save_ckpt(run_dir, epoch=epochs, tag="latest", payload=payload)


def main():
    args = parse_args()
    runs = load_runs("diffusionGrid/experiments.toml", run_id=args.run if not args.all else None)
    for cfg in runs:
        run_one(cfg)


if __name__ == "__main__":
    main()