import torch
from torch.distributions import Beta, Bernoulli

class SurvMixup:
    def __init__(self, alpha=0.4, strategy='hmix', keep_prev=True, device=None):
        self.alpha = alpha
        self.strategy = strategy
        self.eps = 1e-4
        self.keep_prev = keep_prev
        self.device = device if device is not None else ('cuda' if torch.cuda.is_available() else 'cpu')
        
        if strategy == 'hmix':
            self.forward = self._h_mixup
        elif strategy == 'chmix':
            self.forward = self._ch_mixup
        elif strategy == 'smix':
            self.forward = self._s_mixup
        elif strategy == 'omix':
            self.forward = self._o_mixup
        else:
            raise ValueError(f"Unknown mixup strategy: {strategy}")

    def __call__(self, batch_x, batch_t, batch_e):
        return self.forward(batch_x, batch_t, batch_e)

    def _get_mix_params(self, batch_size):
        lam = Beta(self.alpha, self.alpha).sample((batch_size,))
        lam = torch.clamp(lam, self.eps, 1.0 - self.eps).to(self.device)
        perm_idx = torch.randperm(batch_size, device=self.device)
        return lam, perm_idx

    def _s_mixup(self, batch_x, batch_t, batch_e):
        lam, perm_idx = self._get_mix_params(batch_x.size(0))

        x_extra = [1] * (batch_x.ndim - 1)
        lam_x = lam.view(-1, *x_extra)
        x_mix = lam_x * batch_x + (1.0 - lam_x) * batch_x[perm_idx]

        z = (torch.rand_like(lam) < lam)

        t_mix = torch.where(z, batch_t, batch_t[perm_idx])
        e_mix = torch.where(z, batch_e, batch_e[perm_idx])

        if self.keep_prev:
            p_target = float(batch_e.float().mean().item())

            p_hat = float(e_mix.float().mean().clamp_(1e-6, 1.0 - 1e-6).item())
            w1 = p_target / p_hat
            w0 = (1.0 - p_target) / (1.0 - p_hat)
            weights = torch.where(
                e_mix.bool(),
                torch.as_tensor(w1, device=self.device, dtype=t_mix.dtype),
                torch.as_tensor(w0, device=self.device, dtype=t_mix.dtype),
            )
        else:
            weights = torch.ones_like(t_mix, dtype=t_mix.dtype, device=self.device)

        return x_mix.to(batch_x.dtype), t_mix.to(batch_t.dtype), e_mix.to(batch_e.dtype), weights.to(t_mix.dtype)

    def _ch_mixup(self, batch_x, batch_t, batch_e):
        lam, perm_idx = self._get_mix_params(batch_x.size(0))

        x_extra = [1] * (batch_x.ndim - 1)
        lam_x = lam.view(-1, *x_extra)
        x_mix = lam_x * batch_x + (1. - lam_x) * batch_x[perm_idx]

        z = (torch.rand_like(lam) < lam)

        t_mix = 1. / (lam * (1. / batch_t) + (1. - lam) * (1. / batch_t[perm_idx]))
        e_mix = torch.where(z, batch_e, batch_e[perm_idx])

        if self.keep_prev:
            p_target = float(batch_e.float().mean().item())

            p_hat = float(e_mix.float().mean().clamp_(1e-6, 1.0 - 1e-6).item())
            w1 = p_target / p_hat
            w0 = (1.0 - p_target) / (1.0 - p_hat)
            weights = torch.where(
                e_mix.bool(),
                torch.as_tensor(w1, device=self.device, dtype=t_mix.dtype),
                torch.as_tensor(w0, device=self.device, dtype=t_mix.dtype),
            )
        else:
            weights = torch.ones_like(t_mix, dtype=t_mix.dtype, device=self.device)

        return x_mix.to(batch_x.dtype), t_mix.to(batch_t.dtype), e_mix.to(batch_e.dtype), weights.to(t_mix.dtype)


    def _h_mixup(self, batch_x, batch_t, batch_e):
        lam, perm_idx = self._get_mix_params(batch_x.size(0))

        x_extra_dims = [1] * (batch_x.ndim - 1)
        lam_x = lam.view(-1, *x_extra_dims)
        x_mix = lam_x * batch_x + (1.0 - lam_x) * batch_x[perm_idx]

        oi, oj = batch_t, batch_t[perm_idx]
        di, dj = batch_e, batch_e[perm_idx]

        ti = oi / lam
        tj = oj / (1.0 - lam)
        i_wins = ti <= tj  

        t_mix = torch.where(i_wins, ti, tj)
        e_mix = torch.where(i_wins, di, dj)

        if self.keep_prev:
            p_target = float(batch_e.float().mean().item())
            p_hat = float(e_mix.float().mean().clamp_(1e-6, 1.0 - 1e-6).item())

            w1 = p_target / p_hat
            w0 = (1.0 - p_target) / (1.0 - p_hat)
            weights = torch.where(
                e_mix.bool(),
                torch.as_tensor(w1, device=self.device, dtype=t_mix.dtype),
                torch.as_tensor(w0, device=self.device, dtype=t_mix.dtype),
            )
        else:
            weights = torch.ones_like(t_mix, dtype=t_mix.dtype, device=self.device)

        return x_mix.to(batch_x.dtype), t_mix.to(batch_t.dtype), e_mix.to(batch_e.dtype), weights.to(t_mix.dtype)

    def _o_mixup(self, batch_x, batch_t, batch_e):
        lam, perm_idx = self._get_mix_params(batch_x.size(0))

        x_extra = [1] * (batch_x.ndim - 1)
        lam_x = lam.view(-1, *x_extra)
        x_mix = lam_x * batch_x + (1. - lam_x) * batch_x[perm_idx]

        z = (torch.rand_like(lam) < lam)

        t_mix = (lam * batch_t + (1. - lam) * batch_t[perm_idx])
        e_mix = torch.where(z, batch_e, batch_e[perm_idx])

        if self.keep_prev:
            p_target = float(batch_e.float().mean().item())

            p_hat = float(e_mix.float().mean().clamp_(1e-6, 1.0 - 1e-6).item())
            w1 = p_target / p_hat
            w0 = (1.0 - p_target) / (1.0 - p_hat)
            weights = torch.where(
                e_mix.bool(),
                torch.as_tensor(w1, device=self.device, dtype=t_mix.dtype),
                torch.as_tensor(w0, device=self.device, dtype=t_mix.dtype),
            )

        else:
            weights = torch.ones_like(t_mix, dtype=t_mix.dtype, device=self.device)

        return x_mix.to(batch_x.dtype), t_mix.to(batch_t.dtype), e_mix.to(batch_e.dtype), weights.to(t_mix.dtype)
