#!/usr/bin/env python3
"""
Compare (approximate) stationary distributions of SGD vs Bayesian sampling (SGLD)
on a simple toy classification problem (two moons), and estimate a diffusion
exponent from SGD trajectories via mean squared displacement (MSD).

Additions in this version
-------------------------
- During each SGD run, track the **mean squared displacement (MSD)** in weight
  space from the initialization as training progresses (per epoch).
- Normalize time for each run to [0,1] based on the number of epochs used
  (after early stopping). Resample all runs onto a common time grid and compute
  the average MSD curve across runs.
- Let R(t) = sqrt( MSD(t) ). Fit log R vs log t to estimate the **walk dimension** d_w
  from R(t) ~ t^(1/d_w). Then compute the **diffusion exponent** 2 - d_w.
- Save the MSD/R curves and statistics; annotate the diffusion exponent on the plot.
- NEW: Save a second plot showing the **average displacement R(t)** over normalized time.

Outputs (in ./outputs by default)
---------------------------------
- ./outputs/sgd_weights/ : converged SGD weights (state_dicts + flattened vectors)
- ./outputs/bayesian_weights/ : Bayesian (SGLD) samples (state_dicts + flattened vectors)
- ./outputs/embedding.png / .pdf : PCA overlay of SGD vs Bayesian samples
- ./outputs/embedding_points.csv : 2D points used to make the plot
- ./outputs/flattened_weights.npz : raw flattened vectors + labels + losses
- ./outputs/summary.json : configuration & summary stats including d_w and diffusion exponent
- ./outputs/sgd_msd_curve.csv : aggregated MSD/R(t) curves for SGD
- ./outputs/sgd_msd_stats.json : stats used in the diffusion fit (slope, d_w, 2-d_w, etc.)
- ./outputs/sgd_displacement.png / .pdf : plot of average displacement R(t) vs normalized time

How to run
----------
Set any desired configuration values in the `if __name__ == "__main__":` block
near the end of this file, then run:

    python sgd_vs_bayes_stationary.py

Dependencies (install if needed)
--------------------------------
- torch
- numpy
- scikit-learn
- matplotlib
- tqdm
- pandas

Notes & Caveats
---------------
- SGLD here provides an *approximate* posterior sampler. With small networks
  and toy data, it suffices to visualize qualitative differences.
- The acceptance criterion for "sufficiently low loss" is configurable. By
  default we compute the 75th percentile of final SGD losses and keep SGLD
  samples with loss <= that threshold.
- MSD is computed in the **full weight space** (Euclidean). We use **per-epoch**
  snapshots (after each epoch) for stability and portability.
"""

import os
import json
import math
import time
import random
from dataclasses import dataclass, asdict
from typing import Tuple, List, Dict

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
from tqdm import tqdm
from sklearn.datasets import make_moons
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt
import pandas as pd


# ------------------------------
# Utilities
# ------------------------------

def set_seed(seed: int) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def ensure_dir(path: str) -> None:
    if path and not os.path.exists(path):
        os.makedirs(path, exist_ok=True)


def flatten_state_dict(state_dict: Dict[str, torch.Tensor], key_order: List[str]) -> np.ndarray:
    """
    Flatten a state_dict into a single 1D numpy array using a fixed key order.
    """
    flat_list = []
    for k in key_order:
        t = state_dict[k].detach().cpu().numpy().ravel()
        flat_list.append(t)
    return np.concatenate(flat_list, axis=0)


def state_dict_key_order(model: nn.Module) -> List[str]:
    """
    Return a deterministic order of parameter keys to use for flattening.
    """
    return list(model.state_dict().keys())


# ------------------------------
# Model & Data
# ------------------------------

class TinyMLP(nn.Module):
    def __init__(self, in_dim: int = 2, hidden: int = 16, out_dim: int = 2):
        super().__init__()
        self.fc1 = nn.Linear(in_dim, hidden)
        self.fc2 = nn.Linear(hidden, out_dim)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x


def make_toy_data(n_samples: int, noise: float, seed: int) -> Tuple[torch.Tensor, torch.Tensor]:
    X, y = make_moons(n_samples=n_samples, noise=noise, random_state=seed)
    X = X.astype(np.float32)
    y = y.astype(np.int64)
    return torch.from_numpy(X), torch.from_numpy(y)


