import torch
from pathlib import Path
from math import pi
from typing import Callable, Optional
from .diffusionGrid_rewards import build_log_reward_fn
from .diffusionGrid_sampling import marginal_log_reward, forward_traj, backward_trajectory_log_prob
from copy import deepcopy
import argparse
import os
import re

def save_checkpoint(run_dir: Path, step: int,
                    fnet, bnet, logz, opt,
                    div_fnet, div_bnet, div_logz, div_opt):

    ckpt_path = run_dir / f"checkpoint_{int(step)}"
    payload = {
        "step": int(step),
        "fnet": fnet.state_dict(),
        "bnet": bnet.state_dict(),
        "logz": logz.detach().cpu(),
        "optimizer": opt.state_dict(),
        "div_fnet": div_fnet.state_dict(),
        "div_bnet": div_bnet.state_dict(),
        "div_logz": div_logz.detach().cpu(),
        "div_optimizer": div_opt.state_dict(),
    }
    torch.save(payload, ckpt_path)

def get_run_dir(kind: str, size, eps, lr_pf, lr_pb, lr_logz, batch_size, seed, model_name) -> Path:
    script_dir = Path(__file__).resolve().parent
    run_dir = script_dir / "runs" / f"kind-{kind}" / f"size-{size}" / f"eps-{eps}" / f"lr_pf-{lr_pf}"/ f"lr_pb-{lr_pb}" / f"lr_logz-{lr_logz}" / f"batch-{batch_size}" / str(seed) / model_name
    run_dir.mkdir(parents=True, exist_ok=True)
    return run_dir


def make_log_reward(kind: str, size: int, N: Optional[float] = None, **overrides) -> Callable:
    """
    Constrói a função de log-recompensa (logR_fn) para o ambiente, a partir de um 'kind' pré-definido.

    Parâmetros
    ----------
    kind : str
        Um dentre {'8g', 'rings', 'moons', 'spiral', 'checker'}.
    size : float
        Largura e altura do grid/ambiente, usados para escalar hiperparâmetros geométricos.
    N : Optional[float]
        Tamanho “típico” de célula/grade (usado em 'checker'). Se None, usa min(W, H).
    **overrides :
        Qualquer parâmetro específico passado aqui sobrepõe os *defaults* do `kind`.

    Retorna
    -------
    logR_fn : Callable
        Função de log-recompensa para ser passada ao ambiente.
    """
    Rm = float(size)
    N = float(N if N is not None else Rm)

    if kind == '8g':
        params = dict(R=0.8 * Rm, sigma=1.0, lam=1e-6)

    elif kind == 'rings':
        params = dict(
            radii=[0.4 * Rm, 0.75 * Rm],
            sigma_r=0.8,
            weights=[1.0, 0.7],
            lam=1e-6,
        )

    elif kind == 'moons':
        Rm2 = 0.6 * Rm
        params = dict(
            R=Rm2,
            delta=0.5 * Rm2,
            gap=0.3 * Rm2,
            sigma=1.0,
            n_anchors=256,
            lam=1e-6,
        )

    elif kind == 'spiral':
        # espiral de Arquimedes: r(θ) = a + bθ, até θ_max
        theta_max = 3 * pi
        b = (0.8 * Rm) / theta_max  # termina ~0.8*Rm no raio final
        params = dict(
            a=0.0,
            b=b,
            theta_max=theta_max,
            sigma=1.0,
            n_anchors=512,
            lam=1e-6,
        )

    elif kind == 'checker':
        # padrão quadriculado; 'cell' controla o período espacial
        cell = max(2.0, N / 3.0)
        params = dict(cell=cell, alpha=0.02, lam=1e-6)

    else:
        raise ValueError(f"kind desconhecido: {kind!r}. Opções: '8g', 'rings', 'moons', 'spiral', 'checker'.")

    # sobrescreve defaults com qualquer coisa passada via kwargs
    params.update(overrides)

    return build_log_reward_fn(kind, **params)


def l1(env, fnet, bnet, logz, batch=10):
    env.set_full_grid_T()
    true_log_r = env.log_reward()
    true_log_z = torch.logsumexp(true_log_r, dim=0)
    true_log_p = true_log_r - true_log_z
    log_p = marginal_log_reward(env, fnet, bnet, logz, batch=batch) - logz
    return (true_log_p.exp() - log_p.exp()).abs().mean().item()


