import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset
from torchvision import transforms

from medmnist import PathMNIST

# ---- Label → risk Y (lower = higher risk) per CRC-informed mapping ----
# 0: ADI, 1: BACK, 2: DEB, 3: LYM, 4: MUC, 5: MUS, 6: NORM, 7: STR, 8: TUM
PATHMNIST_LABEL_TO_RISK_Y = {
    2: 0.5,  # DEB (necrosis)
    5: 1.0,  # MUS (smooth muscle)
    7: 1.5,  # STR (stroma)
    4: 2.0,  # MUC (mucin)
    0: 4.0,  # ADI (adipose)
    8: 7.0,  # TUM (tumor epithelium)
    3: 8.0,  # LYM (lymphocytes)
    6: 10.0,  # NORM
    1: 12.0,  # BACK (background)
}

PATHMNIST_LABEL_TO_NAME = {
    0: ("ADI", "Adipose"),
    1: ("BACK","Background"),
    2: ("DEB","Debris/Necrosis"),
    3: ("LYM","Lymphocytes"),
    4: ("MUC","Mucus/Mucin"),
    5: ("MUS","Smooth muscle"),
    6: ("NORM","Normal colon mucosa"),
    7: ("STR","Cancer-associated stroma"),
    8: ("TUM","Tumor epithelium"),
}


def make_class_risk_map(base_mean: float = 10.0) -> pd.DataFrame:
    """Class-level mapping DataFrame: label → (abbr, name, risk_Y, mean_time)."""
    labels = sorted(PATHMNIST_LABEL_TO_NAME.keys())
    df = pd.DataFrame(
        {
            "abbr": [PATHMNIST_LABEL_TO_NAME[i][0] for i in labels],
            "name": [PATHMNIST_LABEL_TO_NAME[i][1] for i in labels],
            "risk_Y": [PATHMNIST_LABEL_TO_RISK_Y[i] for i in labels],
        },
        index=pd.Index(labels, name="class_label"),
    )
    df["mean_time"] = base_mean * (df["risk_Y"] + 1)
    return df


def labels_to_riskY(y: np.ndarray) -> np.ndarray:
    """Vectorized map from PathMNIST labels (0..8) to integer risk Y (0..8)."""
    return np.vectorize(PATHMNIST_LABEL_TO_RISK_Y.__getitem__)(y.astype(int))


class SurvivalTimeGeneratorXCal:
    """
    X-CAL style:
      T | Y ~ Gamma(mean = base_mean*(Y+1), var = var)
      For each split, C ~ Uniform(min(T_split), q90(T_split))
    """
    def __init__(self, base_mean: float = 10.0, var: float = 0.05):
        self.base_mean = float(base_mean)
        self.var = float(var)

    def gen_failure_times_gamma(self, risk_Y: np.ndarray, seed: int) -> np.ndarray:
        rng = np.random.RandomState(seed)
        mu = self.base_mean * (risk_Y.astype(np.float64) + 1.0)
        k = (mu ** 2) / self.var         # shape
        theta = self.var / mu            # scale
        return rng.gamma(shape=k, scale=theta)

    @staticmethod
    def gen_uniform_censoring_for_split(T_split: np.ndarray, seed: int) -> np.ndarray:
        rng = np.random.RandomState(seed)
        lo = float(np.min(T_split))
        hi = float(np.quantile(T_split, 0.9))
        return rng.uniform(lo, hi, size=T_split.shape[0])


class PathMNISTSurvivalSplit(Dataset):
    """
    Wraps an underlying PathMNIST split (224x224) and provides (img, time, event).
    """
    def __init__(self, base_ds: PathMNIST, times: np.ndarray, events: np.ndarray, normalize: bool):
        assert len(base_ds) == len(times) == len(events)
        self.base = base_ds
        self.times  = torch.as_tensor(times,  dtype=torch.float32)
        self.events = torch.as_tensor(events, dtype=torch.float32)

    def __len__(self):
        return len(self.base)

    def __getitem__(self, idx: int):
        img, _ = self.base[idx]                 # 3x224x224 tensor in [0,1]
        t = torch.tensor(self.times[idx], dtype=torch.float32)
        e = torch.tensor(self.events[idx], dtype=torch.float32)
        return img, t, e