# ------------------------------
# Training (SGD) + MSD tracking
# ------------------------------

@dataclass
class SGDConfig:
    lr: float = 0.1
    weight_decay: float = 0.0       # L2
    max_epochs: int = 1000
    batch_size: int = 128
    patience: int = 50              # for early stopping
    min_delta: float = 1e-6         # for early stopping
    num_inits: int = 50             # number of random initializations
    init_seed_offset: int = 0       # added to base_seed for SGD runs


def train_sgd_once(model: nn.Module,
                   data_loader: DataLoader,
                   device: torch.device,
                   cfg: SGDConfig,
                   global_pbar: tqdm,
                   key_order: List[str]) -> Tuple[float, int, List[float]]:
    """
    Train a single model with vanilla SGD until convergence (early stopping).
    Returns (final_loss, steps_used, msd_sq_per_epoch), where:
      - msd_sq_per_epoch[j] = squared displacement ||theta_j - theta_0||^2
        measured at the END of epoch j (1-indexed in practice).
    We also include a value for epoch 0 (exactly 0.0) for convenience, so the
    length will be steps_used + 1.
    """
    opt = torch.optim.SGD(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)
    criterion = nn.CrossEntropyLoss()

    # Capture the initial parameter vector
    init_vec = flatten_state_dict(model.state_dict(), key_order)
    msd_sq: List[float] = [0.0]  # epoch 0 displacement

    best_loss = float('inf')
    epochs_no_improve = 0
    final_loss = None
    steps_used = 0

    for epoch in range(cfg.max_epochs):
        model.train()
        epoch_loss = 0.0
        total = 0
        for xb, yb in data_loader:
            xb = xb.to(device)
            yb = yb.to(device)
            opt.zero_grad()
            logits = model(xb)
            loss = criterion(logits, yb)
            loss.backward()
            opt.step()

            batch_sz = yb.size(0)
            epoch_loss += loss.item() * batch_sz
            total += batch_sz

        epoch_loss /= max(total, 1)
        final_loss = epoch_loss
        steps_used += 1

        # Record displacement at end of epoch
        cur_vec = flatten_state_dict(model.state_dict(), key_order)
        disp_sq = float(np.sum((cur_vec - init_vec) ** 2))
        msd_sq.append(disp_sq)

        global_pbar.update(1)

        # Early stopping
        if best_loss - epoch_loss > cfg.min_delta:
            best_loss = epoch_loss
            epochs_no_improve = 0
        else:
            epochs_no_improve += 1
            if epochs_no_improve >= cfg.patience:
                break

    return float(final_loss), steps_used, msd_sq


# ------------------------------
# Bayesian Sampling via SGLD
# ------------------------------

@dataclass
class SGLDConfig:
    lr: float = 1e-3               # Langevin step size
    temperature: float = 1.0       # "T" in SGLD; noise scales with sqrt(2*lr*T)
    weight_decay: float = 1e-4     # Gaussian prior precision equivalent
    total_steps: int = 5000        # total SGLD updates per chain
    burn_in: int = 1000            # steps to discard at start
    thinning: int = 20             # save one sample every 'thinning' steps after burn-in
    num_chains: int = 1            # independent chains
    init_seed_offset: int = 10     # added to base_seed for chains


@torch.no_grad()
def eval_loss(model: nn.Module, data_loader: DataLoader, device: torch.device) -> float:
    model.eval()
    criterion = nn.CrossEntropyLoss(reduction="mean")
    total_loss = 0.0
    total = 0
    for xb, yb in data_loader:
        xb = xb.to(device)
        yb = yb.to(device)
        logits = model(xb)
        loss = criterion(logits, yb)
        total_loss += loss.item() * yb.size(0)
        total += yb.size(0)
    return float(total_loss / max(total, 1))


