import torch
import torch.nn.functional as F
import numpy as np
from copy import deepcopy
from models._base import Base

class MTLRLoss(torch.nn.Module):
    def __init__(self, enforce_monotone="soft", lambda_iso=0.0, lambda_smooth=1e-3, eps=1e-8):
        super().__init__()
        self.enforce_monotone = enforce_monotone
        self.lambda_iso = lambda_iso
        self.lambda_smooth = lambda_smooth
        self.eps = eps

    def _make_targets(self, t_onehot, event):
        B, K = t_onehot.shape
        device = t_onehot.device

        b_idx = t_onehot.argmax(dim=1)  # Shape: (B,)
        ks = torch.arange(K, device=device).unsqueeze(0) # Shape: (1, K)
        before_event_mask = ks < b_idx.unsqueeze(1)
        y = torch.zeros_like(t_onehot, dtype=torch.float)
        m = torch.zeros_like(t_onehot, dtype=torch.float)
        y[before_event_mask] = 1.0
        m[before_event_mask] = 1.0
        is_event = (event > 0.5).view(-1)
        if is_event.any():
            event_rows = torch.nonzero(is_event, as_tuple=False).view(-1)
            at_or_after_event_mask = ks.expand(B, K)[event_rows] >= b_idx[event_rows].unsqueeze(1)
            m[event_rows] = m[event_rows] + at_or_after_event_mask

        return y, m

    def forward(self, z, t_onehot, event, weights=None):
        y, mask = self._make_targets(t_onehot, event)
        bce = F.binary_cross_entropy_with_logits(z, y, reduction='none')
        masked_bce = bce * mask
        
        loss_per_sample = masked_bce.sum(dim=1) / (mask.sum(dim=1) + self.eps)

        if weights is not None:
            loss_main = (loss_per_sample * weights).sum() / (weights.sum() + self.eps)
        else:
            loss_main = loss_per_sample.mean()

        loss_iso = 0.0
        if self.enforce_monotone == "soft" and self.lambda_iso > 0.0:
            diffs = z[:, 1:] - z[:, :-1]
            violations = F.relu(diffs)
            loss_iso = (violations ** 2).mean() * self.lambda_iso

        loss_smooth = 0.0
        if self.lambda_smooth > 0.0:
            diffs = z[:, 1:] - z[:, :-1]
            loss_smooth = (diffs ** 2).mean() * self.lambda_smooth
            
        total_loss = loss_main + loss_iso + loss_smooth
        return total_loss

class DeepMTLR(Base):
    def __init__(self,
                 net,
                 opt,
                 sch=None,
                 mixup=None,
                 discretizer=None,
                 train_transform=None,
                 test_transform=None,
                 epochs=100,
                 batch_size=128,
                 device=None,
                 enforce_monotone="soft",
                 lambda_iso=1e-3, lambda_smooth=0.0):
        
        super(DeepMTLR, self).__init__(
            net, opt, sch, mixup, discretizer, train_transform, test_transform, epochs, batch_size, device)
        
        assert enforce_monotone in ("hard", "soft", "none"), \
            "enforce_monotone must be 'hard', 'soft', or 'none'"

        # Model-specific hyperparameters
        self.enforce_monotone = enforce_monotone
        self.lambda_iso = float(lambda_iso)
        self.lambda_smooth = float(lambda_smooth)

    def _to_monotone_logits(self, raw_scores):
        """Transforms raw network scores to be monotonically decreasing if required."""
        if self.enforce_monotone != "hard":
            return raw_scores

        if raw_scores.shape[1] <= 1:
            return raw_scores # Monotonicity is trivial for a single bin

        # First logit is a free baseline `u`
        u = raw_scores[:, :1]
        # Subsequent logits are `u` minus a cumulative sum of positive values
        deltas = F.softplus(raw_scores[:, 1:]) + 1e-6 # Ensure strictly positive deltas
        cumulative_deltas = torch.cumsum(deltas, dim=1)
        
        return torch.cat([u, u - cumulative_deltas], dim=1)
    
    def _fit(self, train_loader, val_loader=None):
        self.net.to(self.device)
        best_loss = float('inf')
        best_model = deepcopy(self.net)

        if self.use_amp:
            scaler = torch.amp.GradScaler(self.device)
        loss_fn = MTLRLoss(self.enforce_monotone, self.lambda_iso, self.lambda_smooth)
        for epoch in range(self.epochs):
            self.net.train()
            for batch_x, batch_o, batch_e in train_loader:
                batch_x = batch_x.to(self.device, non_blocking=self.non_blocking)
                batch_o = batch_o.to(self.device, non_blocking=self.non_blocking)
                batch_e = batch_e.to(self.device, non_blocking=self.non_blocking)

                w = None
                if self.mixup is not None:
                    batch_x, batch_o, batch_e, w = self.mixup(batch_x, batch_o, batch_e)

                self.opt.zero_grad()
                t_onehot = self.discretizer.transform_one_hot(batch_o)

                with torch.amp.autocast(self.device):
                    raw_scores = self.net(batch_x)
                    z = self._to_monotone_logits(raw_scores)
                    loss = loss_fn(z, t_onehot, batch_e, weights=w)

                if self.use_amp:
                    scaler.scale(loss).backward()
                    scaler.step(self.opt)
                    scaler.update()
                else:
                    loss.backward()
                    self.opt.step()

                if self.sch is not None:
                    self.sch.step()

            if val_loader is not None:
                self.net.eval()
                val_loss = 0.
                with torch.no_grad():
                    for batch_x, batch_o, batch_e in val_loader:
                        batch_x = batch_x.to(self.device, non_blocking=self.non_blocking)
                        batch_o = batch_o.to(self.device, non_blocking=self.non_blocking)
                        batch_e = batch_e.to(self.device, non_blocking=self.non_blocking)

                        t_onehot = self.discretizer.transform_one_hot(batch_o)
                        with torch.amp.autocast(self.device):
                            raw_scores = self.net(batch_x)
                            z = self._to_monotone_logits(raw_scores)
                            loss = loss_fn(z, t_onehot, batch_e)
                        val_loss += loss.item()

                if val_loss < best_loss:
                    best_loss = val_loss
                    best_model = deepcopy(self.net)
        
        if val_loader is not None:
            self.net = best_model

        return self

    def _survival_probability_at_times(self, dataloader, times):
        self.net.eval()
        probs = []
        with torch.no_grad():
            for batch in dataloader:
                batch_x = batch[0]
                batch_x = batch_x.to(self.device, non_blocking=self.non_blocking)
                with torch.amp.autocast(self.device):
                    raw_scores = self.net(batch_x)
                    z = self._to_monotone_logits(raw_scores)
                    S = torch.sigmoid(z)
                probs.append(S)
        probs = torch.cat(probs, dim=0).detach().float().cpu().numpy() 

        T_max = probs.shape[1]
        times_idx = self.discretizer.transform(times)
        times_idx = np.clip(times_idx, a_min=None, a_max=T_max).astype(int)
        return probs[:, times_idx]
