"""Utilities for visualizing high-dimensional image latents."""

from __future__ import annotations

import math
from dataclasses import dataclass
from typing import Iterable, List, Optional, Sequence, Tuple

import numpy as np
import torch
from matplotlib import pyplot as plt
from torchvision.utils import make_grid


def _convert_for_imshow(image: np.ndarray) -> np.ndarray:
    if image.ndim == 2:
        return image
    if image.shape[0] in {1, 3}:
        # CHW -> HWC
        return np.moveaxis(image, 0, -1)
    if image.shape[-1] in {1, 3}:
        return image
    raise ValueError("Unexpected image layout for imshow")


def make_mean_std_figure(latents: torch.Tensor, image_shape: Sequence[int]) -> plt.Figure:
    reshaped = latents.numpy().reshape(latents.shape[0], *image_shape)
    mean_img = reshaped.mean(axis=0)
    std_img = reshaped.std(axis=0)

    fig, axes = plt.subplots(1, 2, figsize=(10, 4), dpi=140)
    ax_mean, ax_std = axes

    im_mean = ax_mean.imshow(_convert_for_imshow(mean_img), cmap="viridis")
    ax_mean.set_title("Latent mean")
    ax_mean.axis("off")
    fig.colorbar(im_mean, ax=ax_mean, fraction=0.046, pad=0.04)

    im_std = ax_std.imshow(_convert_for_imshow(std_img), cmap="magma")
    ax_std.set_title("Latent std")
    ax_std.axis("off")
    fig.colorbar(im_std, ax=ax_std, fraction=0.046, pad=0.04)
    fig.tight_layout()
    return fig


def _standard_normal_quantiles(n: int, device: torch.device) -> torch.Tensor:
    prob = (torch.arange(1, n + 1, device=device, dtype=torch.float32) - 0.5) / float(n)
    return math.sqrt(2.0) * torch.erfinv(2.0 * prob - 1.0)


def make_channel_pixel_histograms(
    latents: torch.Tensor,
    image_shape: Sequence[int],
    pixel_coords: Optional[Iterable[Tuple[int, int]]] = None,
    *,
    num_bins: int = 60,
) -> plt.Figure:
    reshaped = latents.reshape(latents.shape[0], *image_shape)
    cdim, h, w = image_shape
    if pixel_coords is None:
        pixel_coords = [
            (0, 0),
            (0, 16),
            (16, 0),
            (16, 16),
            (24, 24),
            (8, 24),
        ]
    coords = list(pixel_coords)

    colors = ["#d62728", "#2ca02c", "#1f77b4"]
    channel_labels = ["Red", "Green", "Blue"] if cdim >= 3 else ["Channel 0"]

    fig, axes = plt.subplots(len(coords), cdim * 2, figsize=(4 * cdim, 3 * len(coords)), dpi=140, squeeze=False)
    device = latents.device

    for row, (y, x) in enumerate(coords):
        for ch in range(cdim):
            samples = reshaped[:, ch, y, x].to(torch.float32)
            samples_np = samples.cpu().numpy()
            col_hist = ch * 2
            col_qq = col_hist + 1

            ax_hist = axes[row, col_hist]
            color = colors[ch % len(colors)]
            ax_hist.hist(samples_np, bins=num_bins, density=True, color=color, alpha=0.7)
            xs = np.linspace(samples_np.min(), samples_np.max(), 400)
            normal_pdf = (1.0 / math.sqrt(2 * math.pi)) * np.exp(-0.5 * xs ** 2)
            ax_hist.plot(xs, normal_pdf, color="#333333", linewidth=1.0)
            ax_hist.set_title(f"{channel_labels[ch]} @ (y={y}, x={x})")
            ax_hist.set_xlabel("Value")
            ax_hist.set_ylabel("Density")

            sorted_samples, _ = torch.sort(samples)
            theor = _standard_normal_quantiles(len(sorted_samples), device=device)
            ax_qq = axes[row, col_qq]
            ax_qq.scatter(theor.cpu().numpy(), sorted_samples.cpu().numpy(), s=10, alpha=0.6, color=color)
            ax_qq.plot(theor.cpu().numpy(), theor.cpu().numpy(), color="#d62728", linewidth=1.0)
            ax_qq.set_title("QQ vs N(0,1)")
            ax_qq.set_xlabel("Theoretical quantile")
            ax_qq.set_ylabel("Sample quantile")

    fig.tight_layout()
    return fig


