import numpy as np
import torch
import torch.nn as nn

# -------------------------------------------------
# Generator for one step
# -------------------------------------------------
class GenCFModel(nn.Module):
    # -------------------------------------------------------------
    # INIT
    # -------------------------------------------------------------
    def __init__(self, num_treatments, num_dosage_samples, dict_wd, multitask_input_dim, multitask_hidden_dim):
        super().__init__()
        # 
        self.num_treatments        = num_treatments
        self.num_dosage_samples    = num_dosage_samples
        self.dict_wd               = dict_wd
        
        # multitask layer
        self.multitask_input_dim   = multitask_input_dim
        self.multitask_hidden_dim  = multitask_hidden_dim

        # -------------------------------------------------------------------------
        # Counterfactual Layer (Multi-task Layer)
        # -------------------------------------------------------------------------
        # No treatment (w = 0)
        w0_l1 = torch.nn.Linear(self.multitask_input_dim, self.multitask_hidden_dim)
        w0_a1 = torch.nn.ELU()
        w0_l3 = torch.nn.Linear(self.multitask_hidden_dim, 1)
        self.module_list_w0 = nn.ModuleList([w0_l1, w0_a1, w0_l3])
        
        # radio (w = 1) 
        w1_l1 = torch.nn.Linear(self.multitask_input_dim, self.multitask_hidden_dim)
        w1_a1 = torch.nn.ELU()
        w1_l3 = torch.nn.Linear(self.multitask_hidden_dim, 1)
        self.module_list_w1 = nn.ModuleList([w1_l1, w1_a1, w1_l3])
        
        # chemo (w = 2)
        w2_l1 = torch.nn.Linear(self.multitask_input_dim, self.multitask_hidden_dim)
        w2_a1 = torch.nn.ELU()
        w2_l3 = torch.nn.Linear(self.multitask_hidden_dim, 1)
        self.module_list_w2 = nn.ModuleList([w2_l1, w2_a1, w2_l3])
        
        # radio + chemo (w = 3)
        w3_l1 = torch.nn.Linear(self.multitask_input_dim, self.multitask_hidden_dim)
        w3_a1 = torch.nn.ELU()
        w3_l3 = torch.nn.Linear(self.multitask_hidden_dim, 1)
        self.module_list_w3 = nn.ModuleList([w3_l1, w3_a1, w3_l3])
        
        #
        self.module_list = [self.module_list_w0, 
                            self.module_list_w1, 
                            self.module_list_w2, 
                            self.module_list_w3]
        
        self.gencf_head_params = \
        ['module_list_w0', 'module_list_w1', 'module_list_w2', 'module_list_w3']
        
    # -------------------------------------------------------------
    # FORWARD
    # -------------------------------------------------------------
    def forward(self, br, z, d = None):
        # initialize drcurves 
        num_patients, max_length, _ = br.shape
        
        #
        dr_curves = torch.zeros([num_patients, 
                                 max_length,
                                 self.num_treatments,
                                 self.num_dosage_samples])
        
        # -------------------------------------------------------------------------
        # Counterfactual Layer (Multi-task Layer)
        # -------------------------------------------------------------------------
        for w in range(self.num_treatments):
            if d is None:
                inp_x = torch.cat([br, z[:, :, w, i_d, :]], axis = -1)
                
                # forward
                for f in self.module_list[w]:
                    inp_x = f(inp_x)

                dr_curves[:, :, w, 0] = inp_x.reshape(num_patients, max_length)
                
            else:
                for i_d, _ in enumerate(self.dict_wd[w]):                    
                    inp_x = torch.cat([br, z[:, :, w, i_d, :]], axis = -1)
                    inp_x = torch.cat([inp_x, d[:, :, w, i_d, :]], axis = -1)

                    # forward
                    for i,f in enumerate(self.module_list[w]):
                        inp_x = f(inp_x)

                    dr_curves[:, :, w, i_d] = inp_x.reshape(num_patients, max_length)

        return dr_curves
    
