import torch
import torch.nn as nn
from pycox.models import DeepHitSingle, LogisticHazard, PMF
from pycox.models.data import pair_rank_mat
from algorithms.backbone import backbone


class DeepHit(nn.Module):
    def __init__(self, hparams, time_grid, device):
        super(DeepHit, self).__init__()
        self.hparams = hparams
        self.feature_extractor = backbone(hparams)
        self.classifier = nn.Sequential(nn.Linear(hparams['feature_dim'], hparams['hidden_dim']), nn.ReLU(), nn.Linear(hparams['hidden_dim'], len(time_grid)))
        self.surv_model = DeepHitSingle(nn.Sequential(self.feature_extractor, self.classifier), alpha=hparams['alpha'], sigma=hparams['sigma'], device=device, duration_index=time_grid)
        self.loss = self.surv_model.loss
        self.optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams['lr'], weight_decay=self.hparams['decay'])
        self.device = device

    def forward(self, input):
        return self.classifier(self.feature_extractor(input))
    
    def get_repr(self, input):
        return self.feature_extractor(input)
    
    def get_loss(self, output, Y, D):
        rank_mat = pair_rank_mat(Y.cpu().numpy(), D.cpu().numpy())
        rank_mat = torch.tensor(rank_mat, dtype=torch.int, device=self.device)
        loss_batch = self.loss(output, Y, D, rank_mat)
        return {'total_loss': loss_batch}

    def predict_surv(self, input, batch_size, to_cpu, numpy):
        return self.surv_model.predict_surv(input, batch_size, to_cpu, numpy)
    
    def update(self, input, Y, D, S):
        output = self.forward(input)
        loss = self.get_loss(output, Y, D)['total_loss']
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        loss = {'loss': loss.item()}
        return loss


class NnetSurv(nn.Module):
    def __init__(self, hparams, time_grid, device):
        super(NnetSurv, self).__init__()
        self.hparams = hparams
        self.feature_extractor = backbone(hparams)
        self.classifier = nn.Sequential(nn.Linear(hparams['feature_dim'], hparams['hidden_dim']), nn.ReLU(), nn.Linear(hparams['hidden_dim'], len(time_grid)))
        self.surv_model = LogisticHazard(nn.Sequential(self.feature_extractor, self.classifier), device=device, duration_index=time_grid)
        self.loss = self.surv_model.loss
        self.optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams['lr'], weight_decay=self.hparams['decay'])
        self.device = device

    def forward(self, input):
        return self.classifier(self.feature_extractor(input))
    
    def get_repr(self, input):
        return self.feature_extractor(input)
    
    def get_loss(self, output, Y, D):
        loss_batch = self.loss(output, Y, D)
        return {'total_loss': loss_batch}
    
    def predict_surv(self, input, batch_size, to_cpu, numpy):
        return self.surv_model.predict_surv(input, batch_size, to_cpu, numpy)
    
    def update(self, input, Y, D, S):
        output = self.forward(input)
        loss = self.get_loss(output, Y, D)['total_loss']
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        loss = {'loss': loss.item()}
        return loss


class PMFSurv(nn.Module):
    def __init__(self, hparams, time_grid, device):
        super(PMFSurv, self).__init__()
        self.hparams = hparams
        self.feature_extractor = backbone(hparams)
        self.classifier = nn.Sequential(nn.Linear(hparams['feature_dim'], hparams['hidden_dim']), nn.ReLU(), nn.Linear(hparams['hidden_dim'], len(time_grid)))
        self.surv_model = PMF(nn.Sequential(self.feature_extractor, self.classifier), device=device, duration_index=time_grid)
        self.loss = self.surv_model.loss
        self.optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams['lr'], weight_decay=self.hparams['decay'])
        self.device = device

    def forward(self, input):
        return self.classifier(self.feature_extractor(input))
    
    def get_repr(self, input):
        return self.feature_extractor(input)
    
    def get_loss(self, output, Y, D):
        loss_batch = self.loss(output, Y, D)
        return {'total_loss': loss_batch}
    
    def predict_surv(self, input, batch_size, to_cpu, numpy):
        return self.surv_model.predict_surv(input, batch_size, to_cpu, numpy)

    def update(self, input, Y, D, S):
        output = self.forward(input)
        loss = self.get_loss(output, Y, D)['total_loss']
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        loss = {'loss': loss.item()}
        return loss
