#!/usr/bin/env python3
import argparse
import logging
import math
import os
import random
from dataclasses import dataclass
from typing import Dict, Tuple

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset


# -----------------------------
# Utils
# -----------------------------
def seed_all(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.deterministic = False


def pick_device() -> torch.device:
    if torch.cuda.is_available():
        return torch.device("cuda")
    if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
        return torch.device("mps")
    return torch.device("cpu")


@dataclass
class GMMData:
    X: np.ndarray        # (n,2)
    y: np.ndarray        # (n,)
    z: np.ndarray        # component index in 1..K (for reference)
    means: np.ndarray    # (K,2)


def generate_gmm_dataset(n: int, K: int, sigma: float, seed: int, means: np.ndarray = None) -> GMMData:
    """
    2D GMM with K equally-likely components.
    Means on x-axis: mu_z = (pi*z, 0) for z in {1..K} if not provided.
    Labels: y = 1 if z is odd, else 0.
    """
    rng = np.random.default_rng(seed)
    if means is None:
        means = np.stack([(math.pi * (z + 1), 0.0) for z in range(K)], axis=0)  # z=0..K-1 -> (π,2π,...,Kπ)
    else:
        assert means.shape == (K, 2)

    z = rng.integers(low=1, high=K + 1, size=n)  # 1..K
    mu = means[z - 1]  # (n,2)
    X = mu + rng.normal(loc=0.0, scale=sigma, size=(n, 2))
    y = (z % 2 == 1).astype(np.int64)
    return GMMData(X=X, y=y, z=z, means=means)


# -----------------------------
# Model (periodic features MLP)
# -----------------------------
class FourierFeatures(nn.Module):
    """Positional encodings for x-dimension only (period ≈ 2π)."""
    def __init__(self, freqs=(0.5, 1.0, 2.0, 4.0, 8.0)):
        super().__init__()
        self.register_buffer("freqs", torch.tensor(freqs, dtype=torch.float32))

    def forward(self, x):  # x: (B,2) -> [x, y, sin(f*x), cos(f*x)]
        x1 = x[:, :1]  # use only the x coordinate for periodic features
        feats = [x]    # keep raw (x, y)
        for f in self.freqs:
            feats.append(torch.sin(f * x1))
            feats.append(torch.cos(f * x1))
        return torch.cat(feats, dim=1)


class MLP(nn.Module):
    def __init__(self, hidden=128, freqs=(0.5, 1.0, 2.0, 4.0, 8.0), out_dim=2):
        super().__init__()
        self.ff = FourierFeatures(freqs=freqs)
        in_dim = 2 + 2 * len(freqs)  # (x,y) + sin/cos(x) per frequency
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden),
            nn.GELU(),
            nn.LayerNorm(hidden),
            nn.Linear(hidden, hidden),
            nn.GELU(),
            nn.Linear(hidden, out_dim),
        )
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight, gain=0.8)
                nn.init.zeros_(m.bias)

    def forward(self, x):
        x = self.ff(x)
        return self.net(x)


@torch.no_grad()
def evaluate(model: nn.Module, loader: DataLoader, device: torch.device) -> Tuple[float, float]:
    """Returns (avg_cross_entropy, accuracy)."""
    model.eval()
    total_loss, total_correct, total = 0.0, 0, 0
    ce = nn.CrossEntropyLoss(reduction="sum")
    for xb, yb in loader:
        xb = xb.to(device)
        yb = yb.to(device)
        logits = model(xb)
        total_loss += ce(logits, yb).item()
        preds = logits.argmax(dim=1)
        total_correct += (preds == yb).sum().item()
        total += yb.numel()
    return total_loss / total, total_correct / total


