import wandb
import torch
import torch.nn as nn
import math
from torch.nn.functional import softplus
import matplotlib.pyplot as plt


from .diffusionGrid_env import DiffGrid
from .diffusionGrid_losses import tb
from .diffusionGrid_rewards import build_log_reward_fn
from .diffusionGrid_nets import FourierTimePolicy, UniformPolicy
from .diffusionGrid_sampling import marginal_log_reward, forward_reward, backward_reward
from .diffusionGrid_util import save_checkpoint, get_run_dir, build_parser, plot_epoch_panels
from util import set_seed


def train():
    # 1. Inicializa a run (o Sweep passará os config automaticamente)
    with wandb.init() as run:
        config = run.config
        device = "cpu"

        # 2. Configuração do Ambiente e Semente
        set_seed(config.seed)
        log_reward = build_log_reward_fn(config.reward_kind, size=config.size)
        env = DiffGrid(size=config.size, batch_size=config.batch_size,
                       log_reward=log_reward, seed=config.seed, eps=config.eps)

        eval_env = DiffGrid(size=config.size, batch_size=config.batch_size,
                       log_reward=log_reward, seed=config.seed, eps=0)

        # 3. Redes e Parâmetros (Log-Z)
        # --- Redes Principais (Main GFN) ---
        fnet = FourierTimePolicy(
            hidden_dim=config.fnet_hidden_dim,
            num_layers=config.fnet_num_layers,
            n_freq=config.fnet_n_freq
        ).to(device)

        bnet = FourierTimePolicy(
            hidden_dim=config.bnet_hidden_dim,
            num_layers=config.bnet_num_layers,
            n_freq=config.bnet_n_freq
        ).to(device)

        # --- Redes de Divergência (Exploração) ---
        div_fnet = FourierTimePolicy(
            hidden_dim=config.div_fnet_hidden_dim,
            num_layers=config.div_fnet_num_layers,
            n_freq=config.div_fnet_n_freq
        ).to(device)

        div_bnet = FourierTimePolicy(
            hidden_dim=config.div_bnet_hidden_dim,
            num_layers=config.div_bnet_num_layers,
            n_freq=config.div_bnet_n_freq
        ).to(device)

        logz = nn.Parameter(torch.zeros(1, device=device))
        div_logz = nn.Parameter(torch.zeros(1, device=device))

        # 4. Otimizadores usando os parâmetros do sweep
        opt = torch.optim.AdamW([
            {"params": fnet.parameters(), "lr": config.lr_pf},
            {"params": bnet.parameters(), "lr": config.lr_pb},
            {"params": [logz], "lr": config.lr_logz}
        ])

        div_opt = torch.optim.AdamW([
            {"params": div_fnet.parameters(), "lr": config.lr_pf},
            {"params": div_bnet.parameters(), "lr": config.lr_pb},
            {"params": [div_logz], "lr": config.lr_logz}
        ])

        # Ciclo de Treinamento
        for epoch in range(config.epochs + 1):
            env.reset()
            opt.zero_grad()
            main_loss = tb(fnet, bnet, logz, env)

            # Cálculo do Divergence Loss (seu esquema dual)
            env.reset()
            div_opt.zero_grad()
            log_R_theta = forward_reward(env, div_fnet, div_bnet, div_logz)

            log_R = env.log_reward()
            log_R_hat = backward_reward(env, fnet, bnet, logz)
            back_loss = (log_R_hat - log_R).pow(2).mean()

            with torch.no_grad():
                w1 = torch.sigmoid(logz - div_logz)

            loss = back_loss + (main_loss - back_loss) * w1
            loss.backward()
            opt.step()

            with torch.no_grad():
                mask = (log_R_hat - log_R).gt(math.log(config.threshold))

            div_loss = torch.where(mask, softplus(log_R_theta - log_R), log_R_theta - log_R).pow(2).mean()
            div_loss.backward()
            div_opt.step()

            # 5. Logging Métricas
            if epoch % 10 == 0:
                log_dict = {
                    "epoch": epoch,
                    "loss/main_tb": main_loss.item(),
                    "loss/div_loss": div_loss.item(),
                    "params/logz": logz.item(),
                    "params/div_logz": div_logz.item(),
                    "metrics/exploration_ratio": mask.float().mean().item()
                }

            # 6. Logging Visual (Plots como Artifacts)
            if epoch % 100 == 0:
                eval_env.set_full_grid_T()
                log_r_hat = marginal_log_reward(eval_env, fnet, bnet, logz, batch=10)
                model_r_hat = log_r_hat.exp()
                r = eval_env.log_reward().exp()

                # Normalizamos para comparar distribuições de probabilidade (L1 TV distance)
                dist_model = model_r_hat / model_r_hat.sum()
                dist_true = r / r.sum()
                l1_error = (dist_model - dist_true).abs().sum().item() / 2

                log_dict["metrics/l1_total_variation"] = l1_error

            wandb.log(log_dict)

if __name__ == "__main__":
    # Esta parte só é usada para rodar localmente sem sweep se desejar
    train()