import torch
import numpy as np
from copy import deepcopy
from models._base import Base

class DeepHitLoss(torch.nn.Module):
    """
    Weighted DeepHit Loss Function.
    
    This loss function is a combination of two parts:
    1. The negative log-likelihood of the joint distribution of event/censoring time and event type.
    2. A ranking loss that penalizes incorrect ordering of predicted survival probabilities.
    
    Both components have been adapted to handle sample weights.
    
    Args:
        alpha (float): The blending parameter for the two loss components. 
                       The final loss is `(1 - alpha) * nll_loss + alpha * ranking_loss`.
                       Must be between 0.0 and 1.0.
        sigma (float): A hyperparameter for the ranking loss computation.
    """
    def __init__(self, alpha=0.5, sigma=1.0):
        super().__init__()
        if not 0.0 <= alpha <= 1.0:
            raise ValueError("alpha must be between 0.0 and 1.0")
        self.alpha = alpha
        self.sigma = sigma
        self.eps = 1e-6

    def _nll_loss(self, event_prob, cumul_prob, t_one_hot, e, weights=None):
        """Calculates the weighted negative log-likelihood loss."""
        # Probability of event at the observed time t
        event_prob_at_t = (event_prob * t_one_hot).sum(dim=1)
        # Cumulative probability up to the observed time t
        cumul_prob_at_t = (cumul_prob * t_one_hot).sum(dim=1)

        # Log-likelihood calculation
        # For uncensored samples (e=1), we use the event probability density.
        # For censored samples (e=0), we use the survival probability (1 - cumulative density).
        log_likelihood = e * torch.log(torch.clamp(event_prob_at_t, min=self.eps)) \
                         + (1 - e) * torch.log(torch.clamp(1.0 - cumul_prob_at_t, min=self.eps))
        
        # Apply weights
        if weights is not None:
            # Weighted average
            loss = -(log_likelihood * weights).sum() / torch.clamp(weights.sum(), min=self.eps)
        else:
            # Unweighted average
            loss = -log_likelihood.mean()
            
        return loss

    def _ranking_loss(self, cumul_prob, t_one_hot, t_raw, e, weights=None):
        """Calculates the weighted ranking loss."""
        # A pair of samples (i, j) is comparable if sample i had an event (e_i=1)
        # and its event time is less than the observation time of sample j (t_i < t_j).
        pair_mask = (e.unsqueeze(1) == 1) * (t_raw.unsqueeze(1) < t_raw.unsqueeze(0))

        num_pairs = pair_mask.sum()
        if num_pairs < 1:
            return torch.tensor(0.0, device=cumul_prob.device)
            
        # C_matrix[i, j] = F(t_j | x_i), the model's predicted cumulative probability
        # for sample i evaluated at the observation time of sample j.
        C_matrix = cumul_prob @ t_one_hot.float().T
        # F_i_at_ti is a column vector where F_i_at_ti[i] = F(t_i | x_i).
        F_i_at_ti = torch.diag(C_matrix).unsqueeze(1)
        
        # The difference matrix for the ranking loss.
        # The loss penalizes cases where F(t_j|x_i) is not sufficiently larger than F(t_i|x_i).
        # We assume the user's original formulation is desired.
        diff_matrix = C_matrix - F_i_at_ti
        
        ranking_terms = torch.exp(diff_matrix / self.sigma) * pair_mask.float()
        
        if weights is not None:
            # The weight for a pair (i, j) is defined by the weight of sample i (the one with the event).
            weight_matrix = weights.unsqueeze(1)
            
            # Apply weights to each term in the sum
            weighted_ranking_terms = ranking_terms * weight_matrix
            
            # Normalize by the sum of weights of the valid pairs
            denominator = (pair_mask.float() * weight_matrix).sum()
            loss = weighted_ranking_terms.sum() / torch.clamp(denominator, min=self.eps)
        else:
            # Original unweighted normalization
            loss = ranking_terms.sum() / torch.clamp(num_pairs.float(), min=self.eps)
            
        return loss

    def forward(self, event_prob, time_onehot, time_raw, event, weights=None):
        cumul_prob = torch.cumsum(event_prob, dim=1)
        
        nll = self._nll_loss(event_prob, cumul_prob, time_onehot, event, weights)
        if self.alpha > 0:
            ranking = self._ranking_loss(cumul_prob, time_onehot, time_raw, event, weights)

        # Combine the two loss components
        return (1.0 - self.alpha) * nll + self.alpha * ranking
    
