"""Analyze the distribution of sigma_x (per-sample clean standard deviation), caching samples, plotting histograms, and testing for a power-law tail."""

from __future__ import annotations

import time
from dataclasses import dataclass, asdict, field
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import torch
from scipy.stats import gaussian_kde
from tqdm import tqdm

from se.configs import DatasetConfig, PROJECT_ROOT
from snr_hypothesis.utils import AxisContext, build_train_loader

EPS = 1e-6


@dataclass
class SigmaXConfig:
    """Study the distribution of σ_x and test for a power-law tail."""

    dataset: DatasetConfig = field(default_factory=DatasetConfig)
    samples_per_epoch: int = 1_000_000
    batch_size: int = 2**16
    bins: int = 60
    tail_percentile: float = 90.0
    min_tail_points: int = 5
    save_dir: Path = PROJECT_ROOT / Path("artifacts/snr_hypothesis")
    device: str | None = None
    max_batches: int | None = None
    use_cache: bool = True


DEFAULT_CONFIG = SigmaXConfig(
    bins=60,
    tail_percentile=90.0,
    min_tail_points=5,
    save_dir=PROJECT_ROOT / Path("artifacts/snr_hypothesis"),
    device="cpu",
    max_batches=None,
    use_cache=True,
)


def _sigma_x_from_batch(clean: torch.Tensor) -> torch.Tensor:
    dims = tuple(range(1, clean.ndim))
    sigma_x = clean.std(dim=dims, unbiased=False).clamp_min(EPS)
    return sigma_x.reshape(-1)


def sigma_cache_path(root: Path) -> Path:
    return root / "sigma_x.npz"


def save_sigma_cache(values: np.ndarray, cache_path: Path) -> None:
    cache_path.parent.mkdir(parents=True, exist_ok=True)
    np.savez_compressed(cache_path, sigma_x=values)


def load_sigma_cache(cache_path: Path) -> np.ndarray:
    with np.load(cache_path, allow_pickle=False) as data:
        return data["sigma_x"]


def collect_sigma_x(
    cfg: DatasetConfig,
    device: torch.device,
    max_batches: int | None,
    loader: torch.utils.data.DataLoader | None = None,
) -> np.ndarray:
    if loader is None:
        loader = build_train_loader(cfg)
    total_batches = len(loader)
    limit = (
        min(total_batches, max_batches) if max_batches is not None else total_batches
    )

    sigma_samples: list[np.ndarray] = []
    with torch.no_grad():
        for batch_idx, batch in tqdm(enumerate(loader), total=limit, desc="σ_x"):
            if batch_idx >= limit:
                break
            clean = (
                batch.to(device)
                if not isinstance(batch, (list, tuple))
                else batch[0].to(device)
            )
            sigma_x = _sigma_x_from_batch(clean)
            sigma_samples.append(sigma_x.cpu().numpy())

    if not sigma_samples:
        return np.array([], dtype=np.float64)
    return np.concatenate(sigma_samples, axis=0)


def _determine_bin_range(values: np.ndarray) -> tuple[float, float]:
    finite = values[np.isfinite(values)]
    if finite.size == 0:
        return (0.0, 1.0)
    low, high = np.percentile(finite, [0.5, 99.5])
    if not np.isfinite(low) or not np.isfinite(high):
        return (0.0, 1.0)
    if high <= low:
        high = low + 1e-6
    span = max(high - low, 1e-6)
    pad = 0.05 * span
    start = float(max(0.0, low - pad))
    end = float(high + pad)
    return (start, end)


def _fit_power_law(
    values: np.ndarray,
    bins: int,
    x_range: tuple[float, float],
    tail_percentile: float,
    min_tail_points: int,
) -> dict[str, float | np.ndarray] | None:
    finite = values[np.isfinite(values)]
    if finite.size == 0:
        return None

    hist, bin_edges = np.histogram(finite, bins=bins, range=x_range, density=True)
    centers = 0.5 * (bin_edges[:-1] + bin_edges[1:])

    cutoff = np.percentile(finite, tail_percentile)
    mask = (centers >= cutoff) & (hist > 0.0) & (centers > 0.0)
    if mask.sum() < min_tail_points:
        return None

    log_x = np.log10(centers[mask])
    log_y = np.log10(hist[mask])
    slope, intercept = np.polyfit(log_x, log_y, 1)
    pred = slope * log_x + intercept
    ss_res = float(np.sum((log_y - pred) ** 2))
    ss_tot = float(np.sum((log_y - log_y.mean()) ** 2))
    r2 = float(1.0 - ss_res / ss_tot) if ss_tot > 0 else float("nan")

    fit_x = np.linspace(centers[mask].min(), centers[mask].max(), 200)
    fit_y = 10.0 ** (slope * np.log10(fit_x) + intercept)

    return {
        "slope": float(slope),
        "intercept": float(intercept),
        "r2": r2,
        "cutoff": float(cutoff),
        "centers": centers,
        "hist": hist,
        "fit_x": fit_x,
        "fit_y": fit_y,
    }


