"""
Estimate feature-space "radius" induced by dataset-calibrated noise levels.

For each dataset in {MNIST, CIFAR-10, CIFAR-100, Tiny-ImageNet} and each tau in a list:
  1) Compute dataset scale statistics on a (fast) subset:
       - r  : average L2 distance between random images (after normalization)
       - v̂  : global per-pixel variance (after normalization)
       - d  : number of pixels (C*H*W) after transforms (here 3x224x224)
       - c  : r^2 / (2 d v̂)
  2) Compute mf(tau) from theory:
       mf = sqrt( (alpha^2/(tau*(1-alpha)^2) - 1) / (2 c) ), clipped at >=0
  3) Add per-image Gaussian noise with L2 = mf * r to CLEAN images (no mixup),
     clamp to valid pixel range via denorm [0,1] then renorm.
  4) Extract ResNet trunk features after layer3 (pretrained):
       - resnet18 for MNIST, CIFAR-10
       - resnet50 for CIFAR-100, Tiny-ImageNet
     Flatten feature maps and compute L2 distance between clean vs. noisy features.
  5) Average distances over many samples — this is estimate_feature_space_radius.

Outputs a concise table per dataset & tau:
   tau | mf | r | mf*r | avg_feat_radius

Run:
  python feature_radius.py --datasets mnist cifar10 cifar100 tiny-imagenet --taus 0.1 0.05 0.02 0.01 --alpha 0.7 --device cuda
"""

import os
import argparse
import random
from typing import List, Tuple

import numpy as np
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Subset

# ---------------------------
# Normalization helpers
# ---------------------------

IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD  = [0.229, 0.224, 0.225]

def clamp_imagenet_normalized(x: torch.Tensor) -> torch.Tensor:
    """Clamp by denormalizing to [0,1], then renormalizing back (per batch)."""
    mean = torch.tensor(IMAGENET_MEAN, device=x.device).view(1, -1, 1, 1)
    std  = torch.tensor(IMAGENET_STD,  device=x.device).view(1, -1, 1, 1)
    x01 = (x * std + mean).clamp(0.0, 1.0)
    return (x01 - mean) / std

# ---------------------------
# Datasets & transforms
# ---------------------------

def get_transforms(dataset_type: str):
    dataset_type = dataset_type.lower()
    if dataset_type in ["cifar10", "cifar100", "tiny-imagenet"]:
        return transforms.Compose([
            transforms.Resize(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
        ])
    elif dataset_type == "mnist":
        return transforms.Compose([
            transforms.Resize(224),
            transforms.Grayscale(num_output_channels=3),
            transforms.ToTensor(),
            transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
        ])
    else:
        raise ValueError(f"Unsupported dataset type: {dataset_type}")

def get_dataset(dataset_type: str):
    dataset_type = dataset_type.lower()
    tfm = get_transforms(dataset_type)

    if dataset_type == "mnist":
        train = torchvision.datasets.MNIST(root="./data", train=True,  download=True, transform=tfm)
        test  = torchvision.datasets.MNIST(root="./data", train=False, download=True, transform=tfm)
    elif dataset_type == "cifar10":
        train = torchvision.datasets.CIFAR10(root="./data", train=True,  download=True, transform=tfm)
        test  = torchvision.datasets.CIFAR10(root="./data", train=False, download=True, transform=tfm)
    elif dataset_type == "cifar100":
        train = torchvision.datasets.CIFAR100(root="./data", train=True,  download=True, transform=tfm)
        test  = torchvision.datasets.CIFAR100(root="./data", train=False, download=True, transform=tfm)
    elif dataset_type == "tiny-imagenet":
        train_dir = os.path.join("./data", "tiny-imagenet-200", "train")
        val_dir   = os.path.join("./data", "tiny-imagenet-200", "val")
        test_dir  = os.path.join("./data", "tiny-imagenet-200", "test")
        root_eval = val_dir if os.path.isdir(val_dir) else test_dir
        if not (os.path.isdir(train_dir) and os.path.isdir(root_eval)):
            raise FileNotFoundError("Tiny-ImageNet expected at ./data/tiny-imagenet-200/{train,val or test}")
        train = torchvision.datasets.ImageFolder(root=train_dir, transform=tfm)
        test  = torchvision.datasets.ImageFolder(root=root_eval,  transform=tfm)
    else:
        raise ValueError(f"Unsupported dataset type: {dataset_type}")
    return train, test

def make_subdataset(dataset, max_images: int, seed: int = 0):
    n = len(dataset)
    rng = np.random.default_rng(seed)
    idx = rng.choice(n, size=min(max_images, n), replace=False)
    return Subset(dataset, sorted(idx.tolist()))

# ---------------------------
# Scale statistics: r, v_hat (global mean), c
# ---------------------------

@torch.no_grad()
def estimate_average_distance(dataset, device="cpu", max_images: int = 2048, batch_size: int = 64) -> float:
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2, drop_last=False)
    xs, total = [], 0
    for imgs, _ in loader:
        imgs = imgs.to(device)
        xs.append(imgs)
        total += imgs.size(0)
        if total >= max_images:
            break
    X = torch.cat(xs, dim=0)  # [N,C,H,W]
    N = X.size(0)
    perm = torch.randperm(N, device=X.device)
    perm = torch.roll(perm, shifts=1)
    dists = torch.norm((X - X[perm]).view(N, -1), p=2, dim=1)
    return float(dists.mean().item())