class DeepHit(Base):
    def __init__(self,
                 net,
                 opt,
                 sch=None,
                 mixup=None,
                 discretizer=None,
                 train_transform=None,
                 test_transform=None,
                 epochs=100,
                 batch_size=128,
                 device=None):

        super(DeepHit, self).__init__(
            net, opt, sch, mixup, discretizer, train_transform, test_transform, epochs, batch_size, device)

    def _fit(self, train_loader, val_loader=None):
        self.net.to(self.device)
        best_loss = float('inf')
        best_model = deepcopy(self.net)

        if self.use_amp:
            scaler = torch.amp.GradScaler(self.device)
        loss_fn = DeepHitLoss()
        for epoch in range(self.epochs):
            self.net.train()
            for batch_x, batch_o, batch_e in train_loader:
                batch_x = batch_x.to(self.device, non_blocking=self.non_blocking)
                batch_o = batch_o.to(self.device, non_blocking=self.non_blocking)
                batch_e = batch_e.to(self.device, non_blocking=self.non_blocking)

                w = None
                if self.mixup is not None:
                    batch_x, batch_o, batch_e, w = self.mixup(batch_x, batch_o, batch_e)

                batch_o_onehot = self.discretizer.transform_one_hot(batch_o)
                self.opt.zero_grad()
                with self.amp_ctx:
                    batch_event_prob = torch.relu(self.net(batch_x))
                    loss = loss_fn(batch_event_prob, batch_o_onehot, batch_o, batch_e, weights=w)
                if self.use_amp:
                    scaler.scale(loss).backward()
                    scaler.step(self.opt)
                    scaler.update()
                else:
                    loss.backward()
                    self.opt.step()

                if self.sch is not None:
                    self.sch.step()

            if val_loader is not None:
                self.net.eval()
                val_loss = 0.
                with torch.no_grad():
                    for batch_x, batch_o, batch_e in val_loader:
                        batch_x = batch_x.to(self.device, non_blocking=self.non_blocking)
                        batch_o = batch_o.to(self.device, non_blocking=self.non_blocking)
                        batch_e = batch_e.to(self.device, non_blocking=self.non_blocking)

                        batch_o_onehot = self.discretizer.transform_one_hot(batch_o)
                        with self.amp_ctx:
                            batch_event_prob = torch.relu(self.net(batch_x))
                            loss = loss_fn(batch_event_prob, batch_o_onehot, batch_o, batch_e)

                        val_loss += loss.item()

                if val_loss < best_loss:
                    best_loss = val_loss
                    best_model = deepcopy(self.net)

        if val_loader is not None:
            self.net =  deepcopy(best_model)

        return self
    
    def _survival_probability_at_times(self, dataloader, times=None):
        self.net.eval()
        probs = []
        with torch.no_grad():
            for batch in dataloader:
                batch_x = batch[0]
                batch_x = batch_x.to(self.device, non_blocking=self.non_blocking)
                with self.amp_ctx:
                    prob = torch.relu(self.net(batch_x))
                    probs.append(1. - torch.cumsum(prob, 1))
        
        probs = torch.cat(probs, axis=0).detach().float().cpu().numpy()

        T_max = probs.shape[1]
        if times is None:
            return probs
        
        times_idx = self.discretizer.transform(times)
        times_idx = np.clip(times_idx, a_min=None, a_max=T_max).astype(int)
        return probs[:, times_idx]
