import torch
from torch import nn
from torch.autograd import Function
from utils_lstm      import VariationalLSTM

def grad_reverse(x, scale=1.0):

    class ReverseGrad(Function):
        """
        Gradient reversal layer
        """

        @staticmethod
        def forward(ctx, x):
            return x

        @staticmethod
        def backward(ctx, grad_output):
            return scale * grad_output.neg()

    return ReverseGrad.apply(x)

class BuildBr(nn.Module):
    def __init__(self, 
                 input_size, 
                 seq_hidden_units, 
                 num_layer, 
                 dropout_rate, 
                 br_size):
        super().__init__()
        # LSTM
        self.lstm = VariationalLSTM(input_size, 
                                    seq_hidden_units, 
                                    num_layer, 
                                    dropout_rate)
        # fc layers
        self.linear1 = nn.Linear(seq_hidden_units, br_size)
        self.elu1    = nn.ELU()
    
    def forward(self, 
                prev_treatments, 
                prev_dosages, 
                prev_outputs, 
                static_features, 
                init_states = None):
        
        # concat
        x = torch.cat((prev_treatments, prev_dosages), dim = -1)
        x = torch.cat((x, prev_outputs), dim = -1)
        x = torch.cat((x, static_features), dim = -1)
        
        # LSTM
        x = self.lstm(x, init_states = init_states)
        
        # FCLayer
        br = self.elu1(self.linear1(x))
        
        return br
    
class BROutcomeHead(nn.Module):

    def __init__(self, 
                 br_size, 
                 fc_hidden_units, 
                 dim_treatments, 
                 dim_dosages,
                 dim_outcome):
        
        super().__init__()

        self.br_size         = br_size
        self.fc_hidden_units = fc_hidden_units
        self.dim_treatments  = dim_treatments
        self.dim_dosages     = dim_dosages
        self.dim_outcome     = dim_outcome
        
        # Network
        self.linear1 = nn.Linear(self.br_size + self.dim_treatments + self.dim_dosages, 
                                 self.fc_hidden_units)
        self.elu1 = nn.ELU()
        self.linear2 = nn.Linear(self.fc_hidden_units, self.dim_outcome)
        

    def forward(self, br, current_treatment, current_dosage):
        x = torch.cat((br, current_treatment), dim = -1)
        x = torch.cat((x, current_dosage), dim = -1)
        x = self.elu1(self.linear1(x))
        outcome = self.linear2(x)
        return outcome

class BRTreatmentHead(nn.Module):
    def __init__(self,
                 br_size, 
                 fc_hidden_units, 
                 dim_treatments,
                 alpha        = 0.0, 
                 update_alpha = True, 
                 balancing    = 'grad_reverse'):
        super().__init__()
        
        # Parameter
        self.br_size         = br_size
        self.fc_hidden_units = fc_hidden_units
        self.dim_treatments  = dim_treatments
        #
        self.alpha     = alpha if not update_alpha else 0.0
        self.balancing = balancing
        # Network
        self.linear1 = nn.Linear(self.br_size, self.fc_hidden_units)
        self.elu1    = nn.ELU()
        self.linear2 = nn.Linear(self.fc_hidden_units, self.dim_treatments)
        
        self.treatment_head_params = ['linear1', 'linear2']
    
    def forward(self, br, detach=False):
        if detach:
            br = br.detach()

        if self.balancing == 'grad_reverse':
            br = grad_reverse(br, self.alpha)

        x = self.elu1(self.linear1(br))
        treatment = self.linear2(x)  # Softmax is encapsulated into F.cross_entropy()
        return treatment

class BRDosageHead(nn.Module):
    def __init__(self,
                 br_size, 
                 fc_hidden_units, 
                 dim_treatments,
                 num_dosage_samples,
                 alpha        = 0.0, 
                 update_alpha = True, 
                 balancing    = 'grad_reverse'):
        super().__init__()
        
        # Parameter
        self.br_size         = br_size
        self.fc_hidden_units = fc_hidden_units
        self.dim_n_dosage_samples  = num_dosage_samples
        self.dim_treatments  = dim_treatments
        #
        self.alpha     = alpha if not update_alpha else 0.0
        self.balancing = balancing
        # Network
        self.linear1 = nn.Linear(self.br_size + self.dim_treatments, self.fc_hidden_units)
        self.elu1    = nn.ELU()
        self.linear2 = nn.Linear(self.fc_hidden_units, self.dim_n_dosage_samples)
        
        self.dosage_head_params = ['linear1', 'linear2']
    
    def forward(self, br, current_treatment, detach=False):
        if detach:
            br = br.detach()

        if self.balancing == 'grad_reverse':
            br = grad_reverse(br, self.alpha)

        x = torch.cat((br, current_treatment), dim = -1)
        x = self.elu1(self.linear1(x))
        dosage = self.linear2(x)  # Softmax is encapsulated into F.cross_entropy()
        return dosage
        
class ROutcomeHead(nn.Module):
    """Used by G-Net"""

    def __init__(self, seq_hidden_units, r_size, fc_hidden_units, dim_outcome, num_comp, comp_sizes):
        super().__init__()

        self.seq_hidden_units = seq_hidden_units
        self.r_size           = r_size
        self.fc_hidden_units  = fc_hidden_units
        self.dim_outcome      = dim_outcome
        self.num_comp         = num_comp     # 1
        self.comp_sizes       = comp_sizes

        self.linear1 = nn.Linear(self.seq_hidden_units, self.r_size)
        self.elu1 = nn.ELU()

        # Conditional distribution networks init
        self.cond_nets = []
        add_input_dim = 0
        for comp in range(self.num_comp):
            linear2 = nn.Linear(self.r_size + add_input_dim, self.fc_hidden_units)
            elu2    = nn.ELU()
            linear3 = nn.Linear(self.fc_hidden_units, self.comp_sizes[comp])
            self.cond_nets.append(nn.Sequential(linear2, elu2, linear3))

            add_input_dim += self.comp_sizes[comp]

        self.cond_nets = nn.ModuleList(self.cond_nets)

    def build_r(self, seq_output):
        r = self.elu1(self.linear1(seq_output))
        return r

    def build_outcome_vitals(self, r):
        outcome_pred = []
        for cond_net in self.cond_nets:
            out = cond_net(r)
            r = torch.cat((out, r), dim=-1)
            outcome_pred.append(out)
        return torch.cat(outcome_pred, dim=-1)