def sgld_step(model: nn.Module,
              lr: float,
              temperature: float,
              weight_decay: float,
              data_loader: DataLoader,
              device: torch.device):
    """
    One SGLD step: compute gradient on a random minibatch (here: full batch for stability),
    then do: theta <- theta - lr * grad + sqrt(2 * lr * T) * N(0, I)
    L2 weight decay corresponds to Gaussian prior.
    """
    # Full-batch for toy stability
    xb, yb = next(iter(data_loader))
    xb = xb.to(device)
    yb = yb.to(device)

    # Compute gradients
    model.train()
    for p in model.parameters():
        if p.grad is not None:
            p.grad.zero_()

    logits = model(xb)
    loss = F.cross_entropy(logits, yb, reduction="mean")
    loss.backward()

    with torch.no_grad():
        for p in model.parameters():
            if p.grad is None:
                continue
            # Apply L2 gradient (Gaussian prior): grad += weight_decay * p
            grad = p.grad + weight_decay * p
            # Langevin update
            noise = torch.randn_like(p) * math.sqrt(2.0 * lr * temperature)
            p.add_(-lr * grad + noise)


def run_sgld_chains(make_model_fn,
                    data_loader: DataLoader,
                    device: torch.device,
                    cfg: SGLDConfig,
                    global_pbar: tqdm,
                    accept_loss_threshold: float,
                    base_seed: int,
                    key_order: List[str],
                    save_dir: str,
                    chain_prefix: str = "chain") -> Tuple[List[np.ndarray], List[float]]:
    """
    Run num_chains of SGLD. After burn-in & thinning, keep samples whose loss <= accept_loss_threshold.
    Save qualifying state_dicts to `save_dir`. Returns (flattened_vectors, losses) of saved samples.
    """
    ensure_dir(save_dir)
    saved_vectors: List[np.ndarray] = []
    saved_losses: List[float] = []
    sample_idx = 0

    for c in range(cfg.num_chains):
        set_seed(base_seed + cfg.init_seed_offset + c)
        model = make_model_fn().to(device)

        # Run SGLD
        for step in range(cfg.total_steps):
            sgld_step(model, cfg.lr, cfg.temperature, cfg.weight_decay, data_loader, device)
            global_pbar.update(1)

            # Save after burn-in with thinning
            if step + 1 > cfg.burn_in and ((step + 1 - cfg.burn_in) % cfg.thinning == 0):
                loss = eval_loss(model, data_loader, device)
                if loss <= accept_loss_threshold:
                    # Save state_dict
                    fname = os.path.join(save_dir, f"{chain_prefix}{c:02d}_sample{sample_idx:05d}.pt")
                    torch.save(model.state_dict(), fname)
                    # Save flattened vector
                    vec = flatten_state_dict(model.state_dict(), key_order)
                    npy_name = os.path.join(save_dir, f"{chain_prefix}{c:02d}_sample{sample_idx:05d}.npy")
                    np.save(npy_name, vec)
                    saved_vectors.append(vec)
                    saved_losses.append(float(loss))
                    sample_idx += 1

    return saved_vectors, saved_losses


# ------------------------------
# Diffusion / MSD helpers
# ------------------------------