def plot_sigma_histogram(
    sigma_x: np.ndarray,
    axis_ctx: AxisContext,
    bins: int,
    save_path: Path,
) -> None:
    style_path = PROJECT_ROOT / Path("icml_like.mplstyle")
    x0, x1 = axis_ctx.bin_range
    x_grid = np.linspace(x0, x1, 500)

    finite_vals = sigma_x[np.isfinite(sigma_x)]
    if finite_vals.size == 0:
        print("No finite σ_x values to plot.")
        return

    with plt.style.context(str(style_path)):
        h = 4
        aspect_ratio = 4 / 3
        w = h * aspect_ratio
        fig, ax = plt.subplots(figsize=(w, h), constrained_layout=True)

        ax.hist(
            finite_vals,
            bins=bins,
            range=(x0, x1),
            density=True,
            alpha=0.20,
            color="#1f77b4",
            edgecolor="none",
            label=r"$\sigma_x$",
        )

        kde = gaussian_kde(finite_vals)
        kde_vals = kde(x_grid)
        kde_vals[x_grid <= axis_ctx.bin_range[0]] = 0.0
        ax.plot(x_grid, kde_vals, color="#1f77b4", linewidth=2.2, label="KDE")

        ax.set_xlim(x0, x1)
        ax.set_xlabel(axis_ctx.label)
        ax.set_ylabel("Density")

        ax.grid(
            True, which="major", axis="both", linestyle="--", linewidth=0.6, alpha=0.5
        )
        ax.legend(loc="upper right", frameon=True, borderaxespad=0.3, handlelength=2.2)
        for spine in ax.spines.values():
            spine.set_linewidth(1.0)
            spine.set_color("black")
        ax.set_xlim(left=0.0, right=x1)

        save_path.parent.mkdir(parents=True, exist_ok=True)
        fig.savefig(save_path, bbox_inches="tight")
        plt.close(fig)


def plot_power_law(
    sigma_x: np.ndarray,
    fit: dict[str, float | np.ndarray] | None,
    bins: int,
    x_range: tuple[float, float],
    save_path: Path,
) -> None:
    style_path = PROJECT_ROOT / Path("icml_like.mplstyle")
    finite = sigma_x[np.isfinite(sigma_x)]
    if finite.size == 0:
        print("No finite σ_x values for power-law plot.")
        return

    start, end = x_range
    start = max(start, 1e-6)
    if end <= start:
        end = start * 10.0

    hist, bin_edges = np.histogram(finite, bins=bins, range=(start, end), density=True)
    centers = 0.5 * (bin_edges[:-1] + bin_edges[1:])
    valid = (hist > 0.0) & (centers > 0.0)

    with plt.style.context(str(style_path)):
        h = 4
        aspect_ratio = 4 / 3
        w = h * aspect_ratio
        fig, ax = plt.subplots(figsize=(w, h), constrained_layout=True)

        ax.scatter(
            centers[valid],
            hist[valid],
            color="#1f77b4",
            s=16,
            alpha=0.8,
            label="Histogram",
        )

        if fit is not None:
            fit_x = fit["fit_x"]
            fit_y = fit["fit_y"]
            slope = fit["slope"]
            r2 = fit["r2"]
            cutoff = fit["cutoff"]
            ax.plot(
                fit_x,
                fit_y,
                color="#d62728",
                linewidth=2.0,
                label=rf"Tail fit: slope={slope:.2f}, $R^2$={r2:.3f}",
            )
            ax.axvline(
                cutoff,  # type: ignore[index]
                color="0.35",
                linestyle="--",
                linewidth=1.2,
                label=rf"cutoff p={fit['cutoff']:.3g}",
            )

        ax.set_xscale("log")
        ax.set_yscale("log")
        ax.set_xlabel(r"$\sigma_x$")
        ax.set_ylabel("Density")
        ax.grid(
            True, which="both", axis="both", linestyle="--", linewidth=0.6, alpha=0.5
        )
        ax.legend(
            loc="lower left", frameon=True, borderaxespad=0.3, handlelength=2.2
        )
        for spine in ax.spines.values():
            spine.set_linewidth(1.0)
            spine.set_color("black")

        save_path.parent.mkdir(parents=True, exist_ok=True)
        fig.savefig(save_path, bbox_inches="tight")
        plt.close(fig)


