from __future__ import annotations

import math
import os
from typing import Optional, Callable

import numpy as np
import torch
import wandb
import matplotlib.pyplot as plt
from matplotlib.colors import to_rgba
from matplotlib.patches import Rectangle  # (not used now, but fine to keep)
from torchdiffeq import odeint

import learn_noise.utils.evaluation as evaluation
import learn_noise.utils.plotting_traj as plot_traj
from learn_noise.utils.image_eval import compute_fid, reshape_flat_samples, save_image_grid
from learn_noise.utils.image_latent_viz import (
    build_latent_visualizations,
    make_channel_pixel_histograms,
)


def log_real_rgb_histogram_once(
    *,
    args,
    sampler,
    image_shape,
    device: torch.device,
    step: int = 0,
    samples_key: str = "latent/real_rgb_hist",
) -> None:
    """Log a single set of RGB histograms for real data if not already emitted."""
    if getattr(args, "_logged_real_rgb_hist", False):
        return

    if image_shape is None or len(image_shape) != 3:
        return

    channels = image_shape[0]
    if channels not in {1, 3}:  # only meaningful for grayscale/RGB
        return

    try:
        import wandb  # local import to avoid hard dependency when disabled
    except ImportError:  # pragma: no cover - wandb optional
        return

    sample_count = int(getattr(args, "real_hist_samples", 4096))
    max_available = getattr(sampler, "num_samples", None)
    if max_available is not None:
        sample_count = max(1, min(sample_count, int(max_available)))

    try:
        real_flat = sampler.sample(sample_count, device=device, dtype=torch.float32)
    except TypeError:
        real_flat, _ = sampler.sample_with_labels(sample_count, device=device, dtype=torch.float32)
    except AttributeError:
        return

    real_flat = real_flat.detach().cpu()
    hist_fig = make_channel_pixel_histograms(real_flat, image_shape)
    wandb.log({samples_key: wandb.Image(hist_fig)}, step=step)
    plt.close(hist_fig)
    args._logged_real_rgb_hist = True

# -------------------- COLORS --------------------
COL_BG_LIGHT = "#F6F7F9"   # kept for consistency; we don't paint it (background is transparent)
COL_BG_DARK  = "#E9EDF2"
COL_PATH     = "#5A5F69"
COL_START    = "#000509B0"
COL_END      = "#F9C43D"
COL_DENSITY  = "#0B6DBD95"

# -------------------- UTILS --------------------
def _infer_tile_size(sampler, default=1.0):
    """Guess the sampler's checker tile size."""
    for name in ("tile_size", "checker_size", "period", "grid_step", "cell"):
        if hasattr(sampler, name):
            v = float(getattr(sampler, name))
            if v > 0:
                return v
    return float(default)