def aggregate_msd_and_fit(msd_sq_runs: List[List[float]],
                          steps_used_list: List[int],
                          n_grid: int = 50,
                          fit_t_min: float = 1e-2) -> Dict[str, float]:
    """
    Given per-run squared displacement sequences (including epoch 0 = 0), normalize each run's
    time axis to [0,1] using its steps_used (number of epochs), resample onto a common grid,
    compute MSD(t) = mean_k disp_sq_k(t), then R(t) = sqrt(MSD(t)).
    Fit log R vs log t for t >= fit_t_min to estimate the walk dimension.

    Returns a dict with the arrays and fit statistics.
    """
    assert len(msd_sq_runs) == len(steps_used_list)
    if len(msd_sq_runs) == 0:
        return {
            "t_grid": np.array([]),
            "msd": np.array([]),
            "R": np.array([]),
            "log_t": np.array([]),
            "log_R": np.array([]),
            "slope": float("nan"),
            "intercept": float("nan"),
            "d_w": float("nan"),
            "diffusion_exponent": float("nan"),
        }

    # Common time grid in (0,1]; we avoid t=0 to keep logs finite.
    t_grid = np.linspace(1.0 / n_grid, 1.0, n_grid)  # exclude 0

    # Interpolate each run's disp_sq over normalized time
    disp_sq_interp = []
    for disp_sq, steps in zip(msd_sq_runs, steps_used_list):
        # There are steps epochs => we have steps+1 displacement values (0..steps)
        t_run = np.linspace(0.0, 1.0, steps + 1)
        y_run = np.array(disp_sq, dtype=np.float64)
        if len(t_run) != len(y_run):
            min_len = min(len(t_run), len(y_run))
            t_run = t_run[:min_len]
            y_run = y_run[:min_len]
        y_interp = np.interp(t_grid, t_run, y_run)
        disp_sq_interp.append(y_interp)

    disp_sq_interp = np.stack(disp_sq_interp, axis=0)  # shape (runs, n_grid)
    msd = np.mean(disp_sq_interp, axis=0)              # MSD(t) across runs
    R = np.sqrt(np.maximum(msd, 1e-30))                # root mean squared displacement

    # Fit log R = m * log t + b over t >= fit_t_min
    mask = t_grid >= fit_t_min
    log_t = np.log(t_grid[mask])
    log_R = np.log(R[mask])
    if np.all(~np.isfinite(log_t)) or np.all(~np.isfinite(log_R)):
        slope = float("nan")
        intercept = float("nan")
        d_w = float("nan")
        diffusion_exponent = float("nan")
    else:
        m, b = np.polyfit(log_t, log_R, 1)
        slope = float(m)
        intercept = float(b)
        # R ~ t^(1/d_w) => slope = 1/d_w
        if slope <= 0:
            d_w = float("inf")
            diffusion_exponent = float("-inf")
        else:
            d_w = float(1.0 / slope)
            diffusion_exponent = float(2.0 - d_w)
    if np.all(~np.isfinite(log_t)) or np.all(~np.isfinite(log_R)):
        slope = float("nan")
        intercept = float("nan")
        d_w = float("nan")
        diffusion_exponent = float("nan")
        alpha = float("nan")
    else:
        m, b = np.polyfit(log_t, log_R, 1)
        slope = float(m)
        intercept = float(b)
        # R ~ t^(1/d_w) => slope = 1/d_w = alpha
        alpha = float(slope)
        if slope <= 0:
            d_w = float("inf")
            diffusion_exponent = float("-inf")
        else:
            d_w = float(1.0 / slope)
            diffusion_exponent = float(2.0 - d_w)

    return {
        "t_grid": t_grid,
        "msd": msd,
        "R": R,
        "log_t": log_t,
        "log_R": log_R,
        "slope": slope,
        "intercept": intercept,
        "d_w": d_w,
        "diffusion_exponent": diffusion_exponent,
        "alpha": alpha,  # <--- NEW
    }

# ------------------------------
# Experiment Runner
# ------------------------------

@dataclass
class ExperimentConfig:
    output_root: str = "./outputs"
    sgd_subdir: str = "sgd_weights"
    bayes_subdir: str = "bayesian_weights"
    plot_filename_png: str = "embedding.png"
    plot_filename_pdf: str = "embedding.pdf"
    data_csv_filename: str = "embedding_points.csv"
    weights_npz_filename: str = "flattened_weights.npz"
    summary_json_filename: str = "summary.json"
    param_order_filename: str = "parameter_order.txt"
    pca_seed: int = 0
    msd_curve_csv: str = "sgd_msd_curve.csv"
    msd_stats_json: str = "sgd_msd_stats.json"
    displacement_plot_png: str = "sgd_displacement.png"
    displacement_plot_pdf: str = "sgd_displacement.pdf"


