import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset
from medmnist import OrganMNIST3D

# ----------------------------
# 3D-friendly transforms
# ----------------------------
class ToTensor3D(object):
    """Convert a 3D numpy volume [D,H,W] or [1,D,H,W] to torch.FloatTensor [C,D,H,W] in [0,1]."""
    def __call__(self, x):
        if isinstance(x, torch.Tensor):
            vol = x
        else:
            vol = torch.from_numpy(x)
        if vol.ndim == 3:
            vol = vol.unsqueeze(0)  # [1,D,H,W]
        elif vol.ndim == 4:
            pass  # [C,D,H,W]
        else:
            raise ValueError(f"Expected 3D/4D volume, got shape {tuple(vol.shape)}")
        vol = vol.to(torch.float32)
        # If uint8-like, scale to [0,1]
        if vol.max() > 1.0:
            vol = vol / 255.0
        return vol

class MinMaxScale3D(object):
    """Per-volume min-max over spatial dims (D,H,W), channel-wise."""
    def __call__(self, x: torch.Tensor):
        # x: [C,D,H,W]
        minv = x.amin(dim=(-3, -2, -1), keepdim=True)
        maxv = x.amax(dim=(-3, -2, -1), keepdim=True)
        return (x - minv) / (maxv - minv + 1e-8)

class Compose3D(object):
    def __init__(self, transforms): self.transforms = transforms
    def __call__(self, x):
        for t in self.transforms:
            x = t(x)
        return x

default_3d_transform = Compose3D([ToTensor3D(), MinMaxScale3D()])

# ----------------------------
# Label utilities & risk map
# ----------------------------
def _abbr_from_name(name: str) -> str:
    # Make a short, readable abbreviation like "RKID" for "right kidney"
    toks = [w.strip() for w in name.replace('-', ' ').split() if w.strip()]
    if not toks:
        return name[:4].upper()
    abbr = ''.join(tok[0] for tok in toks).upper()
    # if too short, pad from first token
    if len(abbr) < 3:
        abbr = (abbr + toks[0][:3]).upper()[:4]
    return abbr[:6]

def get_organ3d_label_map(ds: OrganMNIST3D) -> dict[int, tuple[str, str]]:
    """Read human labels from medmnist metadata; returns {class_id: (abbr, name)}."""
    info = getattr(ds, 'info', {}) or {}
    label_dict = info.get('label', {})
    # label_dict keys are often strings "0","1",...
    id_to_name = {int(k): str(v) for k, v in label_dict.items()}
    # Fallback if metadata missing: make generic names
    if not id_to_name:
        n = int(np.max(ds.labels)) + 1
        id_to_name = {i: f"Class {i}" for i in range(n)}
    return {i: (_abbr_from_name(nm), nm) for i, nm in id_to_name.items()}

def make_organ3d_class_risk_map(ds_for_meta: OrganMNIST3D,
                                base_mean: float = 10.0,
                                custom_risk_map: dict[int, float] | None = None) -> pd.DataFrame:
    """
    Build class→risk mapping DataFrame.
      - If custom_risk_map provided, use it (lower risk_Y => higher hazard, earlier events).
      - Else assign a smooth ordering across classes (linspace from 0.5 to 12.0, reversed).
    """
    label_to_name = get_organ3d_label_map(ds_for_meta)
    labels = sorted(label_to_name.keys())
    n_cls = len(labels)

    if custom_risk_map is not None:
        risk_Y = [float(custom_risk_map[i]) for i in labels]
    else:
        # Default monotone risks: class 0 highest hazard (smaller Y), class n-1 lowest hazard.
        risk_grid = np.linspace(0.5, 12.0, n_cls)[::-1]
        risk_Y = [float(risk_grid[i]) for i in range(n_cls)]

    df = pd.DataFrame({
        "abbr":   [label_to_name[i][0] for i in labels],
        "name":   [label_to_name[i][1] for i in labels],
        "risk_Y": risk_Y,
    }, index=pd.Index(labels, name="class_label"))
    df["mean_time"] = base_mean * (df["risk_Y"] + 1.0)
    return df

def labels_to_riskY_from_map(y: np.ndarray, class_risk_map: pd.DataFrame) -> np.ndarray:
    mapping = class_risk_map["risk_Y"].to_dict()  # {label: risk_Y}
    vget = np.vectorize(mapping.__getitem__)
    return vget(y.astype(int))

# ----------------------------
# Time generator (same as OCT)
# ----------------------------
class SurvivalTimeGeneratorXCal:
    """T|Y ~ Gamma(mean = base_mean*(Y+1), var = var); C ~ Uniform(min(T_split), q90(T_split))"""
    def __init__(self, base_mean: float = 10.0, var: float = 0.05):
        self.base_mean = float(base_mean)
        self.var = float(var)
    def gen_failure_times_gamma(self, risk_Y: np.ndarray, seed: int) -> np.ndarray:
        rng = np.random.RandomState(seed)
        mu = self.base_mean * (risk_Y.astype(np.float64) + 1.0)
        k = (mu**2) / self.var
        theta = self.var / mu
        return rng.gamma(shape=k, scale=theta)
    @staticmethod
    def gen_uniform_censoring_for_split(T_split: np.ndarray, seed: int) -> np.ndarray:
        rng = np.random.RandomState(seed)
        lo, hi = float(np.min(T_split)), float(np.quantile(T_split, 0.9))
        return rng.uniform(lo, hi, size=T_split.shape[0])