@torch.no_grad()
def _draw_sampler_background(
    ax,
    *,
    sampler,
    device: torch.device,
    x_min: float,
    x_max: float,
    y_min: float,
    y_max: float,
    density_grid: int = 240,
) -> None:
    """
    Transparent background; checker rendered as a single RGBA image (no seams).
    Colored tiles only for cells fully inside [x_min,x_max]×[y_min,y_max] (no skinny border tiles).
    Density overlay appears ONLY on those colored tiles.
    """
    ax.set_facecolor((1, 1, 1, 0))

    # --- grid aligned to sampler tile size ---
    s  = _infer_tile_size(sampler, default=1.0)
    x0 = np.floor(x_min / s) * s
    x1 = np.ceil(x_max / s) * s
    y0 = np.floor(y_min / s) * s
    y1 = np.ceil(y_max / s) * s

    nx = int(round((x1 - x0) / s))
    ny = int(round((y1 - y0) / s))

    ii, jj = np.meshgrid(np.arange(nx), np.arange(ny), indexing="xy")
    prev_dark = ((ii + jj) % 2 == 1).astype(np.float32)

    # invert parity per your last request: "clear ↔ color"
    colored = 1.0 - prev_dark

    # keep ONLY tiles fully inside the plotting window (kills skinny border tiles)
    tile_x_left   = x0 + ii * s
    tile_x_right  = tile_x_left + s
    tile_y_bottom = y0 + jj * s
    tile_y_top    = tile_y_bottom + s
    fully_inside = (
        (tile_x_left >= x_min) &
        (tile_x_right <= x_max) &
        (tile_y_bottom >= y_min) &
        (tile_y_top   <= y_max)
    ).astype(np.float32)

    color_mask = colored * fully_inside  # 1 on colored full tiles, 0 otherwise

    # --- checker as one RGBA image (no borders, no AA seams) ---
    checker = np.zeros((ny, nx, 4), dtype=np.float32)
    r, g, b, _ = to_rgba(COL_BG_DARK)
    checker[..., 0] = r * color_mask
    checker[..., 1] = g * color_mask
    checker[..., 2] = b * color_mask
    checker[..., 3] = color_mask  # alpha 1 only on fully-inside colored tiles

    ax.imshow(
        checker,
        extent=(x0, x1, y0, y1),
        origin="lower",
        interpolation="nearest",
        resample=False,
        filternorm=False,
        aspect="auto",
        zorder=0.1,
    )

    # --- faint density overlay, masked to the SAME fully-inside colored tiles ---
    if getattr(sampler, 'has_log_prob', False):
        gx = np.linspace(x_min, x_max, density_grid, dtype=np.float32)
        gy = np.linspace(y_min, y_max, density_grid, dtype=np.float32)
        XX, YY = np.meshgrid(gx, gy, indexing='xy')
        coords = np.stack([XX, YY], axis=-1).reshape(-1, 2)

        try:
            grid_t = torch.from_numpy(coords).to(device=device, dtype=torch.float32)
            logp = sampler.log_prob(grid_t)
        except Exception:
            grid_t = torch.from_numpy(coords).to('cpu', dtype=torch.float32)
            logp = sampler.log_prob(grid_t)

        if logp is not None:
            lp = logp.detach().cpu().numpy().reshape(gy.size, gx.size)
            finite = np.isfinite(lp)
            if finite.any():
                lo, hi = np.percentile(lp[finite], [5.0, 95.0])
                if hi - lo < 1e-6:
                    alpha = finite.astype(np.float32)
                else:
                    alpha = np.clip((lp - lo) / (hi - lo), 0.0, 1.0)
                    alpha[~finite] = 0.0

                # pixel-resolution mask for fully-inside colored tiles
                ix = np.floor((XX - x0) / s).astype(int)
                iy = np.floor((YY - y0) / s).astype(int)

                # bounds check (points falling outside x0..x1 may happen numerically)
                valid = (
                    (ix >= 0) & (ix < nx) &
                    (iy >= 0) & (iy < ny)
                )
                pix_mask = np.zeros_like(alpha, dtype=np.float32)
                if valid.any():
                    # map to the same color_mask grid
                    cm = color_mask  # (ny, nx)
                    pix_mask[valid] = cm[iy[valid], ix[valid]]

                rgba = np.zeros((gy.size, gx.size, 4), dtype=np.float32)
                r, g, b, _ = to_rgba(COL_DENSITY)
                rgba[..., 0] = r
                rgba[..., 1] = g
                rgba[..., 2] = b
                rgba[..., 3] = alpha * pix_mask * 0.18  # faint & masked

                ax.imshow(
                    rgba,
                    extent=(x_min, x_max, y_min, y_max),
                    origin='lower',
                    interpolation='nearest',
                    resample=False,
                    filternorm=False,
                    aspect='auto',
                    zorder=0.2,
                )