def run_experiment(
    base_seed: int,
    data_n: int,
    data_noise: float,
    sgd_cfg: SGDConfig,
    sgld_cfg: SGLDConfig,
    exp_cfg: ExperimentConfig,
    model_hidden: int = 16,
    msd_grid_points: int = 50,
    fit_t_min: float = 1e-2
):
    set_seed(base_seed)

    device = torch.device("cpu")  # keep CPU for portability
    # Prepare output dirs
    ensure_dir(exp_cfg.output_root)
    sgd_dir = os.path.join(exp_cfg.output_root, exp_cfg.sgd_subdir)
    bayes_dir = os.path.join(exp_cfg.output_root, exp_cfg.bayes_subdir)
    ensure_dir(sgd_dir)
    ensure_dir(bayes_dir)

    # Data & loaders
    X, y = make_toy_data(n_samples=data_n, noise=data_noise, seed=base_seed)
    ds = TensorDataset(X, y)
    dl = DataLoader(ds, batch_size=max(1, min(sgd_cfg.batch_size, data_n)), shuffle=True)

    # Reference model & flattening order
    ref_model = TinyMLP(in_dim=2, hidden=model_hidden, out_dim=2).to(device)
    key_order = state_dict_key_order(ref_model)
    # Save parameter order for reproducibility
    with open(os.path.join(exp_cfg.output_root, exp_cfg.param_order_filename), "w") as f:
        for k in key_order:
            f.write(k + "\n")

    # Progress bar: SGD steps (num_inits * max_epochs) + SGLD steps (num_chains * total_steps)
    total_steps = sgd_cfg.num_inits * sgd_cfg.max_epochs + sgld_cfg.num_chains * sgld_cfg.total_steps
    global_pbar = tqdm(total=total_steps, desc="Experiment progress", ncols=100)

    # ---- Phase 1: SGD runs ----
    sgd_vectors: List[np.ndarray] = []
    sgd_losses: List[float] = []
    sgd_files: List[str] = []
    msd_sq_runs: List[List[float]] = []   # per-run squared displacement seq (len steps_used+1)
    steps_used_list: List[int] = []

    for i in range(sgd_cfg.num_inits):
        # Seed each run deterministically
        set_seed(base_seed + sgd_cfg.init_seed_offset + i)
        model = TinyMLP(in_dim=2, hidden=model_hidden, out_dim=2).to(device)

        final_loss, steps_used, msd_sq = train_sgd_once(
            model, dl, device, sgd_cfg, global_pbar, key_order
        )

        # Save weights
        fname = os.path.join(sgd_dir, f"sgd_model_{i:04d}.pt")
        torch.save(model.state_dict(), fname)
        sgd_files.append(fname)

        # Save flattened vector
        vec = flatten_state_dict(model.state_dict(), key_order)
        np.save(os.path.join(sgd_dir, f"sgd_weights_{i:04d}.npy"), vec)
        sgd_vectors.append(vec)
        sgd_losses.append(final_loss)

        # MSD tracking
        msd_sq_runs.append(msd_sq)
        steps_used_list.append(steps_used)

    # Compute acceptance threshold for SGLD samples
    sgd_loss_array = np.array(sgd_losses, dtype=np.float64)
    accept_quantile = 0.75
    accept_threshold = float(np.quantile(sgd_loss_array, accept_quantile))

    # ---- Phase 2: Bayesian sampling via SGLD ----
    def mk_model():
        return TinyMLP(in_dim=2, hidden=model_hidden, out_dim=2)

    bayes_vectors, bayes_losses = run_sgld_chains(
        make_model_fn=mk_model,
        data_loader=dl,
        device=device,
        cfg=sgld_cfg,
        global_pbar=global_pbar,
        accept_loss_threshold=accept_threshold,
        base_seed=base_seed,
        key_order=key_order,
        save_dir=bayes_dir,
        chain_prefix="c"
    )
    global_pbar.close()

    # If no Bayesian samples accepted (threshold too strict), fall back to taking thinned samples regardless of loss.
    if len(bayes_vectors) == 0:
        print("[WARN] No SGLD samples met the loss threshold; relaxing criterion to keep all thinned samples.")
        set_seed(base_seed + 12345)
        temp_cfg = SGLDConfig(**asdict(sgld_cfg))
        global_pbar = tqdm(total=temp_cfg.num_chains * temp_cfg.total_steps, desc="Fallback SGLD pass", ncols=100)
        tmp_vectors: List[np.ndarray] = []
        tmp_losses: List[float] = []

        for c in range(temp_cfg.num_chains):
            set_seed(base_seed + temp_cfg.init_seed_offset + c + 999)
            model = mk_model().to(device)
            for step in range(temp_cfg.total_steps):
                sgld_step(model, temp_cfg.lr, temp_cfg.temperature, temp_cfg.weight_decay, dl, device)
                global_pbar.update(1)
                if step + 1 > temp_cfg.burn_in and ((step + 1 - temp_cfg.burn_in) % temp_cfg.thinning == 0):
                    loss = eval_loss(model, dl, device)
                    vec = flatten_state_dict(model.state_dict(), key_order)
                    tmp_vectors.append(vec)
                    tmp_losses.append(float(loss))
        global_pbar.close()
        bayes_vectors = tmp_vectors
        bayes_losses = tmp_losses

    # ------------------------------
    # MSD aggregation + diffusion fit
    # ------------------------------
    msd_fit = aggregate_msd_and_fit(msd_sq_runs, steps_used_list,
                                    n_grid=msd_grid_points, fit_t_min=fit_t_min)

    # Save MSD curve
    msd_curve_df = pd.DataFrame({
        "t": msd_fit["t_grid"],
        "MSD": msd_fit["msd"],
        "R": msd_fit["R"],
        "log_t": np.log(msd_fit["t_grid"] + 1e-30),
        "log_R": np.log(msd_fit["R"] + 1e-30),
    })
    msd_curve_csv_path = os.path.join(exp_cfg.output_root, exp_cfg.msd_curve_csv)
    msd_curve_df.to_csv(msd_curve_csv_path, index=False)

    # Save MSD stats
    msd_stats = {
        "slope_logR_logt": msd_fit["slope"],
        "intercept_logR_logt": msd_fit["intercept"],
        "walk_dimension_d_w": msd_fit["d_w"],
        "diffusion_exponent_2_minus_d_w": msd_fit["diffusion_exponent"],
        "fit_t_min": fit_t_min,
        "grid_points": msd_grid_points,
        "num_runs": len(msd_sq_runs),
        "avg_steps_used": float(np.mean(steps_used_list)) if len(steps_used_list) > 0 else 0.0,
    }
    msd_stats_path = os.path.join(exp_cfg.output_root, exp_cfg.msd_stats_json)
    with open(msd_stats_path, "w") as f:
        json.dump(msd_stats, f, indent=2)

    # Plot average displacement R(t) over normalized time
    disp_fig, disp_ax = plt.subplots(figsize=(6.0, 4.2), dpi=150)

    # Main curve: measured average displacement
    disp_ax.plot(msd_fit["t_grid"], msd_fit["R"], linewidth=1.8, label="Average R(t)")

    # Brownian reference: R_brown(t) ~ t^{1/2}, scaled to match R(1) for visual comparability
    R_end = msd_fit["R"][-1] if len(msd_fit["R"]) > 0 else 1.0
    brown_ref = R_end * np.power(msd_fit["t_grid"], 0.5)
    disp_ax.plot(msd_fit["t_grid"], brown_ref, linestyle="--", color="black", linewidth=1.2,
                 label="Brownian reference (α = 0.5)")

    # Titles / labels
    disp_ax.set_title("Average SGD displacement over normalized time")
    disp_ax.set_xlabel("Normalized time (t)")
    disp_ax.set_ylabel("R(t) = sqrt(MSD(t))")
    disp_ax.grid(alpha=0.3)

    # Annotate the fitted exponent α (slope of log R vs log t)
    alpha = msd_fit["alpha"]
    alpha_text = f"Fitted displacement exponent α: {alpha:.3f}" if np.isfinite(
        alpha) else "Fitted displacement exponent α: n/a"
    disp_ax.text(0.02, 0.02, alpha_text, transform=disp_ax.transAxes, fontsize=9,
                 bbox=dict(facecolor="white", alpha=0.75, boxstyle="round,pad=0.3"))

    disp_ax.legend(loc="best", frameon=True)

    plt.tight_layout()
    disp_png = os.path.join(exp_cfg.output_root, exp_cfg.displacement_plot_png)
    disp_pdf = os.path.join(exp_cfg.output_root, exp_cfg.displacement_plot_pdf)
    disp_fig.savefig(disp_png)
    disp_fig.savefig(disp_pdf)
    plt.close(disp_fig)

    # ------------------------------
    # Dimensionality reduction (PCA)
    # ------------------------------
    all_vectors = []
    labels = []
    losses = []

    for v, l in zip(sgd_vectors, sgd_losses):
        all_vectors.append(v)
        labels.append("SGD")
        losses.append(l)

    for v, l in zip(bayes_vectors, bayes_losses):
        all_vectors.append(v)
        labels.append("Bayes")
        losses.append(l)

    all_vectors = np.stack(all_vectors, axis=0)
    labels = np.array(labels)
    losses = np.array(losses)

    # Standardize before PCA
    scaler = StandardScaler()
    all_vectors_std = scaler.fit_transform(all_vectors)

    pca = PCA(n_components=2, random_state=exp_cfg.pca_seed)
    coords = pca.fit_transform(all_vectors_std)  # shape (N, 2)

    # ------------------------------
    # Plot
    # ------------------------------
    fig, ax = plt.subplots(figsize=(6.8, 5.2), dpi=150)
    mask_sgd = labels == "SGD"
    mask_bayes = labels == "Bayes"

    ax.scatter(coords[mask_sgd, 0], coords[mask_sgd, 1], s=18, alpha=0.75, label="SGD")
    ax.scatter(coords[mask_bayes, 0], coords[mask_bayes, 1], s=18, alpha=0.75, label="Bayesian (SGLD)")

    ax.set_title("Stationary-like distributions in weight space (PCA-2D)")
    ax.set_xlabel("PC 1")
    ax.set_ylabel("PC 2")
    ax.legend(loc="best", frameon=True)
    ax.grid(alpha=0.2)

    # Annotate diffusion exponent on the plot
    diff_exp = msd_fit["diffusion_exponent"]
    if np.isfinite(diff_exp):
        text = f"Diffusion exponent (2 - d_w): {diff_exp:.3f}"
    else:
        text = "Diffusion exponent (2 - d_w): n/a"
    ax.text(0.02, 0.02, text, transform=ax.transAxes, fontsize=9,
            bbox=dict(facecolor="white", alpha=0.75, boxstyle="round,pad=0.3"))

    # Save plot
    plt.tight_layout()
    png_path = os.path.join(exp_cfg.output_root, exp_cfg.plot_filename_png)
    pdf_path = os.path.join(exp_cfg.output_root, exp_cfg.plot_filename_pdf)
    fig.savefig(png_path)
    fig.savefig(pdf_path)
    plt.close(fig)

    # ------------------------------
    # Save data used to make the plot
    # ------------------------------
    df = pd.DataFrame({
        "x": coords[:, 0],
        "y": coords[:, 1],
        "method": labels,
        "loss": losses
    })
    csv_path = os.path.join(exp_cfg.output_root, exp_cfg.data_csv_filename)
    df.to_csv(csv_path, index=False)

    # Save raw flattened weights and metadata
    npz_path = os.path.join(exp_cfg.output_root, exp_cfg.weights_npz_filename)
    np.savez_compressed(
        npz_path,
        weights=all_vectors,
        labels=labels,
        losses=losses,
        key_order=np.array(key_order, dtype=object)
    )

    # Save summary JSON (now includes diffusion stats too)
    summary = {
        "timestamp": time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()),
        "base_seed": base_seed,
        "data": {"n_samples": data_n, "noise": data_noise},
        "sgd_cfg": asdict(sgd_cfg),
        "sgld_cfg": asdict(sgld_cfg),
        "exp_cfg": asdict(exp_cfg),
        "sgd": {
            "num_models": int(len(sgd_vectors)),
            "loss_mean": float(np.mean(sgd_losses)),
            "loss_std": float(np.std(sgd_losses)),
            "loss_min": float(np.min(sgd_losses)),
            "loss_max": float(np.max(sgd_losses)),
            "accept_threshold_quantile": 0.75,
        },
        "bayes": {
            "num_samples_kept": int(len(bayes_vectors)),
            "loss_mean": float(np.mean(bayes_losses)) if len(bayes_losses) > 0 else None,
            "loss_std": float(np.std(bayes_losses)) if len(bayes_losses) > 0 else None,
            "loss_min": float(np.min(bayes_losses)) if len(bayes_losses) > 0 else None,
            "loss_max": float(np.max(bayes_losses)) if len(bayes_losses) > 0 else None,
            "accept_loss_threshold": accept_threshold,
        },
        "diffusion": {
            "slope_logR_logt": msd_fit["slope"],
            "intercept_logR_logt": msd_fit["intercept"],
            "walk_dimension_d_w": msd_fit["d_w"],
            "diffusion_exponent_2_minus_d_w": msd_fit["diffusion_exponent"],
            "fit_t_min": fit_t_min,
            "grid_points": msd_grid_points,
            "msd_curve_csv": os.path.abspath(msd_curve_csv_path),
            "msd_stats_json": os.path.abspath(msd_stats_path),
        },
        "outputs": {
            "plot_png": os.path.abspath(png_path),
            "plot_pdf": os.path.abspath(pdf_path),
            "embedding_csv": os.path.abspath(csv_path),
            "weights_npz": os.path.abspath(npz_path),
            "sgd_dir": os.path.abspath(sgd_dir),
            "bayes_dir": os.path.abspath(bayes_dir),
            "displacement_png": os.path.abspath(disp_png),
            "displacement_pdf": os.path.abspath(disp_pdf),
        }
    }
    with open(os.path.join(exp_cfg.output_root, exp_cfg.summary_json_filename), "w") as f:
        json.dump(summary, f, indent=2)

    print("\n=== Experiment complete ===")
    print(f"Saved plot to: {png_path}")
    print(f"Saved data CSV to: {csv_path}")
    print(f"Saved weights NPZ to: {npz_path}")
    print(f"SGD weights: {sgd_dir}")
    print(f"Bayesian (SGLD) weights: {bayes_dir}")
    print(f"MSD curve CSV: {msd_curve_csv_path}")
    print(f"MSD stats JSON: {msd_stats_path}")
    print(f"Displacement plot PNG: {disp_png}")
    print(f"Displacement plot PDF: {disp_pdf}")
    print(f"Summary JSON: {os.path.join(exp_cfg.output_root, exp_cfg.summary_json_filename)}")
    if np.isfinite(msd_fit["d_w"]):
        print(f"Estimated walk dimension d_w: {msd_fit['d_w']:.4f}")
        print(f"Diffusion exponent (2 - d_w): {msd_fit['diffusion_exponent']:.4f}")
    else:
        print("Could not estimate walk dimension (check MSD fit).")


