import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from pycox.models import DeepHitSingle
from pycox.models.deephit_DomainInd import DeepHit_DomainInd
from pycox.models.data import pair_rank_mat
from algorithms.backbone import backbone


class BaseModel(nn.Module):
    def __init__(self, hparams, time_grid, device):
        super(BaseModel, self).__init__()
        self.device = device
        self.hparams = hparams
        self.sen_attr_n_class = self.hparams['sen_attr_n_class']        
        self.feature_extractor = backbone(hparams)
        self.num_class = len(time_grid) * (self.sen_attr_n_class if isinstance(self, DomainInd) else 1)
        self.classifier = nn.Sequential(nn.Linear(hparams['feature_dim'], hparams['hidden_dim']), nn.ReLU(), nn.Linear(hparams['hidden_dim'], self.num_class))

        if isinstance(self, DomainInd):
            self.is_aggregated = isinstance(self, DomainIndAggregated)
            self.surv_model = DeepHit_DomainInd(nn.Sequential(self.feature_extractor, self.classifier), 
                                                self.num_class, self.sen_attr_n_class, self.is_aggregated,
                                                alpha=hparams['alpha'], sigma=hparams['sigma'], device=device, duration_index=time_grid)
        else:
            self.surv_model = DeepHitSingle(nn.Sequential(self.feature_extractor, self.classifier), alpha=hparams['alpha'], sigma=hparams['sigma'], device=device, duration_index=time_grid)
        
        self.loss = self.surv_model.loss
        self.optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams['lr'], weight_decay=self.hparams['decay'])

    def forward(self, input):
        return self.classifier(self.feature_extractor(input))
       
    def get_repr(self, input):
        return self.feature_extractor(input)
    
    def get_loss(self, output, Y, D): # output: from forward, Y:time_to_event, D:indicator
        rank_mat = pair_rank_mat(Y.cpu().numpy(), D.cpu().numpy())
        rank_mat = torch.tensor(rank_mat, dtype=torch.int, device=self.device)
        loss_batch = self.loss(output, Y, D, rank_mat)
        return {'total_loss': loss_batch}

    def predict_surv(self, input, batch_size, to_cpu, numpy):
        return self.surv_model.predict_surv(input, batch_size, to_cpu, numpy)
    
    
class GroupDRO(BaseModel):
    def __init__(self, hparams, time_grid, device):
        super(GroupDRO, self).__init__(hparams, time_grid, device)
        self.register_buffer("q", torch.Tensor())
        self.eta = hparams['eta']

    def update(self, input, Y, D, S):
        if not len(self.q):
            self.q = torch.ones(self.sen_attr_n_class).to(self.device)
        losses = torch.zeros(self.sen_attr_n_class).to(self.device)
        for i in range(self.sen_attr_n_class):
            idx = S == i
            if idx.sum() > 0:
                input_, y_, d_ = input[idx], Y[idx], D[idx]
                output = self.forward(input_)
                losses[i] = self.get_loss(output, y_, d_)['total_loss'] # replaced CrossEntropyLoss
                self.q[i] *= (self.eta * losses[i].data).exp()
        self.q /= self.q.sum()
        loss = torch.dot(losses, self.q)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        loss = {'loss': loss.item()}
        return loss
    

class DomainInd(BaseModel):
    def __init__(self, hparams, time_grid, device):
        super(DomainInd, self).__init__(hparams, time_grid, device)

    def group_pred(self, output, S):
        S = S.long()
        n_class = self.num_class // self.sen_attr_n_class
        pred = []
        for i in range(output.shape[0]):
            s_ = S[i].item()
            pred.append(output[i, s_ * n_class: (s_ + 1) * n_class])
        pred = torch.stack(pred)
        return pred
    
    def update(self, input, Y, D, S):
        output = self.forward(input)
        pred = self.group_pred(output, S)
        loss = self.get_loss(pred, Y, D)['total_loss']        
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        loss = {'loss': loss.item()}
        return loss

    def predict_surv(self, input, S, batch_size, to_cpu, numpy):
        return self.surv_model.predict_surv(input, S=S, batch_size=batch_size, to_cpu=to_cpu, numpy=numpy)


class DomainIndAggregated(DomainInd):
    def __init__(self, hparams, time_grid, device):
        super(DomainIndAggregated, self).__init__(hparams, time_grid, device)
        
    def predict_surv(self, input, batch_size, to_cpu, numpy):
        return self.surv_model.predict_surv(input, batch_size=batch_size, to_cpu=to_cpu, numpy=numpy)


class Reweighting(BaseModel):
    def __init__(self, hparams, time_grid, device):
        super(Reweighting, self).__init__(hparams, time_grid, device)
    
    def update(self, input, Y, D, S):
        output = self.forward(input)
        loss = self.get_loss(output, Y, D)['total_loss'] # replaced CrossEntropyLoss
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        loss = {'loss': loss.item()}
        return loss
 
 
class Mean_loss(nn.Module):
    def __init__(self, device, num_class):
        super(Mean_loss, self).__init__()
        self.device = device
        self.num_class = num_class
        
    def forward(self, output, S): # output, Y, S
        result = torch.FloatTensor([0.0]).contiguous().to(self.device)
        result = result + self.compute_mean_gap_group(output, S)
        return result

    def compute_mean_gap_group(self, outputs, group):
        result = torch.FloatTensor([0.0]).contiguous().to(self.device)
        unique_groups = group.unique()
        for the_group in unique_groups:
            result += self.compute_mean_gap(outputs[group == the_group], outputs)
        return result

    def compute_mean_gap(self, x, y):
        return (x.mean() - y.mean()) ** 2


class MMD_loss(nn.Module):
    def __init__(self, device, num_class):
        super(MMD_loss, self).__init__()
        self.device = device
        self.num_class = num_class
        
    def forward(self, output, S): # output, Y, S
        result = torch.FloatTensor([0.0]).contiguous().to(self.device)
        result = result + self.compute_mmd_gap_group(output, S)
        return result

    def compute_mmd_gap_group(self, outputs, group):
        result = torch.FloatTensor([0.0]).contiguous().to(self.device)
        unique_groups = group.unique()
        for the_group in unique_groups:
            result += self.compute_mmd_gap(outputs[group == the_group], outputs)
        return result

    def compute_mmd_gap(self, x, y):
        return nn.MSELoss(size_average=False)(x.mean(dim=0), y.mean(dim=0))

    
class Regularization(BaseModel):
    def __init__(self, hparams, time_grid, device):
        super(Regularization, self).__init__(hparams, time_grid, device)
        self.fair_loss_fn = MMD_loss(device=self.device, num_class=self.num_class)
        self.fair_weight = hparams['fair_weight']
    
    def update(self, input, Y, D, S):
        # output = self.forward(input)
        representation = self.get_repr(input)
        output = self.classifier(representation)
        loss = self.get_loss(output, Y, D)['total_loss'] # replaced CrossEntropyLoss
        fair_loss = self.fair_loss_fn(representation, S).squeeze()
        loss += self.fair_weight * fair_loss
        # loss = loss / (1 + self.fair_weight)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        loss = {'loss': loss.item()}
        return loss