def _render_trajectory_panel(
    *,
    args,
    step: int,
    sampler,
    device: torch.device,
    trajectories: np.ndarray,
    starts: np.ndarray,
    ends: np.ndarray,
    show_legend: bool = True,
    show_title: bool = False,
) -> None:
    """Draw and log a snapshot of low-dimensional trajectories with a clear background."""
    if args.target_dataset == 'funnel':
        x_min, x_max = -5.0, 5.0
        y_min, y_max = -10.0, 10.0
    else:
        x_min, x_max = -4.0, 4.0
        y_min, y_max = -4.0, 4.0

    fig, ax = plt.subplots(1, 1, figsize=(7, 7), dpi=140)

    # Make the entire figure transparent (important for exports).
    fig.patch.set_facecolor('none')
    fig.patch.set_alpha(0.0)

    _draw_sampler_background(
        ax,
        sampler=sampler,
        device=device,
        x_min=x_min,
        x_max=x_max,
        y_min=y_min,
        y_max=y_max,
    )

    # --- trajectories ---
    line_kwargs = dict(color=COL_PATH, alpha=0.35, linewidth=1.0, solid_capstyle='round', zorder=1)
    for idx in range(starts.shape[0]):
        ax.plot(trajectories[:, idx, 0], trajectories[:, idx, 1], **line_kwargs)

    # --- start/end points ---
    ax.scatter(
        starts[:, 0], starts[:, 1],
        s=14, c=COL_START, alpha=0.9, edgecolors='none',
        # label='start (τ=1)',
        zorder=2,
    )
    ax.scatter(
        ends[:, 0], ends[:, 1],
        s=14, c=COL_END, alpha=0.9, edgecolors='none',
        # label='end (τ=0)',
        zorder=3,
    )

    ax.set_xlim(x_min, x_max)
    ax.set_ylim(y_min, y_max)

    if show_legend:
        legend = ax.legend(frameon=False, fontsize=9, loc='upper right')
        if legend is not None:
            for text in legend.get_texts():
                text.set_color('#2B2B2B')

    ax.set_aspect('equal', 'box')
    ax.set_title("" if not show_title else f"Trajectories @ step {step}", pad=12, color='#2B2B2B')
    ax.set_xticks([])
    ax.set_yticks([])
    for spine in ax.spines.values():
        spine.set_visible(False)

    # Transparent export: background is fully clear except our drawn elements.
    traj_path = os.path.join(args.runs_dir, f"trajectories_step_{step:06d}.png")
    fig.savefig(
        traj_path,
        dpi=180,
        bbox_inches='tight',
        pad_inches=0.02,
        transparent=True,
        facecolor='none',
        edgecolor='none',
    )
    wandb.log({"fm/trajectories_clean": wandb.Image(traj_path)}, step=step)
    plt.close(fig)

# -------------------- LOGGING HOOKS --------------------
def log_baseline_evaluation(
    *,
    args,
    step: int,
    ema_model,
    wrapper,
    ode_func,
    sampler,
    noise_sampler: Callable[[tuple[int, ...]], torch.Tensor],
    x0_batch: torch.Tensor,
    device: torch.device,
    do_light: bool,
    do_heavy: bool,
) -> None:
    """Run evaluation and logging for the baseline FM trainer."""
    ema_model.eval()
    wrapper.model = ema_model
    with torch.inference_mode():
        if do_light:
            if hasattr(ode_func, 'reset_nfe'):
                ode_func.reset_nfe()
            num_traj = min(2000, args.eval_sample)
            if num_traj > 0:
                eps1 = noise_sampler((num_traj, args.dim)).to(device)

                if args.target_dataset == 'funnel':
                    xlim = (-20.0, 20.0)
                    ylim = (-100.0, 100.0)
                else:
                    xlim = (-4.0, 4.0)
                    ylim = (-4.0, 4.0)

                plot_traj.visualize_and_save(
                    ode_func,
                    noise=eps1,
                    T=1.0,
                    output_dir=args.runs_dir,
                    num_steps=50,
                    num_samples=2000,
                    dim=args.dim,
                    device=device,
                    step=step,
                    wandb_key="fm/trajectory_gif",
                    filename=f"trajectory_step_{step:06d}",
                    xlim=xlim,
                    ylim=ylim,
                )

                t_vals = torch.linspace(1.0, 0.0, args.num_steps_eval, device=device)
                n_paths = min(500, eps1.shape[0])
                if n_paths > 0:
                    x_traj = odeint(ode_func, eps1[:n_paths], t_vals, method='dopri5')
                    X = x_traj.detach().cpu().numpy()
                    starts = X[0]
                    ends = X[-1]

                    _render_trajectory_panel(
                        args=args,
                        step=step,
                        sampler=sampler,
                        device=device,
                        trajectories=X,
                        starts=starts,
                        ends=ends,
                        show_legend=True,
                        show_title=False,
                    )

        if do_light and hasattr(ode_func, 'reset_nfe'):
            ode_func.reset_nfe()

        evaluation.heavy_eval_batched(
            args,
            x0_batch,
            ode_func,
            sampler,
            noise=noise_sampler,
            step=step,
            big_eval=do_heavy,
            device=device,
        )

        if do_light and hasattr(ode_func, 'nfe'):
            wandb.log({"fm/nfe_light": int(ode_func.nfe)}, step=step)

    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    ema_model.train()