# ------------------------------
# Main (configure here)
# ------------------------------

if __name__ == "__main__":
    # --- Global seeds & toy data ---
    BASE_SEED = 42
    DATA_N = 500           # number of points in toy dataset
    DATA_NOISE = 0.2       # noise for make_moons
    MODEL_HIDDEN = 16      # hidden units in TinyMLP

    # --- SGD settings ---
    sgd_cfg = SGDConfig(
        lr=0.1,
        weight_decay=0.0,
        max_epochs=2500,
        batch_size=128,
        patience=2500,
        min_delta=1e-6,
        num_inits=50,              # number of random initializations for SGD
        init_seed_offset=0
    )

    # --- SGLD (Bayesian) settings ---
    sgld_cfg = SGLDConfig(
        lr=1e-3,
        temperature=1.0,
        weight_decay=1e-4,
        total_steps=5000,
        burn_in=1000,
        thinning=20,
        num_chains=10,
        init_seed_offset=10
    )

    # --- Output & plotting settings ---
    exp_cfg = ExperimentConfig(
        output_root="D:/SGDExperiments/sgd_bayes_msd",
        sgd_subdir="sgd_weights",
        bayes_subdir="bayesian_weights",
        plot_filename_png="embedding.png",
        plot_filename_pdf="embedding.pdf",
        data_csv_filename="embedding_points.csv",
        weights_npz_filename="flattened_weights.npz",
        summary_json_filename="summary.json",
        param_order_filename="parameter_order.txt",
        pca_seed=0,
        msd_curve_csv="sgd_msd_curve.csv",
        msd_stats_json="sgd_msd_stats.json",
        displacement_plot_png="sgd_displacement.png",
        displacement_plot_pdf="sgd_displacement.pdf",
    )

    # Create outputs directory early
    ensure_dir(exp_cfg.output_root)

    # --- MSD aggregation settings ---
    MSD_GRID_POINTS = 50     # number of points in (0,1] time grid for averaging
    FIT_T_MIN = 1e-2         # ignore very small t for the log-log fit

    # Run the experiment
    run_experiment(
        base_seed=BASE_SEED,
        data_n=DATA_N,
        data_noise=DATA_NOISE,
        sgd_cfg=sgd_cfg,
        sgld_cfg=sgld_cfg,
        exp_cfg=exp_cfg,
        model_hidden=MODEL_HIDDEN,
        msd_grid_points=MSD_GRID_POINTS,
        fit_t_min=FIT_T_MIN
    )
