import torch
import numpy as np
from copy import deepcopy
from models._base import Base


# ---------- KM for censoring on the discrete grid ----------
@torch.no_grad()
def km_censoring_survival_from_loader(discretizer, train_loader, device=None, eps=1e-8):
    """
    Estimate the censoring survival G_hat(t_j) on the model's discrete grid via Kaplan–Meier,
    using delta^C = 1 - event as the 'event' indicator for the censoring process.

    Returns:
        G_t: torch.FloatTensor of shape [T] with G_hat evaluated at each discrete time index j.
    """
    all_o, all_e = [], []
    for batch in train_loader:
        # Expect (X, O, E)
        if isinstance(batch, (list, tuple)) and len(batch) >= 3:
            _, o, e = batch[:3]
        else:
            raise ValueError("train_loader must yield (X, O, E).")
        all_o.append(o.detach().cpu())
        all_e.append(e.detach().cpu())

    O = torch.cat(all_o, dim=0).float()               # [N]
    E = torch.cat(all_e, dim=0).float()               # [N]
    C = (1.0 - E)                                     # censoring indicator

    onehot = discretizer.transform_one_hot(O)         # [N, T]
    if not torch.is_tensor(onehot):
        onehot = torch.as_tensor(onehot)
    T = onehot.shape[1]

    t_idx = onehot.argmax(dim=1)                      # [N], each in {0,...,T-1}

    device = device or torch.device("cpu")
    t_idx = t_idx.to(device)
    C = C.to(device)

    # counts per bin
    # censored at j
    cens_at_j = torch.bincount(t_idx[C > 0.5].long(), minlength=T).float()  # [T]
    # at risk at start of bin j = # with t_idx >= j (tail counts)
    arrivals = torch.bincount(t_idx.long(), minlength=T).float()            # [T]
    at_risk_j = torch.flip(torch.cumsum(torch.flip(arrivals, dims=[0]), dim=0), dims=[0])

    # KM product over discrete bins
    hazard_c = torch.where(at_risk_j > 0,
                           cens_at_j / at_risk_j.clamp_min(eps),
                           torch.zeros_like(at_risk_j))
    step = (1.0 - hazard_c).clamp_min(0.0)
    G = torch.cumprod(step, dim=0).clamp_min(eps)     # [T]
    return G


# ---------- IPCW-Integrated Brier Loss ----------
class IPCWIntegratedBrierLoss(torch.nn.Module):
    """
    IPCW-normalized Integrated Brier Score (IBS) over the discrete grid.

    For each time index j, computes:
        BS(j) = sum_i w_i(j) * ( I(T_i > t_j) - S_hat_i(j) )^2  / sum_i w_i(j),
    with IPCW weights (Graf et al. style):
        w_i(j) = I(T_i > t_j) / G_hat(t_j) + I(T_i <= t_j, E_i=1) / G_hat(T_i-).

    The loss is the (optionally weighted) average of BS(j) over j with nonzero denom.

    Args:
        G_t: torch.FloatTensor [T], censoring KM on the model grid (from training data).
        time_weights: Optional [T] weights over time bins (e.g., uniform or user-defined).
        eps: small constant for numerical stability.
    """
    def __init__(self, G_t: torch.Tensor, time_weights = None, eps: float = 1e-8):
        super().__init__()
        if G_t.ndim != 1:
            raise ValueError("G_t must be 1D [T].")
        self.register_buffer("G_t", G_t.float().clamp_min(eps))
        if time_weights is not None:
            if not torch.is_tensor(time_weights):
                time_weights = torch.as_tensor(time_weights, dtype=torch.float32)
            if time_weights.ndim != 1 or time_weights.shape[0] != G_t.shape[0]:
                raise ValueError("time_weights must be 1D with length T.")
            self.register_buffer("time_weights", time_weights.clamp_min(0.0))
        else:
            self.time_weights = None
        self.eps = eps

    def forward(self, event_prob, time_onehot, time_raw, event, weights=None):
        """
        event_prob : [N, T] non-negative per-bin event mass (model output, apply ReLU before)
        time_onehot: [N, T] one-hot bin of observed time
        time_raw   : [N] kept for API compatibility (unused)
        event      : [N] 1=event, 0=censored
        weights    : Optional [N] sample weights (e.g., from mixup)
        """
        N, T = event_prob.shape
        device = event_prob.device

        # Predicted survival S_hat(j) = 1 - cum event mass
        cum = torch.cumsum(torch.relu(event_prob), dim=1)
        S_hat = (1.0 - cum).clamp(0.0, 1.0)            # [N, T]

        # Observed time index and target y_ij = I(T_i > t_j)
        if not torch.is_tensor(time_onehot):
            time_onehot = torch.as_tensor(time_onehot, device=device)
        t_idx = time_onehot.argmax(dim=1)              # [N]
        j = torch.arange(T, device=device).view(1, T)  # [1, T]
        y = (j < t_idx.view(-1, 1)).float()            # [N, T]

        # IPCW weights
        G = self.G_t.to(device)                        # [T]
        G_j = G.view(1, T).clamp_min(self.eps)

        # (a) contributions for those still at risk at j
        w_alive = (j < t_idx.view(-1, 1)).float() / G_j

        # (b) contributions for events by/at j (use G(T_i -))
        e = event.view(-1, 1).float()
        t_minus = torch.clamp(t_idx - 1, min=0)
        G_event_minus = torch.where(t_idx > 0, G[t_minus], torch.ones_like(t_idx, dtype=torch.float, device=device))
        w_event = (e * (j >= t_idx.view(-1, 1)).float()) / G_event_minus.view(-1, 1).clamp_min(self.eps)

        W = w_alive + w_event                           # [N, T]

        # Optional per-sample weights
        if weights is not None:
            W = W * weights.view(-1, 1).to(device)

        # Per-time normalized Brier
        sq_err = (y - S_hat) ** 2                       # [N, T]
        num = (W * sq_err).sum(dim=0)                   # [T]
        den = W.sum(dim=0).clamp_min(self.eps)          # [T]
        bs_t = num / den                                # [T]

        valid = (den > 0)
        if not valid.any():
            return torch.tensor(0.0, device=device)

        if self.time_weights is None:
            return bs_t[valid].mean()
        else:
            tw = self.time_weights.to(device)
            tw = torch.where(valid, tw, torch.zeros_like(tw))
            return (bs_t * tw).sum() / tw.sum().clamp_min(self.eps)


