
import numpy as np
import pandas as pd
import torch
from copy import deepcopy
from models._base import Base

class CoxNPLL(torch.nn.Module):
    def __init__(self):
        super(CoxNPLL, self).__init__()

    def forward(self, logh, time, event, weights=None):
        eps = 1e-8
        logh = logh.squeeze()
        time = time.squeeze()
        event = event.squeeze()

        if weights is None:
            weights = torch.ones_like(logh)
        else:
            weights = weights.squeeze()

        sorted_indices = time.sort(descending=True)[1]
        logh = logh[sorted_indices]
        event = event[sorted_indices]
        weights = weights[sorted_indices]
        
        if event.dtype is torch.bool:
            event = event.float()

        gamma = logh.max()
        weighted_logh_exp = weights * torch.exp(logh - gamma)
        cumsum_weighted_h = torch.cumsum(weighted_logh_exp, dim=0)
        log_cumsum_h = torch.log(cumsum_weighted_h + eps) + gamma
        log_likelihood = weights * (logh - log_cumsum_h)
        log_likelihood = log_likelihood * event
        return -torch.sum(log_likelihood) / (torch.sum(weights * event) + eps)


class DeepCox(Base):
    def __init__(self,
                 net,
                 opt,
                 sch=None,
                 mixup=None,
                 discretizer=None,
                 train_transform=None,
                 test_transform=None,
                 epochs=100,
                 batch_size=128,
                 device=None):
        
        super(DeepCox, self).__init__(
            net, opt, sch, mixup, discretizer, train_transform, test_transform, epochs, batch_size, device)
        
        self.baseline_hazard_ = None

    def _fit(self, train_loader, val_loader=None):
        best_loss = float('inf')
        best_model = deepcopy(self.net)
        
        if self.use_amp:
            scaler = torch.amp.GradScaler(self.device)
        loss_fn = CoxNPLL()
        for epoch in range(self.epochs):
            self.net.train()
            for batch_x, batch_o, batch_e in train_loader:
                if batch_e.sum() == 0:
                    continue

                batch_x = batch_x.to(self.device, non_blocking=self.non_blocking)
                batch_o = batch_o.to(self.device, non_blocking=self.non_blocking)
                batch_e = batch_e.to(self.device, non_blocking=self.non_blocking)

                w = None
                if self.mixup is not None:
                    batch_x, batch_o, batch_e, w = self.mixup(batch_x, batch_o, batch_e)

                self.opt.zero_grad()
                with self.amp_ctx:
                    batch_logh = self.net(batch_x)
                    loss = loss_fn(batch_logh, batch_o, batch_e, weights=w)
                if self.use_amp:
                    scaler.scale(loss).backward()
                    scaler.step(self.opt)
                    scaler.update()
                else:
                    loss.backward()
                    self.opt.step()
                if self.sch is not None:
                    self.sch.step()

            if val_loader is not None:
                self.net.eval()
                val_loss = 0.
                with torch.no_grad():
                    for batch_x, batch_o, batch_e in val_loader:
                        batch_x = batch_x.to(self.device, non_blocking=self.non_blocking)
                        batch_o = batch_o.to(self.device, non_blocking=self.non_blocking)
                        batch_e = batch_e.to(self.device, non_blocking=self.non_blocking)
                        with self.amp_ctx:
                            batch_logh = self.net(batch_x)
                            loss = loss_fn(batch_logh, batch_o, batch_e)

                        val_loss += loss.item()

                if val_loss < best_loss:
                    best_loss = val_loss
                    best_model = deepcopy(self.net)

        if val_loader is not None:
            self.net = best_model

        self._compute_baseline_hazard(train_loader)
        return self
    
    def _survival_probability_at_times(self, dataloader, times):
        self.net.eval()
        risk_score = []
        with torch.no_grad():
            for batch in dataloader:
                batch_x = batch[0]
                batch_x = batch_x.to(self.device, non_blocking=self.non_blocking)
                with self.amp_ctx:
                    batch_h = self.net(batch_x).exp()
                risk_score.append(batch_h)

        risk_score = torch.cat(risk_score).detach().float().cpu().numpy()
        if risk_score.ndim == 1:
            risk_score = risk_score[:, np.newaxis]
        
        log_baseline_survival = -np.interp(times, self.baseline_hazard_['time'], self.baseline_hazard_['cumulative_baseline_hazard'])
        return np.exp(log_baseline_survival * risk_score)

    def _compute_baseline_hazard(self, dataloader):
        self.net.eval()
        risk_score = []
        o_stack, e_stack = [], []
        with torch.no_grad():
            for batch_x, batch_o, batch_e in dataloader:
                batch_x = batch_x.to(self.device, non_blocking=self.non_blocking)
                with self.amp_ctx:
                    batch_h = self.net(batch_x).exp()
                risk_score.append(batch_h)
                o_stack.append(batch_o)
                e_stack.append(batch_e)

        risk_score = torch.cat(risk_score).detach().float().cpu().numpy()
        o_stack = torch.cat(o_stack).detach().cpu().numpy()
        e_stack = torch.cat(e_stack).detach().cpu().numpy()

        if risk_score.ndim == 1:
            risk_score = risk_score[:, np.newaxis]

        t_unique = np.unique(o_stack[e_stack == 1]) 
        
        baseline_hazard = []
        for o_j in t_unique:
            d_j = np.sum((o_stack == o_j) & (e_stack == 1)) 
            R_j = np.sum(risk_score[o_stack >= o_j]) 
            h0_t = d_j / R_j if R_j > 0 else 0 
            baseline_hazard.append((o_j, h0_t))
        
        baseline_hazard_df = pd.DataFrame(baseline_hazard, columns=['time', 'baseline_hazard'])
        baseline_hazard_df['cumulative_baseline_hazard'] = baseline_hazard_df['baseline_hazard'].cumsum()
        
        self.baseline_hazard_ = baseline_hazard_df