def _load_pathmnist_224_splits(download: bool = True):
    """Load official train/val/test splits (224x224) with ToTensor()."""
    root = './data'
    size = 224
    tfm = transforms.ToTensor()
    ds_train = PathMNIST(split="train", size=size, download=download, transform=tfm, root=root)
    ds_val   = PathMNIST(split="val",   size=size, download=download, transform=tfm, root=root)
    ds_test  = PathMNIST(split="test",  size=size, download=download, transform=tfm, root=root)
    return ds_train, ds_val, ds_test


def make_pathmnist_survival(seed: int = 89,
                            base_mean: float = 10.0,
                            var: float = 0.05,
                            normalize: bool = False):
    """
    Build SurvivalDatasets from PathMNIST (224x224) using the OFFICIAL splits & sizes.
      - No subsampling: returns full-size train/val/test as defined by MedMNIST+.
      - Independent censoring per split.
      - Returns: (train_ds, val_ds, test_ds, class_risk_map)
    """
    # 1) Load official splits (224x224)
    ds_tr, ds_va, ds_te = _load_pathmnist_224_splits(download=True)

    # 2) Get labels per split
    y_tr = ds_tr.labels.squeeze().astype(np.int64)
    y_va = ds_va.labels.squeeze().astype(np.int64)
    y_te = ds_te.labels.squeeze().astype(np.int64)

    # 3) Map to integer risk Y
    r_tr = labels_to_riskY(y_tr)
    r_va = labels_to_riskY(y_va)
    r_te = labels_to_riskY(y_te)

    # 4) Generate failure times (Gamma) and censoring per split
    gen = SurvivalTimeGeneratorXCal(base_mean=base_mean, var=var)

    T_tr = gen.gen_failure_times_gamma(r_tr, seed=seed + 0)
    T_va = gen.gen_failure_times_gamma(r_va, seed=seed + 1)
    T_te = gen.gen_failure_times_gamma(r_te, seed=seed + 2)

    C_tr = gen.gen_uniform_censoring_for_split(T_tr, seed=seed + 10)
    C_va = gen.gen_uniform_censoring_for_split(T_va, seed=seed + 11)
    C_te = gen.gen_uniform_censoring_for_split(T_te, seed=seed + 12)

    E_tr = (T_tr <= C_tr)
    E_va = (T_va <= C_va)
    E_te = (T_te <= C_te)

    Yobs_tr = np.where(E_tr, T_tr, C_tr).astype(np.float32)
    Yobs_va = np.where(E_va, T_va, C_va).astype(np.float32)
    Yobs_te = np.where(E_te, T_te, C_te).astype(np.float32)

    # 5) Wrap each split into its own Dataset (sizes exactly match the official split)
    train_ds = PathMNISTSurvivalSplit(ds_tr, Yobs_tr, E_tr, normalize=normalize)
    val_ds   = PathMNISTSurvivalSplit(ds_va, Yobs_va, E_va, normalize=normalize)
    test_ds  = PathMNISTSurvivalSplit(ds_te, Yobs_te, E_te, normalize=normalize)

    # 6) Helpful summary & class risk map
    print(f"Train censoring rate: {1 - E_tr.mean():.2%} (N={len(ds_tr)})")
    print(f"Valid censoring rate: {1 - E_va.mean():.2%} (N={len(ds_va)})")
    print(f"Test  censoring rate: {1 - E_te.mean():.2%} (N={len(ds_te)})")

    class_risk_map = make_class_risk_map(base_mean=base_mean)
    return train_ds, val_ds, test_ds, class_risk_map