def build_parser():
    p = argparse.ArgumentParser()
    p.add_argument("--reward_kind", type=str, default="8g")
    p.add_argument("--epoches", type=int, default=5000)
    p.add_argument("--size", type=int, default=15)
    p.add_argument("--seed", type=int, default=10)

    p.add_argument("--lr_pf", type=float, default=1e-3)         # forward GFN
    p.add_argument("--lr_pf_div", type=float, default=1e-3)     # forward div-GFN
    p.add_argument("--lr_pb", type=float, default=None)         # se None -> UniformPolicy
    p.add_argument("--lr_logz", type=float, default=5e-2)
    p.add_argument("--lr_logz_div", type=float, default=5e-2)

    p.add_argument("--hidden_dim", type=int, default=128)
    p.add_argument("--hidden_dim_div", type=int, default=128)

    p.add_argument("--num_layers", type=int, default=2)
    p.add_argument("--num_layers_div", type=int, default=2)

    p.add_argument("--batch_size_main", type=int, default=128)
    p.add_argument("--batch_size_div", type=int, default=128)

    p.add_argument("--eps", type=float, default=0.1)
    p.add_argument("--device", type=str, default="cpu")

    p.add_argument("--save_every", type=int, default=100)
    return p


def get_min_epoch(checkpoint_dir: str) -> int | None:
    pattern = re.compile(r"^checkpoint_(\d+)$")
    epochs = []

    for fname in os.listdir(checkpoint_dir):
        match = pattern.match(fname)
        if match:
            epochs.append(int(match.group(1)))

    return min(epochs) if epochs else None

def get_max_epoch(checkpoint_dir: str) -> int | None:
    pattern = re.compile(r"^checkpoint_(\d+)$")
    epochs = []

    for fname in os.listdir(checkpoint_dir):
        match = pattern.match(fname)
        if match:
            epochs.append(int(match.group(1)))

    return max(epochs) if epochs else None


def find_all_checkpoints(root_dir: str) -> list[str]:
    checkpoint_paths = []
    for dirpath, _, filenames in os.walk(root_dir):
        for fname in filenames:
            if "checkpoint" in fname:  # verifica substring
                checkpoint_paths.append(os.path.join(dirpath, fname))
    return checkpoint_paths

import numpy as np
import matplotlib.pyplot as plt

def plot_learned_reward(eval_env, log_r_hat, epoch, out_dir):
    """
    Plota a reward implícita (log_r_hat -> prob) em uma grade 2D
    e salva em out_dir/epoch_{epoch:06d}_reward.png
    """
    os.makedirs(out_dir, exist_ok=True)

    # --- 1) Reward implícita em grid ---
    img_est = eval_env._to_grid_image(log_r_hat)          # [ny, nx]
    data_est = img_est.exp().detach().cpu().numpy()       # prob
    extent = [-eval_env.width, eval_env.width,
              -eval_env.height, eval_env.height]

    # --- 2) Figura ---
    fig, ax = plt.subplots(figsize=(6, 5))

    im = ax.imshow(data_est, origin="lower", extent=extent)
    ax.set_title("Reward implícita")
    fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)

    ax.set_xlabel("x")
    ax.set_ylabel("y")
    ax.set_aspect("equal", adjustable="box")

    plt.tight_layout()

    fname = os.path.join(out_dir, f"epoch_{epoch:06d}_reward.png")
    fig.savefig(fname, dpi=200)
    plt.close(fig)


def plot_epoch_panels_v2(eval_env, log_p_hat, div_log_p_hat, epoch, out_dir):
    """
    Plota em uma única figura:
      [0] reward implícita (log_r_hat -> prob)
      [1] reward implícita do GFN de exploração (div_log_r_hat -> prob)
    e salva em out_dir/epoch_{epoch:06d}.png
    """
    os.makedirs(out_dir, exist_ok=True)

    # --- 1) Reward implícita em grid ---
    img_est = eval_env._to_grid_image(log_p_hat)          # [ny, nx]
    data_est = img_est.exp().detach().cpu().numpy()       # prob

    img_div = eval_env._to_grid_image(div_log_p_hat)      # [ny, nx]
    data_div = img_div.exp().detach().cpu().numpy()       # prob

    target_reward = eval_env.log_reward()
    img_true = eval_env._to_grid_image(target_reward)     # [ny, nx]
    data_true = img_true.exp().detach().cpu().numpy()     # prob    
    
    extent = [-eval_env.width, eval_env.width,
              -eval_env.height, eval_env.height]

    # --- 2) Figura com 2 subplots ---
    fig, axs = plt.subplots(1, 3, figsize=(15, 4.5), sharex=True, sharey=True)
    im0 = axs[0].imshow(data_true, origin="lower", extent=extent)
    axs[0].set_title("True Reward (prob)")
    #fig.colorbar(im0, ax=axs[0], fraction=0.046, pad=0.04)
    im1 = axs[1].imshow(data_est, origin="lower", extent=extent)
    axs[1].set_title("Canonical GFN Induced Probability")
    #fig.colorbar(im1, ax=axs[1], fraction=0.046, pad=0.04)
    im2 = axs[2].imshow(data_div, origin="lower", extent=extent)
    axs[2].set_title("Divergent GFN Induced Probability")
    #fig.colorbar(im2, ax=axs[2], fraction=0.046, pad=0.04)

    #for ax in axs:
        #ax.set_xlabel("x")
        #ax.set_ylabel("y")
        #ax.set_aspect("equal", adjustable="box")

    fig.suptitle(f"Epoch {epoch}")
    plt.tight_layout()

    fname = os.path.join(out_dir, f"epoch_{epoch:06d}.png")
    fig.savefig(fname, dpi=300)
    plt.close(fig)
    
