from ..diffusionGrid_env import DiffGrid
from ..diffusionGrid_losses import tb
from ..diffusionGrid_rewards import build_log_reward_fn
from ..diffusionGrid_nets import FourierTimePolicy
from ..diffusionGrid_sampling import marginal_log_reward, forward_reward, backward_reward, backward_trajectory_log_prob
from ..diffusionGrid_util import save_checkpoint, get_run_dir, build_parser, plot_epoch_panels
from util import set_seed
import torch

from .base_core import cosh_tb, linex_tb

OUT_DIR = "beyond_sq_plots_gfn"
reward_kind = "rings"
size = 15; seed = 42
lr_pf = 5e-3; lr_pb = 5e-3; lr_logz = 5e-2
batch_size = 512
threshold = 0.3
eps = 0.1
epoches = 4000
device = "cpu"

log_reward = build_log_reward_fn(reward_kind, size=size)
env = DiffGrid(size=size, batch_size=batch_size, log_reward=log_reward, seed=seed, eps=eps)
eval_env = DiffGrid(size=size, batch_size=batch_size, log_reward=log_reward, seed=seed, eps=eps)

set_seed(seed)

fnet = FourierTimePolicy(hidden_dim=128, num_layers=3, n_freq=8).to(device)
bnet = FourierTimePolicy(hidden_dim=128, num_layers=1, n_freq=8).to(device)
logz = torch.nn.Parameter(torch.zeros(1, device=device))

opt = torch.optim.AdamW([{"params": fnet.parameters(), "lr": lr_pf},
                         {"params": bnet.parameters(), "lr": lr_pb},
                         {"params": [logz], "lr": lr_logz}
                         ])
sch = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epoches)

for epoch in range(0, epoches + 1):

    env.reset()
    opt.zero_grad()

    loss = cosh_tb(env, fnet, bnet, logz)
    samples = env.pos.clone()

    log_reward = env.log_reward()
    loss.backward()
    opt.step()
    sch.step()

    print(f"Epoch {epoch}: TB Loss: {loss.item():.4f}")

    if epoch % 100 == 0:
        eval_env.set_full_grid_T()
        log_r_hat = marginal_log_reward(eval_env, fnet, bnet, logz, batch=15)
        plot_epoch_panels(eval_env,log_r_hat,div_samples=samples,samples=samples,epoch=epoch,out_dir=OUT_DIR)