# --------------------------------------------
# discriminator (Treatment)
# --------------------------------------------
class DiscTreatModel(nn.Module):
    # -------------------------------------------------------------
    # INIT
    # -------------------------------------------------------------
    def __init__(self,
                 dict_wd, 
                 br_size,
                 fc_hidden_dim,
                 i_inv_eqv_dim, 
                 h_inv_eqv_dim
                 ):
        
        super().__init__()
        
        # information Treatment dosage pairs
        self.dict_wd        = dict_wd
        self.num_treatments = len(dict_wd)

        # Layer dim
        self.i_inv_eqv_dim = i_inv_eqv_dim
        self.h_inv_eqv_dim = h_inv_eqv_dim
        self.br_size       = br_size
        self.fc_hidden_dim = fc_hidden_dim
        
        # -------------------------------------------------------------------------
        # Representaion layer
        # -------------------------------------------------------------------------
        self.l1 = torch.nn.Linear(self.br_size, self.h_inv_eqv_dim)
        self.a1 = torch.nn.ELU()
        
        # -------------------------------------------------------------------------
        # Invariant Layer
        # -------------------------------------------------------------------------
        # No treatment (w = 0)
        w0_l = torch.nn.Linear(self.i_inv_eqv_dim, self.h_inv_eqv_dim)
        w0_a = torch.nn.ELU()
        self.module_list_w0 = nn.ModuleList([w0_l, w0_a])
        
        # radio (w = 1)
        w1_l = torch.nn.Linear(self.i_inv_eqv_dim, self.h_inv_eqv_dim)
        w1_a = torch.nn.ELU()
        self.module_list_w1 = nn.ModuleList([w1_l, w1_a])        
        
        # chemo (w = 2)
        w2_l = torch.nn.Linear(self.i_inv_eqv_dim, self.h_inv_eqv_dim)
        w2_a = torch.nn.ELU()
        self.module_list_w2 = nn.ModuleList([w2_l, w2_a])
        
        # radio+chemo (w = 3)
        w3_l = torch.nn.Linear(self.i_inv_eqv_dim, self.h_inv_eqv_dim)
        w3_a = torch.nn.ELU()
        self.module_list_w3 = nn.ModuleList([w3_l, w3_a])
        
        
        self.module_list = [self.module_list_w0, self.module_list_w1, 
                            self.module_list_w2, self.module_list_w3] 
        
        # FC Layer
        self.f_l1 = torch.nn.Linear(self.h_inv_eqv_dim * (self.num_treatments + 1), self.fc_hidden_dim )
        self.f_a1 = torch.nn.ELU()
        self.f_l2 = torch.nn.Linear(self.fc_hidden_dim, self.num_treatments)
        
        self.treatment_head_params = \
        ['module_list_w0', 'module_list_w1', 'module_list_w2', 'module_list_w3', 'f_l1', 'f_l2']
        
    # -------------------------------------------------------------
    # FORWARD
    # -------------------------------------------------------------
    def forward(self, br, dr_curves, d):
        
        num_patients, max_length, _, _ = dr_curves.shape 
        
        D_treatment_outcomes = dict()
                
        # -------------------------------------------------------------------------
        # Representation Layer
        # -------------------------------------------------------------------------        
        x = self.l1(br)
        patient_features_representaion = self.a1(x)
        
        # ------------------------------------------------
        # Invariant Layer
        # ------------------------------------------------
        for w in range(self.num_treatments):
            dosages = self.dict_wd[w]

            # Invariant Input
            if w == 0:
                dr_curve     = torch.unsqueeze(torch.unsqueeze(dr_curves[:, :, w, 0], dim = -1), dim = -1)
                tile_dosages = torch.unsqueeze(d[:, :, w, 0, :], dim = -1)             
            else:
                dr_curve     = torch.unsqueeze(dr_curves[:, :, w, :], dim = -1)
                tile_dosages = d[:, :, w, :, :]
            
            inv_inputs  = torch.cat([dr_curve, tile_dosages], dim = -1)
                                             
            # Invariant Layer
            # forward
            inp_x = inv_inputs
            for f in self.module_list[w]:
                inp_x = f(inp_x)

            # invariant sum
            D_treatment_rep = torch.sum(inp_x, axis = 2)
            D_treatment_outcomes[w] = D_treatment_rep

        # ------------------------------------------------
        # FC Layer
        # ------------------------------------------------ 
        D_treatment_representations = torch.cat(list(D_treatment_outcomes.values()), axis=-1)
        
        D_shared_representation = torch.cat([D_treatment_representations, 
                                             patient_features_representaion], axis=-1)
        
        
        h1 = self.f_l1(D_shared_representation)
        D_treatment_logits = self.f_a1(h1)
        D_treatment_logits = self.f_l2(D_treatment_logits)

        return D_treatment_logits

    