def plot_epoch_panels_v3(eval_env, log_p_hat, div_log_p_hat, epoch, out_dir):
    """
    Salva 3 figuras separadas (sem título e sem ticks) em um diretório por época,
    em PNG e PDF:

      out_dir/epoch_{epoch:06d}/true.(png|pdf)
      out_dir/epoch_{epoch:06d}/canonical.(png|pdf)
      out_dir/epoch_{epoch:06d}/divergent.(png|pdf)
    """
    epoch_dir = os.path.join(out_dir, f"epoch_{epoch:06d}")
    os.makedirs(epoch_dir, exist_ok=True)

    # --- Dados (grid -> prob) ---
    img_est = eval_env._to_grid_image(log_p_hat)
    data_est = img_est.exp().detach().cpu().numpy()

    img_div = eval_env._to_grid_image(div_log_p_hat)
    data_div = img_div.exp().detach().cpu().numpy()

    target_reward = eval_env.log_reward()
    img_true = eval_env._to_grid_image(target_reward)
    data_true = img_true.exp().detach().cpu().numpy()

    extent = [-eval_env.width, eval_env.width,
              -eval_env.height, eval_env.height]

    def _save_single(data, stem: str):
        fig, ax = plt.subplots(1, 1, figsize=(5, 4.5))
        ax.imshow(data, origin="lower", extent=extent)

        # Remove ticks e títulos
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_title("")
        ax.set_xlabel("")
        ax.set_ylabel("")

        # Remove bordas (opcional)
        for spine in ax.spines.values():
            spine.set_visible(False)

        # Minimiza margens
        fig.subplots_adjust(left=0, right=1, bottom=0, top=1)

        png_path = os.path.join(epoch_dir, f"{stem}.png")
        pdf_path = os.path.join(epoch_dir, f"{stem}.pdf")

        # PNG (raster) e PDF (vetorial/container)
        fig.savefig(png_path, dpi=300, bbox_inches="tight", pad_inches=0)
        fig.savefig(pdf_path, bbox_inches="tight", pad_inches=0)

        plt.close(fig)

    _save_single(data_true, "true")
    _save_single(data_est, "canonical")
    _save_single(data_div, "divergent")
    
    
def plot_epoch_panels(eval_env, log_r_hat, div_samples, samples, epoch, out_dir):
    """
    Plota em uma única figura:
      [0] reward implícita (log_r_hat -> prob)
      [1] histograma 2D das amostras do Exploration GFN (div_samples)
      [2] histograma 2D das amostras do Exploitation GFN (samples)
    e salva em out_dir/epoch_{epoch:06d}.png
    """
    os.makedirs(out_dir, exist_ok=True)

    # --- 1) Reward implícita em grid ---
    img_est = eval_env._to_grid_image(log_r_hat)          # [ny, nx]
    data_est = img_est.exp().detach().cpu().numpy()       # prob
    extent = [-eval_env.width, eval_env.width,
              -eval_env.height, eval_env.height]

    nx, ny = eval_env._grid_shape()

    def hist_from_samples(samps):
        samps = samps.detach().cpu()
        x = samps[:, 0].numpy()
        y = samps[:, 1].numpy()
        H, xedges, yedges = np.histogram2d(
            x, y,
            bins=[nx, ny],
            range=[[-eval_env.width,  eval_env.width],
                   [-eval_env.height, eval_env.height]],
        )
        return H.T  # para casar com imshow(origin="lower")

    # --- 2) Histogramas das amostras ---
    H_div = hist_from_samples(div_samples)
    H_exp = hist_from_samples(samples)

    # --- 3) Figura com 3 subplots ---
    fig, axs = plt.subplots(1, 3, figsize=(15, 4.5), sharex=True, sharey=True)

    im0 = axs[0].imshow(data_est, origin="lower", extent=extent)
    axs[0].set_title("Induced Reward (prob)")
    fig.colorbar(im0, ax=axs[0], fraction=0.046, pad=0.04)

    im1 = axs[1].imshow(H_div, origin="lower", extent=extent)
    axs[1].set_title("Exploration GFN (hist)")
    fig.colorbar(im1, ax=axs[1], fraction=0.046, pad=0.04)

    im2 = axs[2].imshow(H_exp, origin="lower", extent=extent)
    axs[2].set_title("Exploitation GFN (hist)")
    fig.colorbar(im2, ax=axs[2], fraction=0.046, pad=0.04)

    for ax in axs:
        ax.set_xlabel("x")
        ax.set_ylabel("y")
        ax.set_aspect("equal", adjustable="box")

    fig.suptitle(f"Epoch {epoch}")
    plt.tight_layout()

    fname = os.path.join(out_dir, f"epoch_{epoch:06d}.png")
    fig.savefig(fname, dpi=200)
    plt.close(fig)