def make_latent_pca_scatter(
    latents: torch.Tensor,
    *,
    color_source: Optional[torch.Tensor] = None,
) -> plt.Figure:
    centered = latents - latents.mean(dim=0, keepdim=True)
    cov = torch.matmul(centered.T, centered) / max(1, centered.shape[0] - 1)
    # eigh for stability on symmetric covariance
    evals, evecs = torch.linalg.eigh(cov)
    top2 = evecs[:, -2:]
    proj = torch.matmul(centered, top2)
    proj_np = proj.cpu().numpy()

    if color_source is None:
        colors = centered.norm(dim=1).cpu().numpy()
        cbar_label = "Latent norm"
    else:
        colors = color_source.cpu().numpy()
        cbar_label = "Sample mean intensity"

    fig, ax = plt.subplots(1, 1, figsize=(6, 6), dpi=140)
    sc = ax.scatter(proj_np[:, 0], proj_np[:, 1], c=colors, cmap="viridis", s=12, alpha=0.7)
    ax.set_title("Latent PCA (top-2 components)")
    ax.set_xlabel("PC1")
    ax.set_ylabel("PC2")
    cbar = fig.colorbar(sc, ax=ax)
    cbar.set_label(cbar_label)
    fig.tight_layout()
    return fig


def make_correlation_heatmap(
    latents: torch.Tensor,
    image_shape: Sequence[int],
    *,
    patch_size: int = 10,
) -> plt.Figure:
    reshaped = latents.reshape(latents.shape[0], *image_shape)
    _, h, w = image_shape
    ps = min(patch_size, h, w)
    y0 = (h - ps) // 2
    x0 = (w - ps) // 2
    patch = reshaped[:, :, y0 : y0 + ps, x0 : x0 + ps]
    flat = patch.reshape(patch.shape[0], -1).cpu().numpy()
    corr = np.corrcoef(flat, rowvar=False)

    fig, ax = plt.subplots(1, 1, figsize=(6, 5), dpi=140)
    im = ax.imshow(corr, cmap="coolwarm", vmin=-1.0, vmax=1.0)
    ax.set_title(f"Latent correlation, patch {ps}x{ps}")
    ax.set_xlabel("Feature index")
    ax.set_ylabel("Feature index")
    fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
    fig.tight_layout()
    return fig


def make_atlas_grid(images: torch.Tensor, nrow: int) -> np.ndarray:
    grid = make_grid(images, nrow=nrow, normalize=True, value_range=(-1.0, 1.0))
    return grid.permute(1, 2, 0).cpu().numpy()


@dataclass
class LatentVizPayload:
    mean_std_fig: plt.Figure
    hist_qq_fig: plt.Figure
    pca_fig: plt.Figure
    corr_fig: plt.Figure
    atlas_grid: Optional[np.ndarray]


def build_latent_visualizations(
    latents: torch.Tensor,
    *,
    image_shape: Sequence[int],
    atlas_images: Optional[torch.Tensor] = None,
    color_source: Optional[torch.Tensor] = None,
) -> LatentVizPayload:
    if latents.dim() != 2:
        raise ValueError("Expected flattened latent batch for visualization")

    mean_std = make_mean_std_figure(latents, image_shape)
    hist_qq = make_channel_pixel_histograms(latents, image_shape, pixel_coords=None)
    pca_fig = make_latent_pca_scatter(latents, color_source=color_source)
    corr_fig = make_correlation_heatmap(latents, image_shape)

    atlas_np = None
    if atlas_images is not None:
        atlas_np = make_atlas_grid(atlas_images, nrow=int(math.sqrt(atlas_images.shape[0])) or 1)

    return LatentVizPayload(
        mean_std_fig=mean_std,
        hist_qq_fig=hist_qq,
        pca_fig=pca_fig,
        corr_fig=corr_fig,
        atlas_grid=atlas_np,
    )
