import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset
from torchvision import datasets, transforms

class SurvivalDataset(Dataset):
    """
    A Dataset that holds all data in memory and returns items as (img, time, event),
    mimicking the output of SurvivalImageDataset.
    """
    def __init__(self, images, labels, times, events):
        self.images = torch.as_tensor(images)
        self.labels = torch.as_tensor(labels, dtype=torch.long)
        self.times  = torch.as_tensor(times,  dtype=torch.float32)
        self.events = torch.as_tensor(events, dtype=torch.float32)

        n = len(self.images)
        assert all(len(t) == n for t in [self.labels, self.times, self.events])

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        img   = self.images[index]
        time  = self.times[index]
        event = self.events[index]
        return img, time, event


# --- Exact class→risk mapping from X-CAL (Appendix D.2, Table 5) ---
_DIGIT_TO_RISK = {
    0: 11.25, 1:  2.25, 2:  5.25, 3:  5.00, 4:  4.75,
    5:  8.00, 6:  2.00, 7: 11.00, 8:  1.75, 9: 10.75,
}
_DIGIT_TO_GROUP = {
    0: "most", 1: "least", 2: "lower", 3: "lower", 4: "lower",
    5: "higher", 6: "least", 7: "most", 8: "least", 9: "most"
}

def make_risk_score_for_groups(y: np.ndarray):
    """X-CAL Survival-MNIST: each digit label gets a deterministic risk score."""
    classes = np.unique(y)
    risk_per_class = {lbl: float(_DIGIT_TO_RISK[int(lbl)]) for lbl in classes}
    group_assignment = {lbl: _DIGIT_TO_GROUP[int(lbl)] for lbl in classes}
    assignment = pd.concat((
        pd.Series(risk_per_class,   name="risk_score"),
        pd.Series(group_assignment, name="risk_group")
    ), axis=1).rename_axis("class_label")
    risk_scores = np.array([risk_per_class[int(yy)] for yy in y], dtype=np.float64)
    return assignment, risk_scores

class SurvivalTimeGenerator:
    """
    Survival-MNIST generator per X-CAL:
      - Failure time T | y ~ Gamma(mean = risk(y), variance = var).
      - For EACH split, censoring C ~ Uniform(min(T_split), quantile_0.9(T_split)).
    """
    def __init__(self, var: float = 1e-3):
        self.var = float(var)

    def gen_failure_times_gamma(self, risk_score: np.ndarray, seed: int):
        rng = np.random.RandomState(seed)
        mu = np.asarray(risk_score, dtype=np.float64)   # mean per point
        var = self.var
        k = (mu ** 2) / var          # shape
        theta = var / mu             # scale
        return rng.gamma(shape=k, scale=theta)

    @staticmethod
    def gen_uniform_censoring_for_split(t_split: np.ndarray, seed: int):
        rng = np.random.RandomState(seed)
        lo = float(np.min(t_split))
        hi = float(np.quantile(t_split, 0.9))
        return rng.uniform(lo, hi, size=t_split.shape[0])


def make_mnist_survival(seed: int = 89, var: float = 1e-3):
    """
    Build SurvivalDatasets for MNIST using official folds/sizes:
      - Train set (60k) -> fixed split: train=50k, val=10k
      - Test set (10k)  -> test=10k
    Returns: (train_ds, val_ds, test_ds, risk_score_assignment)
    """
    tfm = transforms.ToTensor()

    # --- Official train split (60k) ---
    train_all = datasets.MNIST(root="./data", train=True,  download=True, transform=tfm)
    X_train_all = train_all.data.unsqueeze(1).float() / 255.0   # [60000, 1, 28, 28]
    y_train_all = train_all.targets.numpy()

    # Reproducible shuffle, then 50k train / 10k val
    rng = np.random.RandomState(seed)
    perm = rng.permutation(len(y_train_all))
    X_train_all = X_train_all[perm]
    y_train_all = y_train_all[perm]

    n_train, n_val = 50000, 10000
    train_idx = np.arange(0, n_train)
    val_idx   = np.arange(n_train, n_train + n_val)

    X_train = X_train_all[train_idx]
    y_train = y_train_all[train_idx]
    X_val   = X_train_all[val_idx]
    y_val   = y_train_all[val_idx]

    # --- Official test split (10k) ---
    test_ds = datasets.MNIST(root="./data", train=False, download=True, transform=tfm)
    X_test = test_ds.data.unsqueeze(1).float() / 255.0          # [10000, 1, 28, 28]
    y_test = test_ds.targets.numpy()

    # --- Risk scores (per digit) & times per split ---
    # (Create a full mapping table using all 10 digits for convenience.)
    _ = np.arange(10)
    risk_score_assignment, _ = make_risk_score_for_groups(_)

    _, risk_train = make_risk_score_for_groups(y_train)
    _, risk_val   = make_risk_score_for_groups(y_val)
    _, risk_test  = make_risk_score_for_groups(y_test)

    gen = SurvivalTimeGenerator(var=var)
    T_train = gen.gen_failure_times_gamma(risk_train, seed=seed + 0)
    T_val   = gen.gen_failure_times_gamma(risk_val,   seed=seed + 1)
    T_test  = gen.gen_failure_times_gamma(risk_test,  seed=seed + 2)

    C_train = gen.gen_uniform_censoring_for_split(T_train, seed=seed + 10)
    C_val   = gen.gen_uniform_censoring_for_split(T_val,   seed=seed + 11)
    C_test  = gen.gen_uniform_censoring_for_split(T_test,  seed=seed + 12)

    E_train = (T_train <= C_train)
    E_val   = (T_val   <= C_val)
    E_test  = (T_test  <= C_test)

    Y_train = np.where(E_train, T_train, C_train)
    Y_val   = np.where(E_val,   T_val,   C_val)
    Y_test  = np.where(E_test,  T_test,  C_test)

    # --- Pack datasets (exact sizes: 50k / 10k / 10k) ---
    train_ds = SurvivalDataset(
        images=X_train, labels=y_train, times=Y_train, events=E_train
    )
    val_ds = SurvivalDataset(
        images=X_val, labels=y_val, times=Y_val, events=E_val
    )
    test_ds = SurvivalDataset(
        images=X_test, labels=y_test, times=Y_test, events=E_test
    )

    print(f"Train size: {len(train_ds):>5} | censoring rate: {1 - E_train.mean():.2%}")
    print(f"Valid size: {len(val_ds):>5} | censoring rate: {1 - E_val.mean():.2%}")
    print(f"Test  size: {len(test_ds):>5} | censoring rate: {1 - E_test.mean():.2%}")

    return train_ds, val_ds, test_ds, risk_score_assignment
