# retinamnist_survival.py
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset
from medmnist import RetinaMNIST

# ----------------------------
# 2D-friendly transforms (no torchvision)
# ----------------------------
class ToTensor2D(object):
    """Convert a 2D/3D numpy/PIL image [H,W,(C)] to float tensor [C,H,W] in [0,1]."""
    def __call__(self, x):
        try:
            from PIL import Image
            if isinstance(x, Image.Image):
                x = np.asarray(x)
        except Exception:
            pass
        arr = np.asarray(x)
        if arr.ndim == 2:          # [H,W] -> [H,W,1]
            arr = arr[..., None]
        if arr.dtype != np.float32:
            arr = arr.astype(np.float32)
        # If looks like uint8 data, scale to [0,1]
        if arr.max() > 1.0:
            arr /= 255.0
        arr = np.transpose(arr, (2, 0, 1))  # [C,H,W]
        return torch.from_numpy(arr)

class MinMaxScale2D(object):
    """Per-image channel-wise min-max over spatial dims (H,W)."""
    def __call__(self, x: torch.Tensor):
        # x: [C,H,W]
        minv = x.amin(dim=(-2, -1), keepdim=True)
        maxv = x.amax(dim=(-2, -1), keepdim=True)
        return (x - minv) / (maxv - minv + 1e-8)

class Normalize(object):
    """Channel-wise normalize: (x - mean) / std, supports 1 or 3 channels."""
    def __init__(self, mean, std):
        self.mean = torch.tensor(mean, dtype=torch.float32).view(-1,1,1)
        self.std  = torch.tensor(std,  dtype=torch.float32).view(-1,1,1)
    def __call__(self, x):
        if x.shape[0] != self.mean.shape[0]:
            # broadcast if single-channel data but 3-channel stats requested
            if self.mean.shape[0] == 3 and x.shape[0] == 1:
                x = x.repeat(3,1,1)
            elif self.mean.shape[0] == 1 and x.shape[0] == 3:
                # average RGB to gray then normalize as 1ch
                x = x.mean(dim=0, keepdim=True)
        return (x - self.mean.to(x.device)) / (self.std.to(x.device) + 1e-12)

class Compose2D(object):
    def __init__(self, transforms): self.transforms = transforms
    def __call__(self, x):
        for t in self.transforms:
            x = t(x)
        return x

default_2d_transform = Compose2D([ToTensor2D(), MinMaxScale2D()])

# ----------------------------
# Label utilities & risk map
# ----------------------------
def _abbr_from_name(name: str) -> str:
    toks = [w.strip() for w in name.replace('-', ' ').split() if w.strip()]
    if not toks: return name[:4].upper()
    abbr = ''.join(tok[0] for tok in toks).upper()
    if len(abbr) < 3: abbr = (abbr + toks[0][:3]).upper()[:4]
    return abbr[:6]

def get_retina_label_map(ds: RetinaMNIST) -> dict[int, tuple[str,str]]:
    info = getattr(ds, 'info', {}) or {}
    label_dict = info.get('label', {})
    id_to_name = {int(k): str(v) for k, v in label_dict.items()} if label_dict else {
        int(i): f"Class {int(i)}" for i in np.unique(ds.labels.squeeze())
    }
    return {i: (_abbr_from_name(nm), nm) for i, nm in id_to_name.items()}

def _severity_rank_from_name(nm: str) -> int:
    """Lower rank => more severe (earlier hazard)."""
    s = nm.lower()
    # common DR grading tokens
    if "prolif" in s or "pdr" in s: return 0
    if "severe" in s:               return 1
    if "mod" in s:                  return 2
    if "mild" in s:                 return 3
    if "no" in s or "normal" in s:  return 4
    # default by lexical fallback (put in middle)
    return 2

