import pandas as pd
import torch
import torchtuples as tt

from pycox import models
from pycox.models.utils import pad_col

class DeepHit_DomainInd(models.pmf.PMFBase):
    def __init__(self, net, num_class, sen_attr_n_class, is_aggregated,
                 optimizer=None, device=None, alpha=0.2, sigma=0.1, duration_index=None, loss=None):
        if loss is None:
            loss = models.loss.DeepHitSingleLoss(alpha, sigma)
        super().__init__(net, loss, optimizer, device, duration_index)
        self.num_class = num_class
        self.sen_attr_n_class = sen_attr_n_class
        self.is_aggregated = is_aggregated

    def make_dataloader(self, data, batch_size, shuffle, num_workers=0):
        dataloader = super().make_dataloader(data, batch_size, shuffle, num_workers,
                                             make_dataset=models.data.DeepHitDataset)
        return dataloader
    
    def make_dataloader_predict(self, input, batch_size, shuffle=False, num_workers=0):
        dataloader = super().make_dataloader(input, batch_size, shuffle, num_workers)
        return dataloader

    def predict_surv(self, input, S=None, batch_size=256, numpy=None, eval_=True, to_cpu=False, num_workers=0):
        pmf = self.predict_pmf(input, S, batch_size, False, eval_, to_cpu, num_workers)
        surv = 1 - pmf.cumsum(1)
        return tt.utils.array_or_tensor(surv, numpy, input)
    
    def predict_pmf(self, input, S=None, batch_size=256, numpy=None, eval_=True, to_cpu=False, num_workers=0):
        preds = self.predict(input, batch_size, False, eval_, False, to_cpu, num_workers)
        if self.is_aggregated:
            preds = self.group_sum_predict(preds)
        else:
            preds = self.group_pred(preds, S)
        pmf = pad_col(preds).softmax(1)[:, :-1]
        return tt.utils.array_or_tensor(pmf, numpy, input)

    def group_pred(self, logits, S):
        S = S.long()
        n_class = self.num_class // self.sen_attr_n_class
        pred = []
        for i in range(logits.shape[0]):
            s_ = S[i].item()
            pred.append(logits[i, s_ * n_class: (s_ + 1) * n_class])
        pred = torch.stack(pred)
        return pred
    
    def group_sum_predict(self, logits):
        n_class = self.num_class // self.sen_attr_n_class
        pred = []
        for i in range(self.sen_attr_n_class):
            pred.append(logits[:, i * n_class: (i + 1) * n_class])
        pred = torch.stack(pred)
        pred = torch.sum(pred, dim=0)
        return pred
    
    def predict_surv_df(self, input, S=None, batch_size=256, eval_=True, num_workers=0):
        surv = self.predict_surv(input, S, batch_size, True, eval_, True, num_workers)
        return pd.DataFrame(surv.transpose(), self.duration_index)

