import numpy as np
from omegaconf.errors import MissingMandatoryValue
from omegaconf import DictConfig

# pytorch
import torch
from torch import nn

# transformer
from utils_transformer import TransformerEncoderBlock, TransformerDecoderBlock
from utils_transformer import AbsolutePositionalEncoding, RelativePositionalEncoding
from utils_transformer import LayerNorm

class BuildEDBR(nn.Module):
    def __init__(self, args, isDecoder = False):
        super().__init__()
        
        #
        self.isDecoder = isDecoder
        
        self.dim_treatments      = args.dataset.dim_treatments
        self.dim_static_features = args.dataset.dim_static_features
        self.dim_outcome         = args.dataset.dim_outcomes
        self.dim_dosages         = args.dataset.dim_dosages
        
        self.input_size          = self.dim_treatments + self.dim_dosages + self.dim_static_features + self.dim_outcome
        
        # -------------------------------------------------------------------------------
        # model hyper parameters
        # -------------------------------------------------------------------------------
        self.num_heads          = args.model.encoder.num_heads        #
        
        if self.isDecoder:
            self.br_size            = args.model.decoder.br_size          # br_sizeはEncoderとそろえる For CRN, EDCT
            self.seq_hidden_units   = args.model.encoder.br_size          #
            self.num_layer          = args.model.decoder.num_layer        #
            self.dropout_rate       = args.model.decoder.dropout_rate     #
        else:
            self.seq_hidden_units   = args.model.encoder.seq_hidden_units #
            self.br_size            = args.model.encoder.br_size          #
            self.num_layer          = args.model.encoder.num_layer        #         
            self.dropout_rate       = args.model.encoder.dropout_rate     # 

        self.head_size = self.seq_hidden_units // self.num_heads
        
        # error check
        if self.seq_hidden_units is None or self.br_size is None or self.dropout_rate is None:
            raise MissingMandatoryValue()
        
        # model setup
        self._init_model(args)
    
    def _init_model(self, args):
        # ----------------------------------------------------------
        # Linear Layer (input transformation)
        # ----------------------------------------------------------
        self.input_transformation = nn.Linear(self.input_size, self.seq_hidden_units) 
        
        # ----------------------------------------------------------
        # Transformer blocks
        # ----------------------------------------------------------
        # self positional encoding 
        if self.isDecoder:
            self.self_positional_encoding_k = \
                RelativePositionalEncoding(args.model.decoder.self_positional_encoding.max_relative_position, 
                                           self.head_size,
                                           args.model.decoder.self_positional_encoding.trainable)
            self.self_positional_encoding_v = \
                RelativePositionalEncoding(args.model.decoder.self_positional_encoding.max_relative_position, 
                                           self.head_size,
                                           args.model.decoder.self_positional_encoding.trainable)
            self.cross_positional_encoding_k = \
                RelativePositionalEncoding(args.model.decoder.cross_positional_encoding.max_relative_position, 
                                           self.head_size,
                                           args.model.decoder.cross_positional_encoding.trainable, 
                                           cross_attn=True)
            self.cross_positional_encoding_v = \
                RelativePositionalEncoding(args.model.decoder.cross_positional_encoding.max_relative_position, 
                                           self.head_size,
                                           args.model.decoder.cross_positional_encoding.trainable,
                                           cross_attn=True) 
            
            self.basic_block_cls = TransformerDecoderBlock 
            
            self.transformer_blocks = nn.ModuleList(
                [self.basic_block_cls(self.seq_hidden_units, 
                                      self.num_heads, 
                                      self.head_size, 
                                      self.seq_hidden_units * 4,
                                      self.dropout_rate,
                                      self.dropout_rate,
                                      self_positional_encoding_k = self.self_positional_encoding_k,
                                      self_positional_encoding_v = self.self_positional_encoding_v,
                                      # cross attention
                                      cross_positional_encoding_k=self.cross_positional_encoding_k,
                                      cross_positional_encoding_v=self.cross_positional_encoding_v
                                     ) for _ in range(self.num_layer)]
            )
        else:
            self.self_positional_encoding_k = \
                RelativePositionalEncoding(args.model.encoder.self_positional_encoding.max_relative_position, 
                                           self.head_size,
                                           args.model.encoder.self_positional_encoding.trainable)
            self.self_positional_encoding_v = \
                RelativePositionalEncoding(args.model.encoder.self_positional_encoding.max_relative_position, 
                                           self.head_size,
                                           args.model.encoder.self_positional_encoding.trainable) 
            
            self.basic_block_cls = TransformerEncoderBlock
            
            self.transformer_blocks = nn.ModuleList(
                [self.basic_block_cls(self.seq_hidden_units, 
                                      self.num_heads, 
                                      self.head_size, 
                                      self.seq_hidden_units * 4,
                                      self.dropout_rate,
                                      self.dropout_rate,
                                      self_positional_encoding_k = self.self_positional_encoding_k,
                                      self_positional_encoding_v = self.self_positional_encoding_v,
                                     ) for _ in range(self.num_layer)]  
            )

        # output drop
        self.output_dropout = nn.Dropout(self.dropout_rate)
                
        self.linear1 = nn.Linear(self.seq_hidden_units, self.br_size)
        self.elu1    = nn.ELU()
        
    def forward(self, batch):
        #input
        prev_outputs    = batch['inp_x'].float()
        static_features = batch['inp_v'].float()
        prev_treatments = batch['inp_w_prev'].float()
        prev_dosages    = batch['inp_d_prev'].float()
        
        # active entries
        active_entries  = batch['active_entries']   
        
        if self.isDecoder:        
            encoder_br        = batch['encoder_br']
            active_encoder_br = batch['active_encoder_br']
        
        #   
        # 入力を全結合
        #
        x = torch.cat((prev_treatments, prev_outputs), dim = -1)
        x = torch.cat((x, prev_dosages), dim = -1)
        x = torch.cat((x, static_features), dim = -1)
        x = self.input_transformation(x)

        if self.isDecoder:
             # Both self-attention and cross-attention
            assert x.shape[-1] == encoder_br.shape[-1]
                
            for block in self.transformer_blocks:
                x = block(x, encoder_br, active_entries, active_encoder_br)
    
        else:
            # Only self-attention 
            for block in self.transformer_blocks:
                x = block(x, active_entries)
        
        seq_output = self.output_dropout(x)
        br = self.elu1(self.linear1(seq_output))
        
        return br                
                