def make_retina_class_risk_map(ds_for_meta: RetinaMNIST,
                               base_mean: float = 10.0,
                               custom_risk_map: dict[int, float] | None = None) -> pd.DataFrame:
    """
    Build class→risk mapping DataFrame.
      - If custom_risk_map provided, use it (lower risk_Y => higher hazard).
      - Else infer DR severity from label names and assign monotone risks accordingly.
    """
    label_to_name = get_retina_label_map(ds_for_meta)
    labels = sorted(label_to_name.keys())

    if custom_risk_map is not None:
        risk_Y = [float(custom_risk_map[i]) for i in labels]
    else:
        # Rank by severity using names; then map to spaced values (smaller => earlier failure)
        ranks = {i: _severity_rank_from_name(label_to_name[i][1]) for i in labels}
        # Create a small grid spaced to give reasonable separation
        base_grid = np.linspace(0.5, 8.0, num=len(labels))
        # map rank 0 -> smallest values
        sorted_ids = sorted(labels, key=lambda i: ranks[i])
        risk_by_id = {}
        for k, cid in enumerate(sorted_ids):
            risk_by_id[cid] = float(base_grid[k])
        risk_Y = [risk_by_id[i] for i in labels]

    df = pd.DataFrame({
        "abbr":   [label_to_name[i][0] for i in labels],
        "name":   [label_to_name[i][1] for i in labels],
        "risk_Y": risk_Y,
    }, index=pd.Index(labels, name="class_label"))
    df["mean_time"] = base_mean * (df["risk_Y"] + 1.0)
    return df

def labels_to_riskY_from_map(y: np.ndarray, class_risk_map: pd.DataFrame) -> np.ndarray:
    mapping = class_risk_map["risk_Y"].to_dict()
    vget = np.vectorize(mapping.__getitem__)
    return vget(y.astype(int))

# ----------------------------
# Time generator (Gamma) + censoring
# ----------------------------
class SurvivalTimeGeneratorXCal:
    """T|Y ~ Gamma(mean = base_mean*(Y+1), var = var); C ~ Uniform(min(T_split), q90(T_split))."""
    def __init__(self, base_mean: float = 10.0, var: float = 0.05):
        self.base_mean = float(base_mean)
        self.var = float(var)
    def gen_failure_times_gamma(self, risk_Y: np.ndarray, seed: int) -> np.ndarray:
        rng = np.random.RandomState(seed)
        mu = self.base_mean * (risk_Y.astype(np.float64) + 1.0)
        k = (mu**2) / self.var
        theta = self.var / mu
        return rng.gamma(shape=k, scale=theta)
    @staticmethod
    def gen_uniform_censoring_for_split(T_split: np.ndarray, seed: int) -> np.ndarray:
        rng = np.random.RandomState(seed)
        lo, hi = float(np.min(T_split)), float(np.quantile(T_split, 0.9))
        return rng.uniform(lo, hi, size=T_split.shape[0])

# ----------------------------
# Dataset wrapper (2D)
# ----------------------------
class RetinaMNISTSurvivalSplit(Dataset):
    """
    Wrap RetinaMNIST split; returns (img, time, event).
      - img: [C,H,W] float32 in [0,1] after transform (default min-max per image)
      - time: float32
      - event: float32 in {0,1}
    """
    def __init__(self,
                 base_ds: RetinaMNIST,
                 times: np.ndarray,
                 events: np.ndarray,
                 transform = default_2d_transform,
                 normalize: str | None = None):
        assert len(base_ds) == len(times) == len(events)
        self.base = base_ds
        self.times  = torch.as_tensor(times,  dtype=torch.float32)
        self.events = torch.as_tensor(events, dtype=torch.float32)
        self.transform = transform
        # optional normalization on top of min-max
        if normalize == "imagenet":
            self.norm = Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
        elif normalize == "rescale":   # map [0,1] -> [-1,1]
            self.norm = Normalize(mean=[0.5], std=[0.5])
        else:
            self.norm = None

    def __len__(self): return len(self.base)

    def __getitem__(self, idx: int):
        img, _ = self.base[idx]    # PIL/numpy
        img = self.transform(img) if self.transform is not None else ToTensor2D()(img)
        if self.norm is not None:
            img = self.norm(img)
        t = self.times[idx]
        e = self.events[idx]
        return img, t, e

