from ..diffusionGrid_env import DiffGrid
from ..diffusionGrid_losses import tb
from ..diffusionGrid_rewards import build_log_reward_fn
from ..diffusionGrid_nets import FourierTimePolicy
from ..diffusionGrid_sampling import marginal_log_reward, forward_reward, backward_reward
from ..diffusionGrid_util import save_checkpoint, get_run_dir, build_parser, plot_epoch_panels, plot_learned_reward
from util import set_seed
import torch


ON_OUT_DIR = "baseline_plots_gfn_on"
OFF_OUT_DIR = "baseline_plots_gfn_off"

reward_kind = "rings"
size = 15; seed = 42
lr_pf = 5e-3; lr_pb = 5e-3; lr_logz = 5e-2
batch_size = 128
epoches = 4000
device = "cpu"

log_reward = build_log_reward_fn(reward_kind, size=size)
env_on = DiffGrid(size=size, batch_size=batch_size, log_reward=log_reward, seed=seed, eps=0.)
env_off = DiffGrid(size=size, batch_size=batch_size, log_reward=log_reward, seed=seed, eps=1.)
eval_env = DiffGrid(size=size, batch_size=batch_size, log_reward=log_reward, seed=seed, eps=0.)

set_seed(seed)
fnet_on = FourierTimePolicy().to(device); bnet_on = FourierTimePolicy().to(device); logz_on = torch.nn.Parameter(torch.zeros(1, device=device))
fnet_off = FourierTimePolicy().to(device); bnet_off = FourierTimePolicy().to(device); logz_off = torch.nn.Parameter(torch.zeros(1, device=device))

opt_on = torch.optim.AdamW([{"params": fnet_on.parameters(), "lr": lr_pf},
                         {"params": bnet_on.parameters(), "lr": lr_pb},
                         {"params": [logz_on], "lr": lr_logz}
                         ])

opt_off = torch.optim.AdamW([{"params": fnet_off.parameters(), "lr": lr_pf},
                         {"params": bnet_off.parameters(), "lr": lr_pb},
                         {"params": [logz_off], "lr": lr_logz}
                         ])
on_easy_pos_loss = []; on_hard_pos_loss = []
off_easy_pos_loss = []; off_hard_pos_loss = []

for epoch in range(0, epoches + 1):

    if epoch % 50 == 0:
        print("epoch:", epoch)
        eval_env.set_full_grid_T()
        log_r_hat = marginal_log_reward(env_on, fnet_on, bnet_on, logz_on, batch=15)
        plot_learned_reward(env_on, log_r_hat, epoch, ON_OUT_DIR)

        log_r = eval_env.log_reward()
        easy_mask = (eval_env.pos.abs().sum(dim=-1) <= 4)  & (eval_env.pos.abs().sum(dim=-1) >= 2)
        hard_mask = (eval_env.pos.abs().sum(dim=-1) >= 11) & (eval_env.pos.abs().sum(dim=-1) <= 13)

        loss = (backward_reward(eval_env, fnet_on, bnet_on, logz_on) - log_r).abs()
        on_easy_pos_loss.append(loss[easy_mask].mean().item())
        on_hard_pos_loss.append(loss[hard_mask].mean().item())

        eval_env.set_full_grid_T()
        log_r_hat = marginal_log_reward(env_off, fnet_off, bnet_off, logz_off, batch=15)
        plot_learned_reward(env_off, log_r_hat, epoch, OFF_OUT_DIR)

        log_r = eval_env.log_reward()
        easy_mask = (eval_env.pos.abs().sum(dim=-1) <= 4)  & (eval_env.pos.abs().sum(dim=-1) >= 2)
        hard_mask = (eval_env.pos.abs().sum(dim=-1) >= 11) & (eval_env.pos.abs().sum(dim=-1) <= 13)

        loss = (backward_reward(eval_env, fnet_off, bnet_off, logz_off) - log_r).abs()
        off_easy_pos_loss.append(loss[easy_mask].mean().item())
        off_hard_pos_loss.append(loss[hard_mask].mean().item())



    opt_on.zero_grad()
    env_on.reset()
    main_loss = tb(fnet_on, bnet_on, logz_on, env_on)
    main_loss.backward()
    opt_on.step()

    opt_off.zero_grad()
    env_off.reset()
    main_loss = tb(fnet_off, bnet_off, logz_off, env_off)
    main_loss.backward()
    opt_off.step()


import os
import json
import numpy as np

def save_losses(out_path: str,
                on_easy_pos_loss,
                on_hard_pos_loss,
                off_easy_pos_loss,
                off_hard_pos_loss,
                step: int = 50,
                epoches: int | None = None):
    os.makedirs(os.path.dirname(out_path) or ".", exist_ok=True)

    payload = {
        "meta": {
            "step": int(step),
            "epoches": None if epoches is None else int(epoches),
        },
        "on_easy_pos_loss":  np.asarray(on_easy_pos_loss, dtype=float).tolist(),
        "on_hard_pos_loss":  np.asarray(on_hard_pos_loss, dtype=float).tolist(),
        "off_easy_pos_loss": np.asarray(off_easy_pos_loss, dtype=float).tolist(),
        "off_hard_pos_loss": np.asarray(off_hard_pos_loss, dtype=float).tolist(),
    }

    with open(out_path, "w", encoding="utf-8") as f:
        json.dump(payload, f, indent=2)

# ---- depois do treino / quando tiver as listas prontas ----
save_losses(
    out_path="artifacts/losses.json",
    on_easy_pos_loss=on_easy_pos_loss,
    on_hard_pos_loss=on_hard_pos_loss,
    off_easy_pos_loss=off_easy_pos_loss,
    off_hard_pos_loss=off_hard_pos_loss,
    step=50,
    epoches=epoches,  # opcional
)

