import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset
from torchvision import datasets, transforms

class SurvivalDataset(Dataset):
    """
    A Dataset that holds all data in memory and returns items as (img, time, event),
    mimicking the output of SurvivalImageDataset.
    """
    def __init__(self, images, labels, times, events):
        self.images = torch.as_tensor(images)
        self.labels = torch.as_tensor(labels, dtype=torch.long)
        self.times  = torch.as_tensor(times,  dtype=torch.float32)
        self.events = torch.as_tensor(events, dtype=torch.float32)

        n = len(self.images)
        assert all(len(t) == n for t in [self.labels, self.times, self.events])

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        img   = self.images[index]
        time  = self.times[index]
        event = self.events[index]
        return img, time, event


# --- Exact class→risk mapping from X-CAL (Appendix D.2, Table 5) ---
_DIGIT_TO_RISK = {
    0: 11.25, 1:  2.25, 2:  5.25, 3:  5.00, 4:  4.75,
    5:  8.00, 6:  2.00, 7: 11.00, 8:  1.75, 9: 10.75,
}
_DIGIT_TO_GROUP = {
    0: "most", 1: "least", 2: "lower", 3: "lower", 4: "lower",
    5: "higher", 6: "least", 7: "most", 8: "least", 9: "most"
}

def make_risk_score_for_groups(y: np.ndarray):
    """X-CAL Survival-MNIST: each digit label gets a deterministic risk score."""
    classes = np.unique(y)
    risk_per_class = {lbl: float(_DIGIT_TO_RISK[int(lbl)]) for lbl in classes}
    group_assignment = {lbl: _DIGIT_TO_GROUP[int(lbl)] for lbl in classes}
    assignment = pd.concat((
        pd.Series(risk_per_class,   name="risk_score"),
        pd.Series(group_assignment, name="risk_group")
    ), axis=1).rename_axis("class_label")
    risk_scores = np.array([risk_per_class[int(yy)] for yy in y], dtype=np.float64)
    return assignment, risk_scores

class SurvivalTimeGenerator:
    """
    Survival-MNIST generator per X-CAL:
      - Failure time T | y ~ Gamma(mean = risk(y), variance = var).
      - For EACH split, censoring C ~ Uniform(min(T_split), quantile_0.9(T_split)).
    """
    def __init__(self, var: float = 1e-3):
        self.var = float(var)

    def gen_failure_times_gamma(self, risk_score: np.ndarray, seed: int):
        rng = np.random.RandomState(seed)
        mu = np.asarray(risk_score, dtype=np.float64)   # mean per point
        var = self.var
        k = (mu ** 2) / var          # shape
        theta = var / mu             # scale
        return rng.gamma(shape=k, scale=theta)

    @staticmethod
    def gen_uniform_censoring_for_split(t_split: np.ndarray, seed: int):
        rng = np.random.RandomState(seed)
        lo = float(np.min(t_split))
        hi = float(np.quantile(t_split, 0.9))
        return rng.uniform(lo, hi, size=t_split.shape[0])


# ---------------------- NEW: helpers for noise ----------------------
def _apply_label_swaps(y: np.ndarray, frac: float, seed: int) -> np.ndarray:
    """
    Randomly permute labels for a fraction 'frac' of indices (within-split), to
    inject noise *before* generating T. The permutation is on the chosen subset.
    """
    if frac <= 0.0:
        return y
    y_noisy = y.copy()
    n = len(y_noisy)
    m = int(np.floor(frac * n))
    if m <= 1:
        return y_noisy

    rng = np.random.RandomState(seed)
    idx = rng.choice(n, size=m, replace=False)
    # Permute labels within the selected subset; ensure it's not identity if possible
    perm = rng.permutation(m)
    # Try to avoid too many fixed points (optional)
    if m > 2 and np.all(perm == np.arange(m)):
        perm = np.roll(perm, 1)
    y_noisy[idx] = y_noisy[idx][perm]
    return y_noisy


def _add_uniform_noise_to_observed_time(Y: np.ndarray, low: float, high: float, seed: int,
                                        min_time: float = 1e-6) -> np.ndarray:
    """
    Add Uniform(low, high) noise to the *observed* time Y only (not to T or C).
    Clip to be strictly positive.
    """
    if low == 0.0 and high == 0.0:
        return Y
    rng = np.random.RandomState(seed)
    noise = rng.uniform(low, high, size=Y.shape[0])
    Y_tilde = Y.astype(np.float64) + noise
    # enforce positivity
    Y_tilde = np.maximum(Y_tilde, min_time)
    return Y_tilde
# -------------------------------------------------------------------