# ----------------------------
# Split loader
# ----------------------------
def _load_retinamnist_splits(size: int = 224, root: str | None = None, download: bool = True):
    kw = dict(size=size, download=download)
    if root is not None:
        kw["root"] = root
    ds_tr = RetinaMNIST(split="train", **kw)
    ds_va = RetinaMNIST(split="val",   **kw)
    ds_te = RetinaMNIST(split="test",  **kw)
    return ds_tr, ds_va, ds_te

# ----------------------------
# Top-level builder (parallel to OCT)
# ----------------------------
def make_retinamnist_survival(seed: int = 77,
                              size: int = 224,
                              root: str | None = None,
                              base_mean: float = 10.0,
                              var: float = 0.05,
                              normalize: str | None = None,
                              custom_risk_map: dict[int, float] | None = None):
    """
    Build survivalized RetinaMNIST using OFFICIAL train/val/test splits.
    Endpoint simulated: time to clinically significant vision impairment (earlier for higher DR severity).
    Returns (train_ds, val_ds, test_ds, class_risk_map).
    """
    ds_tr, ds_va, ds_te = _load_retinamnist_splits(size=size, root=root, download=True)

    y_tr = ds_tr.labels.squeeze().astype(np.int64)
    y_va = ds_va.labels.squeeze().astype(np.int64)
    y_te = ds_te.labels.squeeze().astype(np.int64)

    # Class risk table (can be overridden)
    class_risk_map = make_retina_class_risk_map(ds_tr, base_mean=base_mean, custom_risk_map=custom_risk_map)

    # Map labels -> risk_Y
    r_tr = labels_to_riskY_from_map(y_tr, class_risk_map)
    r_va = labels_to_riskY_from_map(y_va, class_risk_map)
    r_te = labels_to_riskY_from_map(y_te, class_risk_map)

    # Failure times + censoring
    gen = SurvivalTimeGeneratorXCal(base_mean=base_mean, var=var)
    T_tr = gen.gen_failure_times_gamma(r_tr, seed=seed+0)
    T_va = gen.gen_failure_times_gamma(r_va, seed=seed+1)
    T_te = gen.gen_failure_times_gamma(r_te, seed=seed+2)

    C_tr = gen.gen_uniform_censoring_for_split(T_tr, seed=seed+10)
    C_va = gen.gen_uniform_censoring_for_split(T_va, seed=seed+11)
    C_te = gen.gen_uniform_censoring_for_split(T_te, seed=seed+12)

    E_tr = (T_tr <= C_tr).astype(np.float32)
    E_va = (T_va <= C_va).astype(np.float32)
    E_te = (T_te <= C_te).astype(np.float32)

    Yobs_tr = np.where(E_tr > 0, T_tr, C_tr).astype(np.float32)
    Yobs_va = np.where(E_va > 0, T_va, C_va).astype(np.float32)
    Yobs_te = np.where(E_te > 0, T_te, C_te).astype(np.float32)

    # Datasets
    train_ds = RetinaMNISTSurvivalSplit(ds_tr, Yobs_tr, E_tr, transform=default_2d_transform, normalize=normalize)
    val_ds   = RetinaMNISTSurvivalSplit(ds_va, Yobs_va, E_va, transform=default_2d_transform, normalize=normalize)
    test_ds  = RetinaMNISTSurvivalSplit(ds_te, Yobs_te, E_te, transform=default_2d_transform, normalize=normalize)

    print(f"Train censoring rate: {1 - E_tr.mean():.2%} (N={len(ds_tr)})")
    print(f"Valid censoring rate: {1 - E_va.mean():.2%} (N={len(ds_va)})")
    print(f"Test  censoring rate: {1 - E_te.mean():.2%} (N={len(ds_te)})")

    return train_ds, val_ds, test_ds, class_risk_map