@torch.no_grad()
def estimate_per_pixel_variance_global_mean(dataset, device="cpu", max_images: int = 2048, batch_size: int = 64) -> Tuple[float, int]:
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2, drop_last=False)
    xs, total = [], 0
    for imgs, _ in loader:
        imgs = imgs.to(device)
        xs.append(imgs)
        total += imgs.size(0)
        if total >= max_images:
            break
    X = torch.cat(xs, dim=0)   # [N,C,H,W]
    N, C, H, W = X.shape
    d = C * H * W
    Xv = X.view(N, -1)
    mu = Xv.mean()                                # GLOBAL mean
    v_hat = ((Xv - mu)**2).mean().item()
    return v_hat, d

def compute_c_factor(r: float, v_hat: float, d: int) -> float:
    return (r*r) / (2.0 * d * v_hat + 1e-20)

def mf_from_tau(alpha: float, tau: float, c: float) -> float:
    # mf >= sqrt( (alpha^2/(tau(1-alpha)^2) - 1) / (2c) )
    a2 = alpha * alpha
    inner = max(a2 / (max(tau, 1e-20) * (1 - alpha)**2) - 1.0, 0.0)
    return float(np.sqrt(inner / (2.0 * max(c, 1e-20))))

# ---------------------------
# Noise
# ---------------------------

def add_noise_with_l2_norm_batch(x: torch.Tensor, target_norm: float) -> torch.Tensor:
    """Add independent Gaussian noise per sample, scaled to have L2 norm exactly target_norm (per sample)."""
    B = x.size(0)
    noise = torch.randn_like(x)
    norms = torch.norm(noise.view(B, -1), p=2, dim=1)
    scales = (target_norm / (norms + 1e-12)).view(B, 1, 1, 1)
    return x + noise * scales

# ---------------------------
# Feature extractor after layer3 (per user spec)
# ---------------------------

def build_resnet_feature_extractor(model_name: str, device: torch.device):
    """
    Build a ResNet (18/34/50) backbone and remove the last THREE layers:
    - layer4
    - avgpool
    - fc

    So the output is taken after layer3, which gives spatial feature maps.
    """
    model_name = model_name.lower()

    if model_name == "resnet18":
        weights = torchvision.models.ResNet18_Weights.IMAGENET1K_V1
        resnet = torchvision.models.resnet18(weights=weights)
    elif model_name == "resnet34":
        weights = torchvision.models.ResNet34_Weights.IMAGENET1K_V1
        resnet = torchvision.models.resnet34(weights=weights)
    elif model_name == "resnet50":
        weights = torchvision.models.ResNet50_Weights.IMAGENET1K_V1
        resnet = torchvision.models.resnet50(weights=weights)
    else:
        raise ValueError("Model type not supported. Use resnet18, resnet34, or resnet50.")

    # children: [conv1, bn1, relu, maxpool, layer1, layer2, layer3, layer4, avgpool, fc]
    # We keep up to layer3 => cut last 3 modules (layer4, avgpool, fc)
    modules = list(resnet.children())[:-2]
    feature_extractor = nn.Sequential(*modules).to(device)
    feature_extractor.eval()

    return feature_extractor

@torch.no_grad()
def flatten_features_after_layer3(feature_extractor: nn.Module, x_norm_nchw: torch.Tensor) -> torch.Tensor:
    fmap = feature_extractor(x_norm_nchw)  # [B, C, H', W']
    feats = fmap.flatten(1)                # [B, C*H'*W']
    return feats

# ---------------------------
# Core measurement: feature-space radius
# ---------------------------