def log_baseline_image_metrics(
    *,
    args,
    step: int,
    eval_model,
    wrapper,
    device: torch.device,
    image_shape,
    sampler,
    sample_vis_interval: int,
    sample_vis_count: int,
    sample_vis_nrow: int,
    sample_dir: str,
    fid_interval: int,
    fid_num_gen: int,
    fid_batch_size: int,
    fid_image_size: int,
    fid_gen_batch: int,
    fid_real_cache,
    noise_sampler: Callable[[tuple[int, ...]], torch.Tensor],
    generate_samples: Callable[..., torch.Tensor],
    fixed_noise: Optional[torch.Tensor],
) -> Optional[torch.Tensor]:
    """Log baseline image metrics: sample grids and FID."""
    log_samples = (
        sample_vis_interval > 0
        and sample_vis_count > 0
        and ((step + 1) % sample_vis_interval == 0)
    )
    log_fid = (
        fid_interval > 0
        and fid_num_gen > 0
        and fid_real_cache is not None
        and ((step + 1) % fid_interval == 0)
    )

    if not (log_samples or log_fid):
        return fixed_noise

    prev_mode = eval_model.training
    eval_model.eval()
    wrapper.model = eval_model

    with torch.inference_mode():
        if log_samples:
            if sample_dir:
                os.makedirs(sample_dir, exist_ok=True)
            if fixed_noise is None or fixed_noise.shape[0] < sample_vis_count:
                fixed_noise = noise_sampler((sample_vis_count, args.dim))
            generated_samples = generate_samples(sample_vis_count, latents=fixed_noise)
            vis_imgs = reshape_flat_samples(generated_samples, torch.Size(image_shape))
            grid_np = save_image_grid(
                vis_imgs,
                path=os.path.join(sample_dir, f'step_{step:06d}.png'),
                nrow=sample_vis_nrow,
            )
            wandb.log({"samples/grid": wandb.Image(grid_np)}, step=step)

        if log_fid:
            try:
                gen_samples = generate_samples(fid_num_gen)
                gen_imgs = reshape_flat_samples(gen_samples, torch.Size(image_shape))
                fid_val = compute_fid(
                    fid_real_cache,
                    gen_imgs,
                    device=device,
                    image_size=fid_image_size,
                    batch_size=fid_batch_size,
                )
                wandb.log({"metrics/fid": float(fid_val)}, step=step)
            except ImportError as exc:
                warned = getattr(args, "_fid_import_warned", False)
                if not warned:
                    print(f"[fid] Skipping FID evaluation: {exc}")
                    args._fid_import_warned = True

    if prev_mode:
        eval_model.train()

    return fixed_noise


