import math
import torch
import torch.nn.functional as F
from copy import deepcopy
from models._base import Base

class AFTNLL(torch.nn.Module):
    def __init__(self, baseline='loglogistic', min_sigma=1e-3, fixed_sigma=1.0):
        super(AFTNLL, self).__init__()
        self.baseline = baseline
        self.min_sigma = min_sigma
        self.fixed_sigma = float(fixed_sigma)

    def forward(self, y_hat, time, event, weights=None):
        eps = 1e-8
        time = time.squeeze()
        event = event.squeeze()
        
        if weights is None:
            weights = torch.ones_like(time)
        else:
            weights = weights.squeeze()

        if event.dtype == torch.bool:
            event = event.float()

        mu, sigma = self._unpack_mu_sigma(y_hat)     
        t = torch.clamp(time, min=eps)                
        z = (t.log() - mu) / sigma                   

        log_f_e = self._log_pdf_epsilon(z)            
        log_pdf_T = log_f_e - t.log() - sigma.log()

        log_S_T = self._log_survival_epsilon(z)

        loglik = event * log_pdf_T + (1.0 - event) * log_S_T
        nll = - (loglik * weights).sum() / (weights.sum() + eps)
        
        return nll
    
    def _unpack_mu_sigma(self, y_hat):
        if y_hat.dim() == 2 and y_hat.size(1) == 2:
            mu = y_hat[:, 0]
            raw = y_hat[:, 1]
            sigma = F.softplus(raw) + self.min_sigma
        elif y_hat.dim() == 2 and y_hat.size(1) == 1:
            mu = y_hat[:, 0]
            if self.learn_scale:
                sigma = F.softplus(self.global_log_sigma).expand_as(mu) + self.min_sigma
            else:
                sigma = torch.full_like(mu, float(self.fixed_sigma))
        elif y_hat.dim() == 1:
            mu = y_hat
            if self.learn_scale:
                sigma = F.softplus(self.global_log_sigma).expand_as(mu) + self.min_sigma
            else:
                sigma = torch.full_like(mu, float(self.fixed_sigma))
        else:
            raise ValueError("net(x) must return shape (B,1) or (B,2)")

        return mu, sigma

    def _log_pdf_epsilon(self, z):
        """
        log f_ε(z) for baseline ε.
        """
        if self.baseline == 'lognormal':
            # ε ~ N(0,1): log φ(z) = -0.5*z^2 - 0.5*log(2π)
            return -0.5 * z * z - 0.5 * math.log(2.0 * math.pi)
        elif self.baseline == 'loglogistic':
            # ε ~ Logistic(0,1): f(z) = exp(-z) / (1 + exp(-z))^2
            # log f(z) = -z - 2*softplus(-z)
            return -z - 2.0 * F.softplus(-z)
        else:
            raise RuntimeError("Unsupported baseline")

    def _log_survival_epsilon(self, z):
        """
        log S_ε(z) for baseline ε.
        """
        if self.baseline == 'lognormal':
            # S(z) = 0.5 * erfc(z / sqrt(2)); use log for stability
            return math.log(0.5) + torch.log(torch.special.erfc(z / math.sqrt(2.0)).clamp_min(1e-38))
        elif self.baseline == 'loglogistic':
            # S(z) = 1 / (1 + exp(z)) ; log S = -log(1 + exp(z)) = -softplus(z)
            return -F.softplus(z)
        else:
            raise RuntimeError("Unsupported baseline")

