from __future__ import annotations

import argparse
import time

import torch

from util import set_seed
from ..diffusionGrid_env import DiffGrid
from ..diffusionGrid_losses import tb
from ..diffusionGrid_rewards import build_log_reward_fn
from ..diffusionGrid_nets import FourierTimePolicy
from ..diffusionGrid_sampling import marginal_log_reward, backward_reward
from ..diffusionGrid_util import plot_epoch_panels
from ..diffusionGrid_utils import load_runs, new_run_dir, log_jsonl, save_ckpt


def parse_args():
    p = argparse.ArgumentParser()
    g = p.add_mutually_exclusive_group(required=True)
    g.add_argument("--run", type=str)
    g.add_argument("--all", action="store_true")
    return p.parse_args()


@torch.no_grad()
def l1_tv_distance(eval_env: DiffGrid, fnet, bnet, logz) -> float:
    eval_env.set_full_grid_T()
    log_r_hat = marginal_log_reward(eval_env, fnet, bnet, logz, batch=10)
    model_r = log_r_hat.exp()
    true_r = eval_env.log_reward().exp()
    dist_model = model_r / model_r.sum()
    dist_true = true_r / true_r.sum()
    return float((dist_model - dist_true).abs().sum().item() / 2.0)


def run_one(cfg: dict):
    cfg = dict(cfg)

    run_dir = new_run_dir(exp="diffusionGrid", method="tb", cfg=cfg, out_root="runs")
    metrics_path = run_dir / "metrics.jsonl"
    eval_path = run_dir / "eval.jsonl"
    fig_dir = run_dir / "figures"

    device = cfg.get("device", "cpu")
    seed = int(cfg.get("seed", 42))
    set_seed(seed)

    reward_kind = str(cfg.get("reward_kind", "8g"))
    size = int(cfg.get("size", 15))
    batch_size = int(cfg.get("batch_size", 512))
    eps = float(cfg.get("eps", 0.1))

    epochs = int(cfg.get("epochs", 4000))
    log_every = int(cfg.get("log_every", 10))
    eval_every = int(cfg.get("eval_every", 100))
    save_every = int(cfg.get("save_every", 1000))
    marginal_batch = int(cfg.get("marginal_batch", 15))

    lr_pf = float(cfg.get("lr_pf", 5e-3))
    lr_pb = float(cfg.get("lr_pb", 5e-3))
    lr_logz = float(cfg.get("lr_logz", 5e-2))

    log_reward_fn = build_log_reward_fn(reward_kind, size=size)
    env = DiffGrid(size=size, batch_size=batch_size, log_reward=log_reward_fn, seed=seed, eps=eps)
    eval_env = DiffGrid(size=size, batch_size=batch_size, log_reward=log_reward_fn, seed=seed, eps=eps)

    fnet = FourierTimePolicy(
        hidden_dim=int(cfg.get("hidden_dim_f", 128)),
        num_layers=int(cfg.get("num_layers_f", 3)),
        n_freq=int(cfg.get("n_freq_f", 8)),
    ).to(device)
    bnet = FourierTimePolicy(
        hidden_dim=int(cfg.get("hidden_dim_b", 128)),
        num_layers=int(cfg.get("num_layers_b", 1)),
        n_freq=int(cfg.get("n_freq_b", 8)),
    ).to(device)
    logz = torch.nn.Parameter(torch.zeros(1, device=device))

    opt = torch.optim.AdamW(
        [{"params": fnet.parameters(), "lr": lr_pf},
         {"params": bnet.parameters(), "lr": lr_pb},
         {"params": [logz], "lr": lr_logz}]
    )
    sch = torch.optim.lr_scheduler.LinearLR(opt, start_factor=1.0, end_factor=0.1, total_iters=epochs)
    t0 = time.time()

    samples = None
    for epoch in range(epochs + 1):
        env.reset()
        opt.zero_grad(set_to_none=True)

        loss = tb(fnet, bnet, logz, env)
        samples = env.pos.clone()

        loss.backward()
        opt.step()
        sch.step()

        if epoch % log_every == 0:
            row = {
                "epoch": epoch,
                "time_sec": time.time() - t0,
                "tb_loss": float(loss.item()),
                "logz": float(logz.item()),
                "Z": float(logz.exp().item()),
                "avg_logR": float(env.log_reward().mean().item()),
                "max_logR": float(env.log_reward().max().item()),
                "run_id": cfg.get("run_id"),
                "seed": seed,
            }
            log_jsonl(metrics_path, row)
            print(f"[{cfg.get('run_id','?')}] [seed={seed}] [{epoch:5d}] tb={row['tb_loss']:.4f} logz={row['logz']:.3f}")

        if epoch % eval_every == 0:
            eval_env.set_full_grid_T()
            log_r_hat = marginal_log_reward(eval_env, fnet, bnet, logz, batch=marginal_batch)

            plot_epoch_panels(
                eval_env,
                log_r_hat,
                div_samples=samples,
                samples=samples,
                epoch=epoch,
                out_dir=fig_dir,
            )

            log_r = eval_env.log_reward()
            easy_mask = (eval_env.pos.pow(2).sum(dim=-1) < 9**2)
            hard_mask = (eval_env.pos.pow(2).sum(dim=-1) >= 9**2)
            loss_abs = (backward_reward(eval_env, fnet, bnet, logz) - log_r).abs()
            easy_loss = loss_abs[easy_mask].mean().item()
            hard_loss = loss_abs[hard_mask].mean().item()

            l1 = l1_tv_distance(eval_env, fnet, bnet, logz)
            log_jsonl(
                eval_path,
                {
                    "epoch": epoch,
                    "l1_tv": l1,
                    "easy_pos_loss": easy_loss,
                    "hard_pos_loss": hard_loss,
                    "run_id": cfg.get("run_id"),
                    "seed": seed,
                },
            )

        if epoch % save_every == 0:
            payload = {
                "epoch": epoch,
                "cfg": cfg,
                "fnet": fnet.state_dict(),
                "bnet": bnet.state_dict(),
                "logz": float(logz.item()),
                "opt": opt.state_dict(),
                "rng_torch": torch.random.get_rng_state(),
            }
            save_ckpt(run_dir, epoch=epoch, tag="epoch", payload=payload)

    payload = {
        "epoch": epochs,
        "cfg": cfg,
        "fnet": fnet.state_dict(),
        "bnet": bnet.state_dict(),
        "logz": float(logz.item()),
        "opt": opt.state_dict(),
        "rng_torch": torch.random.get_rng_state(),
    }
    save_ckpt(run_dir, epoch=epochs, tag="latest", payload=payload)


def main():
    args = parse_args()
    runs = load_runs("diffusionGrid/experiments.toml", run_id=args.run if not args.all else None)
    for cfg in runs:
        run_one(cfg)


if __name__ == "__main__":
    main()