def log_quantile_image_metrics(
    *,
    args,
    step: int,
    eval_model,
    wrapper,
    quantile,
    device: torch.device,
    image_shape,
    sampler,
    sample_vis_interval: int,
    sample_vis_count: int,
    sample_vis_nrow: int,
    sample_dir: str,
    fid_interval: int,
    fid_num_gen: int,
    fid_batch_size: int,
    fid_image_size: int,
    fid_gen_batch: int,
    fid_real_cache,
    generate_samples: Callable[..., torch.Tensor],
    fixed_u_vis: Optional[torch.Tensor],
    u_eps: float,
) -> Optional[torch.Tensor]:
    """Log image-based metrics (sample grids, FID, latent visualisations)."""
    log_samples = (
        sample_vis_interval > 0
        and sample_vis_count > 0
        and ((step + 1) % sample_vis_interval == 0)
    )
    log_fid = (
        fid_interval > 0
        and fid_num_gen > 0
        and fid_real_cache is not None
        and ((step + 1) % fid_interval == 0)
    )

    # Class-conditional FID removed in anonymized minimal version

    latent_viz_samples = int(args.latent_viz_samples)
    atlas_grid = int(args.latent_atlas_grid)

    if not (log_samples or log_fid or latent_viz_samples > 0):
        return fixed_u_vis

    prev_mode = eval_model.training
    prev_quant_mode = quantile.training
    eval_model.eval()
    wrapper.model = eval_model
    quantile.eval()

    with torch.inference_mode():
        if log_samples:
            if sample_dir:
                os.makedirs(sample_dir, exist_ok=True)
            if fixed_u_vis is None or fixed_u_vis.shape[0] < sample_vis_count:
                fixed_u_vis = torch.rand(sample_vis_count, args.dim, device=device)
            generated_samples = generate_samples(sample_vis_count, u_source=fixed_u_vis)
            vis_imgs = reshape_flat_samples(generated_samples, torch.Size(image_shape))
            grid_np = save_image_grid(
                vis_imgs,
                path=os.path.join(sample_dir, f'step_{step:06d}.png'),
                nrow=sample_vis_nrow,
            )
            wandb.log({"samples/grid": wandb.Image(grid_np)}, step=step)

        if log_fid:
            try:
                gen_samples = generate_samples(fid_num_gen)
                gen_imgs = reshape_flat_samples(gen_samples, torch.Size(image_shape))
                fid_val = compute_fid(
                    fid_real_cache,
                    gen_imgs,
                    device=device,
                    image_size=fid_image_size,
                    batch_size=fid_batch_size,
                )
                wandb.log({"metrics/fid": float(fid_val)}, step=step)
            except ImportError as exc:
                warned = getattr(args, "_fid_import_warned", False)
                if not warned:
                    print(f"[fid] Skipping FID evaluation: {exc}")
                    args._fid_import_warned = True

        if log_samples and latent_viz_samples > 0:
            latent_viz_samples = min(latent_viz_samples, 1024)
            unit_u = torch.rand(latent_viz_samples, args.dim, device=device)
            U_latent = u_eps + (1 - 2 * u_eps) * unit_u
            ones_latent = torch.ones(latent_viz_samples, 1, device=device)
            x_aux_latent = torch.zeros(latent_viz_samples, args.dim, device=device)
            eps_latent = quantile(U_latent, ones_latent, x_aux=x_aux_latent)

            latents_cpu = eps_latent.detach().cpu()

            atlas_images = None
            if atlas_grid > 1:
                n_atlas = min(atlas_grid * atlas_grid, latents_cpu.shape[0])
                if n_atlas > 0 and math.prod(image_shape) == latents_cpu.shape[1]:
                    perm = torch.randperm(latents_cpu.shape[0])[:n_atlas]
                    atlas_images = latents_cpu[perm].reshape(n_atlas, *image_shape)

            viz_payload = build_latent_visualizations(
                latents_cpu,
                image_shape=image_shape,
                atlas_images=atlas_images,
            )

            wandb_viz = {
                "latent/mean_std": wandb.Image(viz_payload.mean_std_fig),
                "latent/hist_qq": wandb.Image(viz_payload.hist_qq_fig),
                "latent/pca": wandb.Image(viz_payload.pca_fig),
                "latent/correlation": wandb.Image(viz_payload.corr_fig),
            }
            if viz_payload.atlas_grid is not None:
                wandb_viz["latent/atlas"] = wandb.Image(viz_payload.atlas_grid)
            wandb.log(wandb_viz, step=step)

            for fig in [
                viz_payload.mean_std_fig,
                viz_payload.hist_qq_fig,
                viz_payload.pca_fig,
                viz_payload.corr_fig,
            ]:
                plt.close(fig)

    if prev_mode:
        eval_model.train()
    if prev_quant_mode:
        quantile.train()

    return fixed_u_vis