# ----------------------------
# Dataset wrapper
# ----------------------------
class OrganMNIST3DSurvivalSplit(Dataset):
    """
    Wrap OrganMNIST3D split; returns (vol, time, event).
      - vol: [C,D,H,W] float32 in [0,1] after transform (default min-max per volume)
      - time: float32
      - event: float32 in {0,1}
    """
    def __init__(self,
                 base_ds: OrganMNIST3D,
                 times: np.ndarray,
                 events: np.ndarray,
                 transform = default_3d_transform,
                 replicate_channels: int = 1):
        assert len(base_ds) == len(times) == len(events)
        self.base = base_ds
        self.times  = torch.as_tensor(times,  dtype=torch.float32)
        self.events = torch.as_tensor(events, dtype=torch.float32)
        self.transform = transform
        self.replicate_channels = int(replicate_channels)

    def __len__(self): return len(self.base)

    def __getitem__(self, idx: int):
        vol, _ = self.base[idx]  # usually a numpy array [D,H,W]
        vol = self.transform(vol) if self.transform is not None else torch.as_tensor(vol, dtype=torch.float32)
        # Replicate channels if requested (e.g., backbones expecting C=3)
        if self.replicate_channels > 1 and vol.shape[0] == 1:
            vol = vol.repeat(self.replicate_channels, 1, 1, 1)
        t = self.times[idx]
        e = self.events[idx]
        return vol, t, e

# ----------------------------
# Split loader
# ----------------------------
def _load_organmnist3d_splits(size: int = 28, root: str | None = None, download: bool = True):
    # MedMNIST3D supports size parameter (typical: 28). No torchvision transforms needed.
    kw = dict(size=size, download=download)
    if root is not None:
        kw["root"] = root
    ds_tr = OrganMNIST3D(split="train", **kw)
    ds_va = OrganMNIST3D(split="val",   **kw)
    ds_te = OrganMNIST3D(split="test",  **kw)
    return ds_tr, ds_va, ds_te

# ----------------------------
# Top-level builder (parallel to OCT)
# ----------------------------
def make_organmnist3d_survival(seed: int = 123,
                               size: int = 28,
                               root: str | None = None,
                               base_mean: float = 10.0,
                               var: float = 0.05,
                               replicate_channels: int = 1,
                               transform = default_3d_transform,
                               custom_risk_map: dict[int, float] | None = None):
    """
    Build survivalized OrganMNIST3D using OFFICIAL train/val/test splits at the requested size.
    Endpoint simulated: time to a generic adverse event (earlier for higher-hazard classes).
    Returns (train_ds, val_ds, test_ds, class_risk_map).

    Args:
      custom_risk_map: optional {class_id: risk_Y}; smaller risk_Y => higher hazard (earlier failure).
    """
    ds_tr, ds_va, ds_te = _load_organmnist3d_splits(size=size, root=root, download=True)

    # Labels (MedMNIST uses shape [N,1])
    y_tr = ds_tr.labels.squeeze().astype(np.int64)
    y_va = ds_va.labels.squeeze().astype(np.int64)
    y_te = ds_te.labels.squeeze().astype(np.int64)

    # Build class risk map from metadata (or user-specified)
    class_risk_map = make_organ3d_class_risk_map(ds_tr, base_mean=base_mean, custom_risk_map=custom_risk_map)

    # Map to risk_Y
    r_tr = labels_to_riskY_from_map(y_tr, class_risk_map)
    r_va = labels_to_riskY_from_map(y_va, class_risk_map)
    r_te = labels_to_riskY_from_map(y_te, class_risk_map)

    # Generate failure times + censoring
    gen = SurvivalTimeGeneratorXCal(base_mean=base_mean, var=var)
    T_tr = gen.gen_failure_times_gamma(r_tr, seed=seed+0)
    T_va = gen.gen_failure_times_gamma(r_va, seed=seed+1)
    T_te = gen.gen_failure_times_gamma(r_te, seed=seed+2)

    C_tr = gen.gen_uniform_censoring_for_split(T_tr, seed=seed+10)
    C_va = gen.gen_uniform_censoring_for_split(T_va, seed=seed+11)
    C_te = gen.gen_uniform_censoring_for_split(T_te, seed=seed+12)

    E_tr = (T_tr <= C_tr).astype(np.float32)
    E_va = (T_va <= C_va).astype(np.float32)
    E_te = (T_te <= C_te).astype(np.float32)

    Yobs_tr = np.where(E_tr > 0, T_tr, C_tr).astype(np.float32)
    Yobs_va = np.where(E_va > 0, T_va, C_va).astype(np.float32)
    Yobs_te = np.where(E_te > 0, T_te, C_te).astype(np.float32)

    # Wrap datasets
    train_ds = OrganMNIST3DSurvivalSplit(ds_tr, Yobs_tr, E_tr, transform=transform,
                                         replicate_channels=replicate_channels)
    val_ds   = OrganMNIST3DSurvivalSplit(ds_va, Yobs_va, E_va, transform=transform,
                                         replicate_channels=replicate_channels)
    test_ds  = OrganMNIST3DSurvivalSplit(ds_te, Yobs_te, E_te, transform=transform,
                                         replicate_channels=replicate_channels)

    return train_ds, val_ds, test_ds, class_risk_map