def train_model(
    train_ds: TensorDataset,
    val_ds: TensorDataset,
    device: torch.device,
    epochs: int = 15,
    batch_size: int = 1024,
    lr: float = 1e-3,
) -> nn.Module:
    pin = (device.type == "cuda")
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=pin)
    val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=pin)

    model = MLP().to(device)
    opt = torch.optim.Adam(model.parameters(), lr=lr)
    ce = nn.CrossEntropyLoss()

    for epoch in range(1, epochs + 1):
        model.train()
        running_loss, running_correct, seen = 0.0, 0, 0
        for xb, yb in train_loader:
            xb = xb.to(device)
            yb = yb.to(device)
            logits = model(xb)
            loss = ce(logits, yb)

            opt.zero_grad(set_to_none=True)
            loss.backward()
            opt.step()

            running_loss += loss.item() * yb.size(0)
            running_correct += (logits.argmax(dim=1) == yb).sum().item()
            seen += yb.size(0)

        train_ce = running_loss / seen
        train_acc = running_correct / seen
        val_ce, val_acc = evaluate(model, val_loader, device)

        logging.info(f"Epoch {epoch:02d} | train CE {train_ce:.4f} | train Acc {train_acc:.4f} | "
                     f"val CE {val_ce:.4f} | val Acc {val_acc:.4f}")

    return model


@torch.no_grad()
def get_per_sample_losses(model: nn.Module, X: np.ndarray, y: np.ndarray, device: torch.device) -> Dict[str, np.ndarray]:
    """Return dict with per-sample cross-entropy and 0-1 errors."""
    model.eval()
    xb = torch.from_numpy(X).float().to(device)
    yb = torch.from_numpy(y).long().to(device)

    logits = model(xb)
    ce_all = F.cross_entropy(logits, yb, reduction="none").detach().cpu().numpy()
    preds = logits.argmax(dim=1)
    err01_all = (preds != yb).float().detach().cpu().numpy()
    return {"ce": ce_all, "01": err01_all}


# -----------------------------
# Clustering strategies
# -----------------------------
def cluster_T1_grid(X: np.ndarray, K_grid: int = 10) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """
    Gridify bounding box into K_grid x K_grid cells.
    Returns (cluster_ids, counts_per_cluster, cell_centers).
    """
    xmin, ymin = X.min(axis=0)
    xmax, ymax = X.max(axis=0)

    x_edges = np.linspace(xmin, xmax, K_grid + 1)
    y_edges = np.linspace(ymin, ymax, K_grid + 1)

    ix = np.clip(np.digitize(X[:, 0], x_edges, right=False) - 1, 0, K_grid - 1)
    iy = np.clip(np.digitize(X[:, 1], y_edges, right=False) - 1, 0, K_grid - 1)

    cluster_ids = ix * K_grid + iy  # 0..K_grid*K_grid-1
    K_total = K_grid * K_grid
    counts = np.bincount(cluster_ids, minlength=K_total)

    x_centers = 0.5 * (x_edges[:-1] + x_edges[1:])
    y_centers = 0.5 * (y_edges[:-1] + y_edges[1:])
    cx, cy = np.meshgrid(x_centers, y_centers, indexing="ij")
    centers = np.stack([cx.ravel(), cy.ravel()], axis=1)
    return cluster_ids, counts, centers


def cluster_T2_random_centroids(X: np.ndarray, K: int, seed: int) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """
    Uniform random centroids over the bounding box; assign by nearest centroid (Euclidean).
    """
    rng = np.random.default_rng(seed)
    xmin, ymin = X.min(axis=0); xmax, ymax = X.max(axis=0)
    centroids = np.stack([rng.uniform(xmin, xmax, size=K), rng.uniform(ymin, ymax, size=K)], axis=1)

    diffs = X[:, None, :] - centroids[None, :, :]
    d2 = np.sum(diffs * diffs, axis=2)
    cluster_ids = np.argmin(d2, axis=1)
    counts = np.bincount(cluster_ids, minlength=K)
    return cluster_ids, counts, centroids