class DeepAFT(Base):
    def __init__(
        self,
        net,
        opt,
        sch=None,
        mixup=None,
        discretizer=None,
        train_transform=None,
        test_transform=None,
        epochs=100,
        batch_size=128,
        device=None,
        baseline='lognormal',        # 'lognormal' or 'loglogistic'
        learn_scale=True,              # learn σ if net outputs only μ
        fixed_sigma=1.0,               # used if learn_scale=False and net outputs only μ
        init_log_sigma=0.0,            # initial log σ when learn_scale=True and net outputs only μ
        min_sigma=1e-3                 # lower bound for σ for numerical stability
    ):
        super(DeepAFT, self).__init__(
            net, opt, sch, mixup, discretizer, train_transform, test_transform, epochs, batch_size, device)
        assert baseline.lower() in ('lognormal', 'loglogistic'), "baseline must be 'lognormal' or 'loglogistic'"
        
        self.baseline = baseline.lower()
        self.learn_scale = learn_scale
        self.fixed_sigma = float(fixed_sigma)
        self.min_sigma = float(min_sigma)

        if self.learn_scale:
            self.global_log_sigma = torch.nn.Parameter(torch.tensor(float(init_log_sigma)))
            if not any(self.global_log_sigma in p for g in self.opt.param_groups for p in g.get('params', [])):
               self.opt.add_param_group({'params': [self.global_log_sigma]})
        else:
            self.global_log_sigma = None

    def _fit(self, train_loader, val_loader=None):
        self.net.to(self.device)
        best_loss = float('inf')
        best_model = deepcopy(self.net)

        if self.use_amp:
            scaler = torch.amp.GradScaler(self.device)
        loss_fn = AFTNLL(baseline=self.baseline, min_sigma=self.min_sigma, fixed_sigma=self.fixed_sigma)
        for epoch in range(self.epochs):
            self.net.train()
            for batch_x, batch_o, batch_e in train_loader:
                batch_x = batch_x.to(self.device, non_blocking=self.non_blocking)
                batch_o = batch_o.to(self.device, non_blocking=self.non_blocking)
                batch_e = batch_e.to(self.device, non_blocking=self.non_blocking)

                w = None
                if self.mixup is not None:
                    batch_x, batch_o, batch_e, w = self.mixup(batch_x, batch_o, batch_e)

                self.opt.zero_grad()
                with self.amp_ctx:
                    y_hat = self.net(batch_x)
                    loss = loss_fn(y_hat, batch_o, batch_e, weights=w)
                if self.use_amp:
                    scaler.scale(loss).backward()
                    scaler.step(self.opt)
                    scaler.update()
                else:
                    loss.backward()
                    self.opt.step()

                if self.sch is not None:
                    self.sch.step()

            if val_loader is not None:
                self.net.eval()
                val_loss = 0.
                with torch.no_grad():
                    for batch_x, batch_o, batch_e in val_loader:
                        batch_x = batch_x.to(self.device, non_blocking=self.non_blocking)
                        batch_o = batch_o.to(self.device, non_blocking=self.non_blocking)
                        batch_e = batch_e.to(self.device, non_blocking=self.non_blocking)
                        with self.amp_ctx:
                            val_yhat = self.net(batch_x)
                            loss = loss_fn(val_yhat, batch_o, batch_e)

                        val_loss += loss.item()

                if val_loss < best_loss:
                    best_loss = val_loss
                    best_model = deepcopy(self.net)

        if val_loader is not None:
            self.net = best_model

        return self

    def _survival_probability_at_times(self, dataloader, times):
        self.net.eval()
        times = torch.as_tensor(times, dtype=torch.float32)
        out = []
        with torch.no_grad():
            for batch in dataloader:
                batch_x = batch[0]
                batch_x = batch_x.to(self.device, non_blocking=self.non_blocking)
                with self.amp_ctx:
                    y_hat = self.net(batch_x)
                    mu, sigma = self._unpack_mu_sigma(y_hat)
                    z = (times.log()[None, :].to(self.device) - mu[:, None]) / sigma[:, None]
                    logS = self._log_survival_epsilon(z)
                    S = torch.exp(logS).clamp(0.0, 1.0)
                out.append(S.detach().float().cpu())

        return torch.cat(out, dim=0).numpy()

    def _unpack_mu_sigma(self, y_hat):
            if y_hat.dim() == 2 and y_hat.size(1) == 2:
                mu = y_hat[:, 0]
                raw = y_hat[:, 1]
                sigma = F.softplus(raw) + self.min_sigma
            elif y_hat.dim() == 2 and y_hat.size(1) == 1:
                mu = y_hat[:, 0]
                if self.learn_scale:
                    sigma = F.softplus(self.global_log_sigma).expand_as(mu) + self.min_sigma
                else:
                    sigma = torch.full_like(mu, float(self.fixed_sigma))
            elif y_hat.dim() == 1:
                mu = y_hat
                if self.learn_scale:
                    sigma = F.softplus(self.global_log_sigma).expand_as(mu) + self.min_sigma
                else:
                    sigma = torch.full_like(mu, float(self.fixed_sigma))
            else:
                raise ValueError("net(x) must return shape (B,1) or (B,2)")

            return mu, sigma
    
    def _log_survival_epsilon(self, z):
        """
        log S_ε(z) for baseline ε.
        """
        if self.baseline == 'lognormal':
            # S(z) = 0.5 * erfc(z / sqrt(2)); use log for stability
            return math.log(0.5) + torch.log(torch.special.erfc(z / math.sqrt(2.0)).clamp_min(1e-38))
        elif self.baseline == 'loglogistic':
            # S(z) = 1 / (1 + exp(z)) ; log S = -log(1 + exp(z)) = -softplus(z)
            return -F.softplus(z)
        else:
            raise RuntimeError("Unsupported baseline")