def main(script_cfg: SigmaXConfig = DEFAULT_CONFIG) -> None:
    start_time = time.time()
    print("σ_x distribution config:")
    print(asdict(script_cfg))

    ds = script_cfg.dataset
    ds.batch_size = script_cfg.batch_size  # type: ignore[attr-defined]
    ds.s_samples_per_epoch = script_cfg.samples_per_epoch  # type: ignore[attr-defined]

    device_str = (
        script_cfg.device
        if script_cfg.device is not None
        else ("cuda" if torch.cuda.is_available() else "cpu")
    )
    device = torch.device(device_str)

    cache_root = script_cfg.save_dir / "sigma_x_cache"
    cache_root.mkdir(parents=True, exist_ok=True)
    figure_root = script_cfg.save_dir / "sigma_x"
    figure_root.mkdir(parents=True, exist_ok=True)

    cache_path = sigma_cache_path(cache_root)
    sigma_x_vals: np.ndarray
    if script_cfg.use_cache and cache_path.is_file():
        try:
            sigma_x_vals = load_sigma_cache(cache_path)
            print(f"Loaded cached σ_x samples from {cache_path.name}")
        except Exception as exc:
            print(f"Failed to load σ_x cache: {exc}. Recomputing.")
            sigma_x_vals = np.array([], dtype=np.float64)
    else:
        sigma_x_vals = np.array([], dtype=np.float64)

    loader: torch.utils.data.DataLoader | None = None
    if sigma_x_vals.size == 0:
        loader = build_train_loader(ds)
        sigma_x_vals = collect_sigma_x(
            cfg=ds,
            device=device,
            max_batches=script_cfg.max_batches,
            loader=loader,
        )
        save_sigma_cache(sigma_x_vals, cache_path)
        print(f"Saved σ_x samples to {cache_path.name}")

    print(f"Collected {sigma_x_vals.size} σ_x samples.")

    bin_range = _determine_bin_range(sigma_x_vals)
    axis_ctx = AxisContext(
        label=r"$\sigma_x = \sqrt{\mathrm{Var}(x)}$",
        symbol=r"$\sigma_x$",
        bin_range=bin_range,
        default_xlim=bin_range,
        prefix="sigma_x",
    )

    hist_path_pdf = figure_root / f"{axis_ctx.prefix}_histogram.pdf"
    plot_sigma_histogram(
        sigma_x_vals,
        axis_ctx,
        bins=script_cfg.bins,
        save_path=hist_path_pdf,
    )
    print(f"Saved σ_x histogram to {hist_path_pdf}")

    log_bin_range = (max(bin_range[0], 1e-6), bin_range[1])
    fit = _fit_power_law(
        sigma_x_vals,
        bins=script_cfg.bins,
        x_range=log_bin_range,
        tail_percentile=script_cfg.tail_percentile,
        min_tail_points=script_cfg.min_tail_points,
    )
    if fit is None:
        print("Power-law fit skipped (insufficient tail points or no finite samples).")
    else:
        print(
            f"Power-law tail slope: {fit['slope']:.3f}, intercept: {fit['intercept']:.3f}, R^2: {fit['r2']:.3f}, cutoff: {fit['cutoff']:.3f}"
        )

    power_path_pdf = figure_root / f"{axis_ctx.prefix}_power_law.pdf"
    plot_power_law(
        sigma_x_vals,
        fit=fit,
        bins=script_cfg.bins,
        x_range=log_bin_range,
        save_path=power_path_pdf,
    )
    print(f"Saved σ_x power-law plot to {power_path_pdf}")

    elapsed = time.time() - start_time
    print(f"Total elapsed time: {elapsed:.2f} seconds")


if __name__ == "__main__":
    main()