def cluster_T3_true_means(X: np.ndarray, means: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """Assign each point to nearest true GMM mean."""
    diffs = X[:, None, :] - means[None, :, :]
    d2 = np.sum(diffs * diffs, axis=2)
    cluster_ids = np.argmin(d2, axis=1)
    counts = np.bincount(cluster_ids, minlength=means.shape[0])
    return cluster_ids, counts, means


# -----------------------------
# Pieces of the (newer) bound (5)
# -----------------------------
def compute_A1_global(loss_all: np.ndarray) -> float:
    """F(S,h) = average loss over S."""
    return float(loss_all.mean())


def compute_F_Si(loss_all: np.ndarray, cluster_ids: np.ndarray, K_clusters: int) -> np.ndarray:
    """Vector of F(S_i,h) over clusters i=0..K-1 (0 for empty)."""
    sums = np.bincount(cluster_ids, weights=loss_all, minlength=K_clusters)
    counts = np.bincount(cluster_ids, minlength=K_clusters)
    with np.errstate(divide="ignore", invalid="ignore"):
        avg = np.divide(sums, counts, out=np.zeros_like(sums), where=counts > 0)
    return avg


def compute_A2(b: float, n: int, F_Si: np.ndarray, counts: np.ndarray) -> float:
    """A2 = (b/n) * sum_{i in T} F(S_i,h) with T = non-empty clusters."""
    nonempty = counts > 0
    return float((b / n) * F_Si[nonempty].sum())


def compute_u_hat_v2(n: int, counts: np.ndarray, b: float, gamma: float, delta: float) -> float:
    """
    'û' from slide (bound (5) version used earlier in the script):
    û = γ(1+2b)/(2n) + γ|T|b^2/(2n^2) + γ^2/2 * sum_i (n_i/n)^2 + γ^2 * sqrt( ln(2/δ) / (2n) )
    """
    T_nonempty = int((counts > 0).sum())
    frac_sq = np.sum((counts / n) ** 2)
    term1 = gamma * (1.0 + 2.0 * b) / (2.0 * n)
    term2 = gamma * T_nonempty * (b ** 2) / (2.0 * (n ** 2))
    term3 = (gamma ** 2) * 0.5 * frac_sq
    term4 = (gamma ** 2) * math.sqrt(math.log(2.0 / delta) / (2.0 * n))
    return float(term1 + term2 + term3 + term4)


def compute_A3(C: float, u_hat: float, alpha: float, gamma: float, n: int, delta: float) -> float:
    """A3 = C * sqrt( û * α * ln γ ) + C * sqrt( ln(2/δ) / (2n) )."""
    part1 = C * math.sqrt(max(0.0, u_hat) * alpha * math.log(gamma))
    part2 = C * math.sqrt(math.log(2.0 / delta) / (2.0 * n))
    return float(part1 + part2)


def compute_C_for_loss(loss_name: str, loss_all: np.ndarray, ce_cap: float = 27.6310211) -> float:
    """
    C = sup_z ℓ(h,z).
    - For 0-1 loss, C = 1.
    - For cross-entropy, use a computable finite cap: C = max capped CE on S, cap ≈ -log(1e-12).
    """
    if loss_name == "01":
        return 1.0
    capped = np.minimum(loss_all, ce_cap)
    return float(np.max(capped))


# -----------------------------
# Pieces of the **old** bound (3) from your screenshot
# F(P,h) <= F(S,h) + C * sqrt(û_old * α * ln γ) + g2_old(δ/2)
# -----------------------------
def compute_u_hat_old(n: int, counts: np.ndarray, gamma: float, delta: float) -> float:
    """
    û_old = γ/(2n) + (γ^2)/2 * sum_i (n_i/n)^2 + γ^2 * sqrt( (2/n) * ln(2K/δ) )
      - K is the total number of cells in the partition (including empties).
      - The sum runs over all clusters (empties contribute 0).
    """
    K_clusters = counts.shape[0]
    frac_sq = np.sum((counts / n) ** 2)
    term1 = gamma / (2.0 * n)
    term2 = (gamma ** 2) * 0.5 * frac_sq
    term3 = (gamma ** 2) * math.sqrt((2.0 / n) * math.log((2.0 * K_clusters) / max(1e-12, delta)))
    return float(term1 + term2 + term3)


def compute_g2_old(n: int, counts: np.ndarray, C: float, delta: float) -> float:
    """
    g2_old(δ) = [ C(1+√2) √ln(2K/δ) / n ] * sum_{i∈T} √n_i + [ 4C |T| ln(2K/δ) ] / n
      - K is total number of cells in the partition (including empties).
      - T is the set of non-empty clusters.
    """
    K_clusters = counts.shape[0]
    T_nonempty = int((counts > 0).sum())
    sum_sqrt_n_i = float(np.sum(np.sqrt(counts[counts > 0])))
    ln_term = math.log((2.0 * K_clusters) / max(1e-12, delta))
    part1 = (C * (1.0 + math.sqrt(2.0)) * math.sqrt(ln_term) / n) * sum_sqrt_n_i
    part2 = (4.0 * C * T_nonempty * ln_term) / n
    return float(part1 + part2)


# -----------------------------
# Main experiment
# -----------------------------
def run(args):
    seed_all(args.seed)
    os.makedirs(args.outdir, exist_ok=True)
    device = pick_device()
    logging.basicConfig(filename=os.path.join(args.outdir, 'experiment.log'),
                        filemode='w',
                        level=logging.INFO,
                        format='%(asctime)s - %(levelname)s - %(message)s')
    logging.info(f"Using device: {device}")

    # Data
    K = args.K
    sigma = args.sigma
    means = np.stack([(math.pi * (z + 1), 0.0) for z in range(K)], axis=0)

    train = generate_gmm_dataset(n=args.n, K=K, sigma=sigma, seed=args.seed + 1, means=means)
    val   = generate_gmm_dataset(n=args.n, K=K, sigma=sigma, seed=args.seed + 2, means=means)

    # Train MLP (CE)
    train_ds = TensorDataset(torch.from_numpy(train.X).float(), torch.from_numpy(train.y).long())
    val_ds   = TensorDataset(torch.from_numpy(val.X).float(),   torch.from_numpy(val.y).long())
    model = train_model(train_ds, val_ds, device, epochs=args.epochs, batch_size=args.batch_size, lr=args.lr)

    # Per-sample losses for train & val
    losses_train = get_per_sample_losses(model, train.X, train.y, device)
    losses_val   = get_per_sample_losses(model, val.X,   val.y,   device)

    # Alpha values and constants
    alpha_list = [float(a) for a in args.alphas.split(",")]
    delta = args.delta

    def process_dataset(name: str, data: GMMData, losses: Dict[str, np.ndarray]):
        X = data.X
        n = X.shape[0]

        # Clusterings
        ids_T1, counts_T1, centers_T1 = cluster_T1_grid(X, K_grid=10)                   # 10x10 = 100 cells
        ids_T2, counts_T2, centers_T2 = cluster_T2_random_centroids(X, K=K, seed=args.seed + (123 if name == "train" else 456))
        ids_T3, counts_T3, centers_T3 = cluster_T3_true_means(X, data.means)

        partitions = {
            "T1": (ids_T1, counts_T1, centers_T1),
            "T2": (ids_T2, counts_T2, centers_T2),
            "T3": (ids_T3, counts_T3, centers_T3),
        }

        for loss_name, loss_all in losses.items():
            A1 = compute_A1_global(loss_all)  # F(S,h)
            C_val = compute_C_for_loss(loss_name, loss_all, ce_cap=args.ce_cap)
            b = math.sqrt(0.5 * n * math.log(2.0 / delta))

            # --- Newer bound (5) pieces ---
            a2_tbl = pd.DataFrame(index=["T1", "T2", "T3"], columns=alpha_list, dtype=float)
            a3_tbl = pd.DataFrame(index=["T1", "T2", "T3"], columns=alpha_list, dtype=float)
            bd5_tbl = pd.DataFrame(index=["T1", "T2", "T3"], columns=alpha_list, dtype=float)

            # --- Old bound (3) pieces (we will also export û_old and g2_old) ---
            uhat_old_tbl = pd.DataFrame(index=["T1", "T2", "T3"], columns=alpha_list, dtype=float)
            g2_old_tbl   = pd.DataFrame(index=["T1", "T2", "T3"], columns=alpha_list, dtype=float)
            oldb_tbl     = pd.DataFrame(index=["T1", "T2", "T3"], columns=alpha_list, dtype=float)

            for T_name, (ids, counts, _) in partitions.items():
                # ------ bound (5) ------
                K_clusters = counts.shape[0]
                F_Si = compute_F_Si(loss_all, ids, K_clusters)
                A2_const = compute_A2(b=b, n=n, F_Si=F_Si, counts=counts)  # independent of alpha

                for alpha in alpha_list:
                    gamma = (0.04) ** (-1.0 / alpha)

                    # v2 (newer) û and A3, Bound (5)
                    u_hat_v2 = compute_u_hat_v2(n=n, counts=counts, b=b, gamma=gamma, delta=delta)
                    A3_val = compute_A3(C=C_val, u_hat=u_hat_v2, alpha=alpha, gamma=gamma, n=n, delta=delta)
                    bound5 = A1 + A2_const + A3_val

                    a2_tbl.at[T_name, alpha] = A2_const
                    a3_tbl.at[T_name, alpha] = A3_val
                    bd5_tbl.at[T_name, alpha] = bound5

                    # ------ old bound (3) ------
                    u_hat_old = compute_u_hat_old(n=n, counts=counts, gamma=gamma, delta=delta)
                    g2_old = compute_g2_old(n=n, counts=counts, C=C_val, delta=delta / 2.0)  # note δ/2
                    old_bound = A1 + C_val * math.sqrt(max(0.0, u_hat_old) * alpha * math.log(gamma)) + g2_old

                    uhat_old_tbl.at[T_name, alpha] = u_hat_old
                    g2_old_tbl.at[T_name, alpha] = g2_old  # constant across α but repeated for convenience
                    oldb_tbl.at[T_name, alpha] = old_bound

            loss_tag = "01" if loss_name == "01" else "ce"
            # write newer bound (5) CSVs
            a2_tbl.to_csv(os.path.join(args.outdir, f"A2_{name}_{loss_tag}.csv"), index=True)
            a3_tbl.to_csv(os.path.join(args.outdir, f"A3_{name}_{loss_tag}.csv"), index=True)
            bd5_tbl.to_csv(os.path.join(args.outdir, f"Bound5_{name}_{loss_tag}.csv"), index=True)
            # write old bound (3) CSVs
            uhat_old_tbl.to_csv(os.path.join(args.outdir, f"OldUhat_{name}_{loss_tag}.csv"), index=True)
            g2_old_tbl.to_csv(os.path.join(args.outdir, f"OldG2_{name}_{loss_tag}.csv"), index=True)
            oldb_tbl.to_csv(os.path.join(args.outdir, f"OldBound_{name}_{loss_tag}.csv"), index=True)
            print(f"[{name} | {loss_tag}] wrote CSVs to: {args.outdir}")

    process_dataset("train", train, losses_train)
    process_dataset("val",   val,   losses_val)

    torch.save({"state_dict": model.state_dict()}, os.path.join(args.outdir, "mlp_model.pt"))
    print("Done. CSVs and model saved in:", args.outdir)


def main():
    parser = argparse.ArgumentParser(description="GMM experiment: partitions (T1/T2/T3), bounds (5) + old (3)")
    parser.add_argument("--n", type=int, default=100_000, help="Samples per split (train/val)")
    parser.add_argument("--K", type=int, default=100, help="Number of components / clusters")
    parser.add_argument("--sigma", type=float, default=1.0, help="Std dev of Gaussians (section 1.1 uses 1.0)")
    parser.add_argument("--epochs", type=int, default=15)
    parser.add_argument("--batch-size", type=int, default=1024)
    parser.add_argument("--lr", type=float, default=1e-3)
    parser.add_argument("--seed", type=int, default=1337)
    parser.add_argument("--delta", type=float, default=0.01)
    parser.add_argument("--alphas", type=str, default="4,6,8,10,50,100,200,400,800,2000",
                        help="Comma-separated α values")
    parser.add_argument("--outdir", type=str, default="results")
    parser.add_argument("--ce-cap", type=float, default=27.6310211, help="Cap for CE to get finite C (≈ -log(1e-12))")
    args = parser.parse_args()
    run(args)


if __name__ == "__main__":
    main()