def make_noisy_mnist_survival(
    seed: int = 89,
    var: float = 1e-3,
    *,
    # (1) Label swapping noise (applied to TRAIN/VAL ONLY, never test)
    label_swap_frac: float = 0.0,
    label_swap_seed: int = 2027,
    # (2) Uniform noise on observed time Y only (applied to TRAIN/VAL ONLY, never test)
    obs_time_noise_range: tuple[float, float] = (0.0, 0.0),
    obs_time_noise_seed: int = 4041,
    min_time_after_noise: float = 1e-6,
):
    """
    Build SurvivalDatasets for MNIST using official folds/sizes:
      - Train set (60k) -> fixed split: train=50k, val=10k
      - Test set (10k)  -> test=10k

    Noise options (TRAIN/VAL ONLY):
      - label_swap_frac: fraction in each split (train/val) whose labels are randomly
                         permuted before generating T (alters risk assignments).
      - obs_time_noise_range: (low, high) for Uniform noise added to Y only (train/val),
                              AFTER computing Y = min(T, C). E is NOT recomputed.

    Returns: (train_ds, val_ds, test_ds, risk_score_assignment)
    """
    tfm = transforms.ToTensor()

    # --- Official train split (60k) ---
    train_all = datasets.MNIST(root="./data", train=True,  download=True, transform=tfm)
    X_train_all = train_all.data.unsqueeze(1).float() / 255.0   # [60000, 1, 28, 28]
    y_train_all = train_all.targets.numpy()

    # Reproducible shuffle, then 50k train / 10k val
    rng = np.random.RandomState(seed)
    perm = rng.permutation(len(y_train_all))
    X_train_all = X_train_all[perm]
    y_train_all = y_train_all[perm]

    n_train, n_val = 50000, 10000
    train_idx = np.arange(0, n_train)
    val_idx   = np.arange(n_train, n_train + n_val)

    X_train = X_train_all[train_idx]
    y_train = y_train_all[train_idx]
    X_val   = X_train_all[val_idx]
    y_val   = y_train_all[val_idx]

    # --- Official test split (10k) ---
    test_ds_raw = datasets.MNIST(root="./data", train=False, download=True, transform=tfm)
    X_test = test_ds_raw.data.unsqueeze(1).float() / 255.0      # [10000, 1, 28, 28]
    y_test = test_ds_raw.targets.numpy()

    # --- Label swapping noise: TRAIN/VAL ONLY ---
    y_train_noisy = _apply_label_swaps(y_train, label_swap_frac, label_swap_seed + 0)
    y_val_noisy   = _apply_label_swaps(y_val,   label_swap_frac, label_swap_seed + 1)

    # TEST IS KEPT CLEAN
    y_test_noisy  = y_test.copy()  # identical to original

    # --- Risk scores (per digit) & times per split ---
    _digits = np.arange(10)
    risk_score_assignment, _ = make_risk_score_for_groups(_digits)

    _, risk_train = make_risk_score_for_groups(y_train_noisy)
    _, risk_val   = make_risk_score_for_groups(y_val_noisy)
    _, risk_test  = make_risk_score_for_groups(y_test_noisy)  # same as clean labels

    gen = SurvivalTimeGenerator(var=var)
    T_train = gen.gen_failure_times_gamma(risk_train, seed=seed + 0)
    T_val   = gen.gen_failure_times_gamma(risk_val,   seed=seed + 1)
    T_test  = gen.gen_failure_times_gamma(risk_test,  seed=seed + 2)

    C_train = gen.gen_uniform_censoring_for_split(T_train, seed=seed + 10)
    C_val   = gen.gen_uniform_censoring_for_split(T_val,   seed=seed + 11)
    C_test  = gen.gen_uniform_censoring_for_split(T_test,  seed=seed + 12)

    E_train = (T_train <= C_train)
    E_val   = (T_val   <= C_val)
    E_test  = (T_test  <= C_test)

    Y_train = np.where(E_train, T_train, C_train)
    Y_val   = np.where(E_val,   T_val,   C_val)
    Y_test  = np.where(E_test,  T_test,  C_test)

    # --- Observed-time noise: TRAIN/VAL ONLY (E unchanged) ---
    low, high = obs_time_noise_range
    Y_train = _add_uniform_noise_to_observed_time(
        Y_train, low, high, seed=obs_time_noise_seed + 0, min_time=min_time_after_noise
    )
    Y_val = _add_uniform_noise_to_observed_time(
        Y_val,   low, high, seed=obs_time_noise_seed + 1, min_time=min_time_after_noise
    )

    # TEST IS KEPT CLEAN — no time noise applied
    # Y_test unchanged

    # --- Pack datasets (exact sizes: 50k / 10k / 10k) ---
    train_ds = SurvivalDataset(
        images=X_train, labels=y_train_noisy, times=Y_train, events=E_train
    )
    val_ds = SurvivalDataset(
        images=X_val, labels=y_val_noisy, times=Y_val, events=E_val
    )
    test_ds = SurvivalDataset(
        images=X_test, labels=y_test, times=Y_test, events=E_test
    )

    print(f"Train size: {len(train_ds):>5} | censoring rate: {1 - E_train.mean():.2%}")
    print(f"Valid size: {len(val_ds):>5} | censoring rate: {1 - E_val.mean():.2%}")
    print(f"Test  size: {len(test_ds):>5} | censoring rate: {1 - E_test.mean():.2%}")
    if label_swap_frac > 0:
        print(f"[Noise] Label swaps applied to TRAIN/VAL only: {label_swap_frac:.2%}")
    if (low, high) != (0.0, 0.0):
        print(f"[Noise] Observed-time jitter Uniform({low}, {high}) on TRAIN/VAL only")
    print("[Sanity] Test set left pristine: labels and observed times unmodified.")

    return train_ds, val_ds, test_ds, risk_score_assignment