@torch.no_grad()
def estimate_feature_space_radius(
    dataset,
    feature_extractor: nn.Module,
    image_avg_dist_r: float,
    mf: float,
    num_samples: int = 1024,
    batch_size: int = 64,
    device: str = "cpu",
) -> float:
    """
    Average L2 distance between features(feat(clean)) and features(feat(noisy)),
    where ||noise||_2 = mf * r per image, and clamped to valid pixel range.
    """
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2,
                        pin_memory=("cuda" in str(device)))
    target_norm = float(mf) * float(image_avg_dist_r)

    total, count = 0.0, 0
    for imgs, _ in loader:
        imgs = imgs.to(device)
        noisy = add_noise_with_l2_norm_batch(imgs, target_norm=target_norm)
        noisy = clamp_imagenet_normalized(noisy)

        f_clean = flatten_features_after_layer3(feature_extractor, imgs)
        f_noisy = flatten_features_after_layer3(feature_extractor, noisy)

        dists = torch.norm(f_clean - f_noisy, p=2, dim=1)
        take = min(num_samples - count, dists.numel())
        total += float(dists[:take].sum())
        count += take
        if count >= num_samples:
            break
    return total / max(count, 1)

# ---------------------------
# Orchestration
# ---------------------------

def run_experiment(
    datasets: List[str],
    taus: List[float],
    alpha: float,
    sub_size: int,
    device: str,
    seed: int,
    max_images_for_stats: int,
    batch_size_stats: int,
    num_samples_radius: int,
    batch_size_radius: int,
):
    if seed is not None:
        random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)

    print("\n=== Feature-space radius vs. dataset-calibrated noise ===")
    print(f"alpha    : {alpha}")
    print(f"taus     : {taus}")
    print(f"subset   : {sub_size} samples / dataset (for stats & radius)")
    print(f"device   : {device}\n")

    for ds_name in datasets:
        print(f"\n--- Dataset: {ds_name} ---")
        _, test_ds = get_dataset(ds_name)
        sub_ds = make_subdataset(test_ds, max_images=sub_size, seed=seed or 0)

        # Stats
        r = estimate_average_distance(sub_ds, device=device, max_images=max_images_for_stats,
                                      batch_size=batch_size_stats)
        v_hat, d = estimate_per_pixel_variance_global_mean(sub_ds, device=device,
                                                           max_images=max_images_for_stats,
                                                           batch_size=batch_size_stats)
        c = compute_c_factor(r, v_hat, d)

        # Backbone choice
        backbone = "resnet18" if ds_name in ("mnist", "cifar10") else "resnet50"
        feat_extractor = build_resnet_feature_extractor(backbone, device=torch.device(device))

        print(f"[stats] r={r:.6f} | v_hat={v_hat:.4e} | d={d} | c={c:.3f} | backbone={backbone}")
        print(" tau   |     mf     |       r       |    mf*r     |  avg_feat_radius")
        print("-------+------------+---------------+-------------+------------------")

        for tau in taus:
            mf = mf_from_tau(alpha=alpha, tau=tau, c=c)
            feat_radius = estimate_feature_space_radius(
                sub_ds, feat_extractor, image_avg_dist_r=r, mf=mf,
                num_samples=num_samples_radius, batch_size=batch_size_radius, device=device
            )
            print(f" {tau:>5.3f} | {mf:>10.4f} | {r:>13.4f} | {mf*r:>11.4f} | {feat_radius:>16.4f}")

# ---------------------------
# CLI
# ---------------------------

def parse_args():
    p = argparse.ArgumentParser(
        description="Estimate feature-space displacement (radius) for dataset-calibrated noise levels."
    )
    p.add_argument("--datasets", type=str, nargs="+",
                   default=["mnist", "cifar10", "cifar100", "tiny-imagenet"])
    p.add_argument("--taus", type=float, nargs="+",
                   default=[1e-00, 1e-01, 1e-02, 1e-03, 1e-04, 1e-05, 1e-06])
    p.add_argument("--alpha", type=float, default=0.7)

    p.add_argument("--sub_size", type=int, default=2048,
                   help="Subset size per dataset for stats and radius estimation.")
    p.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
    p.add_argument("--seed", type=int, default=0)

    p.add_argument("--max_images_for_stats", type=int, default=2048)
    p.add_argument("--batch_size_stats", type=int, default=64)

    p.add_argument("--num_samples_radius", type=int, default=1024,
                   help="How many images to average for the feature radius.")
    p.add_argument("--batch_size_radius", type=int, default=64)

    return p.parse_args()

# ---------------------------
# Entry
# ---------------------------

if __name__ == "__main__":
    args = parse_args()
    run_experiment(
        datasets=args.datasets,
        taus=args.taus,
        alpha=args.alpha,
        sub_size=args.sub_size,
        device=args.device,
        seed=args.seed,
        max_images_for_stats=args.max_images_for_stats,
        batch_size_stats=args.batch_size_stats,
        num_samples_radius=args.num_samples_radius,
        batch_size_radius=args.batch_size_radius,
    )