class DeepIBS(Base):
    """
    Standalone deep survival method trained by minimizing IPCW-Integrated Brier Score.

    API mirrors DeepHit:
        - Same constructor signature
        - Uses discretizer.transform_one_hot for (O -> onehot)
        - Supports mixup that returns (X, O, E, weights)
        - Uses AMP (torch.amp) and optional scheduler
    """
    def __init__(self,
                 net,
                 opt,
                 sch=None,
                 mixup=None,
                 discretizer=None,
                 train_transform=None,
                 test_transform=None,
                 epochs=100,
                 batch_size=128,
                 device=None,
                 time_weights=None):
        super(DeepIBS, self).__init__(
            net, opt, sch, mixup, discretizer, train_transform, test_transform, epochs, batch_size, device)
        self.time_weights = time_weights  # Optional [T] weights over bins

    def _fit(self, train_loader, val_loader=None):
        self.net.to(self.device)

        # 1) Estimate censoring survival G_t on training grid (REUSE for validation)
        G_t = km_censoring_survival_from_loader(self.discretizer, train_loader, device=self.device)

        # 2) Create loss (IBS)
        loss_fn = IPCWIntegratedBrierLoss(G_t=G_t, time_weights=self.time_weights)

        if self.use_amp:
            scaler = torch.amp.GradScaler(self.device)

        best_loss = float('inf')
        best_model = deepcopy(self.net)

        for epoch in range(self.epochs):
            self.net.train()
            for batch_x, batch_o, batch_e in train_loader:
                batch_x = batch_x.to(self.device, non_blocking=self.non_blocking)
                batch_o = batch_o.to(self.device, non_blocking=self.non_blocking)
                batch_e = batch_e.to(self.device, non_blocking=self.non_blocking)

                w = None
                if self.mixup is not None:
                    batch_x, batch_o, batch_e, w = self.mixup(batch_x, batch_o, batch_e)

                batch_o_onehot = self.discretizer.transform_one_hot(batch_o)
                if not torch.is_tensor(batch_o_onehot):
                    batch_o_onehot = torch.as_tensor(batch_o_onehot, device=self.device)

                self.opt.zero_grad(set_to_none=True)
                with self.amp_ctx:
                    batch_event_prob = torch.relu(self.net(batch_x))
                    loss = loss_fn(batch_event_prob, batch_o_onehot, batch_o, batch_e, weights=w)

                if self.use_amp:
                    scaler.scale(loss).backward()
                    scaler.step(self.opt)
                    scaler.update()
                else:
                    loss.backward()
                    self.opt.step()

                if self.sch is not None:
                    self.sch.step()

            if val_loader is not None:
                self.net.eval()
                val_loss = 0.0
                with torch.no_grad():
                    for batch_x, batch_o, batch_e in val_loader:
                        batch_x = batch_x.to(self.device, non_blocking=self.non_blocking)
                        batch_o = batch_o.to(self.device, non_blocking=self.non_blocking)
                        batch_e = batch_e.to(self.device, non_blocking=self.non_blocking)

                        batch_o_onehot = self.discretizer.transform_one_hot(batch_o)
                        if not torch.is_tensor(batch_o_onehot):
                            batch_o_onehot = torch.as_tensor(batch_o_onehot, device=self.device)

                        with self.amp_ctx:
                            batch_event_prob = torch.relu(self.net(batch_x))
                            loss = loss_fn(batch_event_prob, batch_o_onehot, batch_o, batch_e)

                        val_loss += loss.item()

                if val_loss < best_loss:
                    best_loss = val_loss
                    best_model = deepcopy(self.net)

        if val_loader is not None:
            self.net = deepcopy(best_model)

        return self

    def _survival_probability_at_times(self, dataloader, times):
        """
        Predict S(t | x) at user-specified raw times using the same discretizer as training.
        """
        self.net.eval()
        probs = []
        with torch.no_grad():
            for batch in dataloader:
                batch_x = batch[0]
                batch_x = batch_x.to(self.device, non_blocking=self.non_blocking)
                with self.amp_ctx:
                    event_prob = torch.relu(self.net(batch_x))           # [B, T]
                    S = 1. - torch.cumsum(event_prob, dim=1)            # [B, T]
                    probs.append(S)
        probs = torch.cat(probs, dim=0).detach().float().cpu().numpy()   # [N, T]

        T_max = probs.shape[1]
        times_idx = self.discretizer.transform(times)
        times_idx = np.clip(times_idx, a_min=None, a_max=T_max).astype(int)
        return probs[:, times_idx]