def log_quantile_low_dim_metrics(
    *,
    args,
    step: int,
    eval_model,
    wrapper,
    ode_func,
    sampler,
    quantile,
    x0_batch: torch.Tensor,
    device: torch.device,
    do_light: bool,
    do_heavy: bool,
    u_eps: float,
    fixed_eval_u: Optional[torch.Tensor],
) -> Optional[torch.Tensor]:
    """Run trajectory plots and Sinkhorn/MMD evaluation for low-dimensional targets."""
    if not (do_light or do_heavy):
        return fixed_eval_u

    prev_mode = eval_model.training
    eval_model.eval()
    wrapper.model = eval_model
    prev_quant_mode = quantile.training
    quantile.eval()

    with torch.inference_mode():
        if do_light:
            if hasattr(ode_func, 'reset_nfe'):
                ode_func.reset_nfe()
            num_traj = min(2000, args.eval_sample)
            if num_traj > 0:
                if fixed_eval_u is None or fixed_eval_u.shape[0] < num_traj:
                    fixed_eval_u = torch.rand(num_traj, args.dim, device=device)
                Uv = u_eps + (1 - 2 * u_eps) * fixed_eval_u[:num_traj]
                eps1 = quantile(
                    Uv,
                    torch.ones(num_traj, 1, device=device),
                    x_aux=torch.zeros(num_traj, args.dim, device=device),
                )

                if args.target_dataset == 'funnel':
                    xlim = (-20.0, 20.0)
                    ylim = (-100.0, 100.0)
                else:
                    xlim = (-4.0, 4.0)
                    ylim = (-4.0, 4.0)
                plot_traj.visualize_and_save(
                    ode_func,
                    noise=eps1,
                    T=1.0,
                    output_dir=args.runs_dir,
                    num_steps=50,
                    num_samples=2000,
                    dim=args.dim,
                    device=device,
                    step=step,
                    wandb_key="fm/trajectory_gif",
                    filename=f"trajectory_step_{step:06d}",
                    xlim=xlim,
                    ylim=ylim,
                )

                t_vals = torch.linspace(1.0, 0.0, args.num_steps_eval, device=device)
                n_paths = min(500, eps1.shape[0])
                if n_paths > 0:
                    x_traj = odeint(ode_func, eps1[:n_paths], t_vals, method='dopri5')
                    X = x_traj.detach().cpu().numpy()
                    starts = X[0]
                    ends = X[-1]

                    _render_trajectory_panel(
                        args=args,
                        step=step,
                        sampler=sampler,
                        device=device,
                        trajectories=X,
                        starts=starts,
                        ends=ends,
                        show_legend=True,
                        show_title=False,
                    )

        if do_light and hasattr(ode_func, 'reset_nfe'):
            ode_func.reset_nfe()

        evaluation.heavy_eval_batched(
            args,
            x0_batch,
            ode_func,
            sampler,
            step=step,
            big_eval=do_heavy,
            device=device,
            quantile=quantile,
        )
        if do_light and hasattr(ode_func, 'nfe'):
            wandb.log({"fm/nfe_val": int(ode_func.nfe)}, step=step)
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

    if prev_mode:
        eval_model.train()
    if prev_quant_mode:
        quantile.train()

    return fixed_eval_u
