import numpy as np
from omegaconf.errors import MissingMandatoryValue
from omegaconf import DictConfig

# pytorch
import torch
from torch import nn

# transformer
from utils_transformer import TransformerEncoderBlock, TransformerDecoderBlock
from utils_transformer import AbsolutePositionalEncoding, RelativePositionalEncoding
from utils_transformer import TransformerMultiInputBlock, LayerNorm

class BuildBR(nn.Module):
    def __init__(self, args):
        super().__init__()
        
        # datasets parameters
        self.dim_treatments      = args.dataset.dim_treatments
        self.dim_static_features = args.dataset.dim_static_features
        self.dim_outcome         = args.dataset.dim_outcomes
        self.dim_dosages         = args.dataset.dim_dosages
        self.n_inputs            = args.dataset.dim_n_inputs
        #
        self.input_size   = self.dim_treatments + self.dim_dosages + self.dim_static_features + self.dim_outcome

        # -------------------------------------------------------------------------------
        # model hyper parameters
        # -------------------------------------------------------------------------------
        self.num_heads          = args.model.encoder.num_heads
        #
        self.br_size            = args.model.encoder.br_size           # balanced representation size
        self.seq_hidden_units   = args.model.encoder.seq_hidden_units
        self.num_layer          = args.model.encoder.num_layer
        self.dropout_rate       = args.model.encoder.dropout_rate
        #
        self.head_size          = args.model.encoder.seq_hidden_units // args.model.encoder.num_heads   
        
        # error check
        if self.seq_hidden_units is None or self.br_size is None or self.dropout_rate is None:
            raise MissingMandatoryValue()
        
        # model setup
        self._init_model(args)
    
    def _init_model(self, args):
        # ----------------------------------------------------------
        # Linear Layer (input transformation)
        # ----------------------------------------------------------
        self.outputs_input_transformation    = nn.Linear(self.dim_outcome,         self.seq_hidden_units) # dynamic feature
        self.static_input_transformation     = nn.Linear(self.dim_static_features, self.seq_hidden_units) # static feature
        self.treatments_input_transformation = nn.Linear(self.dim_treatments,      self.seq_hidden_units) # Treatment
        self.dosages_input_transformation    = nn.Linear(self.dim_dosages,         self.seq_hidden_units) # Dosage
         
        # ----------------------------------------------------------
        # Transformer blocks
        # ----------------------------------------------------------
        # self positional encoding 
        self.self_positional_encoding_k = \
            RelativePositionalEncoding(args.model.encoder.self_positional_encoding.max_relative_position, 
                                       self.head_size,
                                       args.model.encoder.self_positional_encoding.trainable)
        self.self_positional_encoding_v = \
            RelativePositionalEncoding(args.model.encoder.self_positional_encoding.max_relative_position, 
                                       self.head_size,
                                       args.model.encoder.self_positional_encoding.trainable)
        
        # transformer blocks
        self.basic_block_cls = TransformerMultiInputBlock        
        self.transformer_blocks = nn.ModuleList(
            [self.basic_block_cls(self.seq_hidden_units, 
                                  self.num_heads, 
                                  self.head_size, 
                                  self.seq_hidden_units * 4,
                                  self.dropout_rate,
                                  self.dropout_rate if args.model.encoder.attn_dropout else 0.0,
                                  self_positional_encoding_k = self.self_positional_encoding_k,
                                  self_positional_encoding_v = self.self_positional_encoding_v,
                                  n_inputs                   = self.n_inputs,
                                  disable_cross_attention    = args.model.encoder.disable_cross_attention,
                                  isolate_subnetwork         = args.model.encoder.isolate_subnetwork) for _ in range(self.num_layer)]
        )

        # output drop
        self.output_dropout = nn.Dropout(self.dropout_rate)
        
        self.linear1 = nn.Linear(self.seq_hidden_units, self.br_size)
        self.elu1    = nn.ELU()
        
    def forward(self, batch):
        prev_outputs    = batch['inp_x'].float()
        static_features = batch['inp_v'].float()
        prev_treatments = batch['inp_w_prev'].float()
        prev_dosages    = batch['inp_d_prev'].float()
        #
        active_entries  = batch['active_entries']
        active_entries_treat_outcomes = torch.clone(active_entries)
        active_entries_dosage_otucomes = torch.clone(active_entries)
        
        x_o = self.outputs_input_transformation(prev_outputs)
        x_s = self.static_input_transformation(static_features) 
        x_w = self.treatments_input_transformation(prev_treatments)
        x_d = self.dosages_input_transformation(prev_dosages)

        # if active_encoder_br is None and encoder_r is None:  # Only self-attention
        for block in self.transformer_blocks:
            x_w, x_o, x_d = block((x_w, x_o, x_d), x_s, active_entries_treat_outcomes, active_entries_dosage_otucomes)

        x = (x_o + x_w + x_d) / 3
        seq_output = self.output_dropout(x)

        br = self.elu1(self.linear1(seq_output))
        
        return br