# --------------------------------------------
# Discriminator (Dosage)
# --------------------------------------------
class DiscDosageModel(nn.Module):
    # -------------------------------------------------------------
    # INIT
    # -------------------------------------------------------------
    def __init__(self, dosages, br_size, i_inv_eqv_dim, h_inv_eqv_dim):
        super().__init__()
        
        # information Treatment dosage pairs
        self.dosages        = dosages
        self.br_size        = br_size       
        self.i_inv_eqv_dim  = i_inv_eqv_dim
        self.h_inv_eqv_dim  = h_inv_eqv_dim

        # -------------------------------------------------------------------------
        # Representaion layer
        # -------------------------------------------------------------------------
        self.l1 = torch.nn.Linear(self.br_size, self.h_inv_eqv_dim)
        self.a1 = torch.nn.ELU()

        # -------------------------------------------------------------------------
        # Equivariant layer
        # -------------------------------------------------------------------------
        # Equivariant layer 1
        # gamma   γ
        # lambda  λ 
        self.e1_l_gamma  = torch.nn.Linear(self.i_inv_eqv_dim, self.h_inv_eqv_dim) 
        self.el_l_lambda = torch.nn.Linear(self.i_inv_eqv_dim, self.h_inv_eqv_dim, bias = False)
        self.e1_a        = torch.nn.ELU()
    
        # Equivariant layer 2
        self.e2_l_gamma  = torch.nn.Linear(self.h_inv_eqv_dim, self.h_inv_eqv_dim)
        self.e2_l_lambda = torch.nn.Linear(self.h_inv_eqv_dim, self.h_inv_eqv_dim, bias = False)
        self.e2_a        = torch.nn.ELU()
        
        # -------------------------------------------------------------------------
        # Logits layer
        # -------------------------------------------------------------------------
        self.D_logits_dosage = torch.nn.Linear(self.h_inv_eqv_dim, 1)
        
        self.dosage_head_params = \
        ['l1', 'a1', 'e1_l_gamma', 'el_l_lambda', 'e2_l_gamma', 'el_2_lambda', 'D_logits_dosage']
        
    # -------------------------------------------------------------
    # FORWARD
    # -------------------------------------------------------------
    def forward(self, br, dr_curve, tile_dosages):       
        num_patients, _, _ = dr_curve.shape 
        
        # -------------------------------------------------------------------------
        # Representation Layer
        # -------------------------------------------------------------------------  
        x = self.l1(br)
        patient_features_representaion = torch.unsqueeze(self.a1(x), dim = -2)
            
        # -------------------------------------------------------------------------
        # Equivariant Layer
        # -------------------------------------------------------------------------     
        # Equivariant Input
        dr_curve     = torch.unsqueeze(dr_curve, dim = -1)
        tile_dosages = torch.unsqueeze(tile_dosages, dim = -1)
        
        eqiv_inputs  = torch.cat([dr_curve, tile_dosages], dim = -1)
        
        # Equivariant Layer1
        e1_l_gamma   = self.e1_l_gamma(eqiv_inputs)
        e1_inputs_xm = torch.sum(eqiv_inputs, axis = 2, keepdim = True) # true 1
        e1_l_lambda  = self.el_l_lambda(e1_inputs_xm)
        e1_equiv_out = e1_l_gamma - e1_l_lambda  
        D_h1         = self.e1_a(e1_equiv_out + patient_features_representaion) 
        
        # Equivariant Layer2
        e2_l_gamma   = self.e2_l_gamma(D_h1)
        e2_inputs_xm = torch.sum(D_h1, axis = 2, keepdim = True) # true 1
        e2_l_lambda  = self.e2_l_lambda(e2_inputs_xm)
        e2_equiv_out = e2_l_gamma - e2_l_lambda
        e2_out       = self.e2_a(e2_equiv_out)
        
        # -------------------------------------------------------------------------
        # Logits Layer
        # -------------------------------------------------------------------------
        D_logits_dosage = self.D_logits_dosage(e2_out)
        D_dosage_outcomes = torch.squeeze(D_logits_dosage, axis = -1)
    
        return D_dosage_outcomes