from tensorflow.python.keras.engine.base_layer import Layer
import numpy as np
from collections.abc import Sequence, Mapping
from einops import rearrange
import tensorflow as tf
import tensorflow_probability as tfp
from typing import TYPE_CHECKING, Union, Any

try:
    from typing import Literal
except ImportError:
    from typing_extensions import Literal

if TYPE_CHECKING:  # Workaround for VS Code intellisense
    # from tensorflow.python import keras
    from keras.api._v2 import keras
    from tensorflow_probability.python.distributions import Normal, Bernoulli
else:
    keras = tf.keras
    Normal = tfp.distributions.Normal
    Bernoulli = tfp.distributions.Bernoulli

# parameter는 전부 원래 pytorch code 따라서 감. https://github.com/WenjieDu/SAITS/blob/235bec1409dd6323100f33502e94d8af1dd2dd47/modeling/layers.py#L23

class PositionalEncoding(keras.layers.Layer):
    def __init__(self, d_feature, n_position=300): # 300은 고정
        # Not a parameter
        self.d_feature = d_feature
        self.n_position = n_position
        super().__init__()

    def get_position_angle_vec(self,position):
        return [position / np.power(10000, 2 * (feature // 2) / self.d_feature) for feature in range(self.d_feature)] # feature 축에 대한 함수 list
    
    def _get_sinusoid_encoding_table(self, n_position):

        sinusoid_table = tf.constant([self.get_position_angle_vec(pos_i) for pos_i in range(1,self.n_position+1)],dtype=tf.float32) # [n_position , feature]의 tensor
        cond = tf.tile(tf.expand_dims(tf.constant([True,False]*(self.d_feature//2)+[True]*(self.d_feature%2)),0),[self.n_position,1]) # n_pos,feature 짝수 index가 1
        # tf.print(tf.shape(cond),tf.shape(sinusoid_table))
        sinusoid_table=tf.where(cond,tf.sin(sinusoid_table),tf.cos(sinusoid_table))
        #sinusoid_table[:, 0::2] = tf.sin(sinusoid_table[:, 0::2])  # dim 2i
        #sinusoid_table[:, 1::2] = tf.cos(sinusoid_table[:, 1::2])  # dim 2i+1
        return sinusoid_table

    def call(self, x):
        pos_table = self._get_sinusoid_encoding_table(self.n_position)
        # tf.print(tf.shape(pos_table))
        return x + pos_table  # broadcasting 될 것임


class ScaledDotProductAttention(keras.layers.Layer):
    """scaled dot-product attention"""

    def __init__(self, temperature, attn_dropout=0.1):
        super().__init__()
        self.temperature = temperature
        self.dropout = keras.layers.Dropout(rate=attn_dropout)
        self.softmax = keras.layers.Softmax(axis=-1)
        # query: shape --> (batch,n_head,time,d) 
    def call(self, q, k, v,mask):
        attn = tf.matmul(q / tf.constant(self.temperature,dtype=float), tf.transpose(k,perm=[0,1,3,2])) # batch,n_head,t,t
        masked = tf.cast(tf.expand_dims(tf.expand_dims(mask, 1), -1), tf.float32) # 128,n_head,t,t
        ones = attn*0+1
        #if attn_mask is not None: # (1,1,time,time)b0[]
        #    attn = tf.where(attn_mask==1,tf.constant(-1e09,dtype=float),attn)
        attn_mask = tf.linalg.band_part(ones, 0, 0) # ==> Diagonal이 0. # diagonal이고 , length로 sequence mask 만들어서 빼주며될듯? padding mask만들기 --그냥 saits call안에 length 들고와서 그걸로 mask 만들자
        attn_mask = 1.-attn_mask # 사실 이게 원하는 거임..잘못구현했는데?
        attn_mask = attn_mask*masked 
        # padding mask - length 줘야됨. broadcast?
        # padding mask 곱해 줘야된다. tf keras softmax layer를 쓰는데 거기에 padding mask를 같이줌
        # attn_mask를 쓰지말고, softmax layer로 바꿔서 그 diagonal하고 padding부분에 0이고 나머지 다 1인 mask matrix를 같이 넣어주면 알아서 계산됨.
        attn = self.dropout(self.softmax(attn, attn_mask))
        output = tf.matmul(attn, v) # <- 이부분 에러 의심된다 아 차원 개헷갈리네 (batch,n_head,t,d)
        return output, attn


class MultiHeadAttention(keras.layers.Layer):
    """original Transformer multi-head attention"""

    def __init__(self, n_head, d_model, d_k, d_v, attn_dropout):
        super().__init__()

        self.n_head = n_head
        self.d_k = d_k # 
        self.d_v = d_v

        self.w_qs = keras.layers.Dense(n_head * d_k, use_bias=False) # keras dense는 (B,T,F)로 인풋 주면 (B,T,hidden)이렇게 나옴. 그니까 time series에 무지성 적용 가능
        self.w_ks = keras.layers.Dense(n_head * d_k, use_bias=False) # 전부 d_model이 hidden state임
        self.w_vs = keras.layers.Dense(n_head * d_v, use_bias=False)

        self.attention = ScaledDotProductAttention(d_k ** 0.5, attn_dropout)
        self.fc = keras.layers.Dense(d_model, use_bias=False)

    def call(self, q, k, v,mask): # q,k,v = enc_output을 넣음.  이는 (batch,seq,embedding)이라고 보면 된다.? positional encoding 까지 추가
        d_k, d_v, n_head = self.d_k, self.d_v, self.n_head                                   # embedding size d_model / n_head = d_k.
        sz_b, len_q, len_k, len_v = tf.shape(q)[0],tf.shape(q)[1],tf.shape(k)[1],tf.shape(v)[1]

        # Pass through the pre-attention projection: b x lq x (n*dv)     (batch,len_query,num_head*value_dim (d_model))
        # Separate different heads: b x lq x n x dv
        q = rearrange(self.w_qs(q), "b lq (n d) -> b lq n d",n=self.n_head)
        k = rearrange(self.w_qs(k), "b lq (n d) -> b lq n d",n=self.n_head)
        v = rearrange(self.w_qs(v), "b lq (n d) -> b lq n d",n=self.n_head)

        # Transpose for attention dot product: b x n x lq x dv
        q, k, v = tf.transpose(q,perm=[0,2,1,3]),tf.transpose(k,perm=[0,2,1,3]), tf.transpose(v,perm=[0,2,1,3])
    
        #if attn_mask is not None:
            # this mask is imputation mask, which is not generated from each batch, so needs broadcasting on batch dim
        #    attn_mask = tf.expand_dims(tf.expand_dims(attn_mask, axis=0), axis=1)  # For batch and head axis broadcasting in scaled dot product
            # attn_mask (1,1,d_time,d_time) 형태가 되고
        v, attn_weights = self.attention(q, k, v,mask)

        # Transpose to move the head dimension back: b x lq x n x dv
        # Combine the last two dimensions to concatenate all the heads together: b x lq x (n*dv)
        v = tf.reshape(tf.transpose(v,perm=[0,2,1,3]),shape=[sz_b,len_q,-1])
        v = self.fc(v)
        return v, attn_weights
    
class PositionWiseFeedForward(keras.layers.Layer):
    def __init__(self, d_in, d_hid, dropout=0.1):
        super().__init__()
        self.w_1 = keras.layers.Dense(d_hid)
        self.w_2 = keras.layers.Dense(d_in)
        self.layer_norm = keras.layers.LayerNormalization(epsilon=1e-6) # 차원? 뭘 기준으로 normalize한다는거냐 (feature,seq기준임?)
        self.dropout = keras.layers.Dropout(rate=dropout)

    def forward(self, x):
        residual = x
        x = self.layer_norm(x)
        x = self.w_2(tf.nn.relu(self.w_1(x)))
        x = self.dropout(x)
        x += residual
        return x

# measurement가 missing mask에 해당. SAITS의 경우 artificially masking을 줘야함
# 얘네도 input data 형태는 (batch,seq,feature)임.
# input dim의 feature를 임베딩을 통해서 d_model로 만들어줌. 즉 이거부터 시작임. 그래서 input 크기가 (batch,seq,d_model). query

class EncoderLayer(keras.layers.Layer):
    def __init__(self, d_feature, d_model, d_inner, n_head, d_k, d_v, dropout=0.1, attn_dropout=0.1):
        super(EncoderLayer, self).__init__()

        self.diagonal_attention_mask = True
        self.d_feature = d_feature

        self.layer_norm = keras.layers.LayerNormalization(epsilon=1e-6) # sgiykd cgabge
        self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, attn_dropout)
        self.dropout = keras.layers.Dropout(rate=dropout)
        self.pos_ffn = PositionWiseFeedForward(d_model, d_inner, dropout)

    def call(self, enc_input,mask): # 이부분에 length 추가하기

        residual = enc_input
        # here we apply LN before attention cal, namely Pre-LN, refer paper https://arxiv.org/abs/2002.04745
        enc_input = self.layer_norm(enc_input)
        enc_output, attn_weights = self.slf_attn(enc_input, enc_input, enc_input,mask) # 여기에 lengths넣기
        enc_output = self.dropout(enc_output)
        enc_output += residual

        enc_output = self.pos_ffn(enc_output)
        return enc_output, attn_weights
        

class SAITSMODEL(keras.Model):
    
    def __init__(
        self,
        output_activation,
        output_dims,
        n_groups:int = 20 , 
        n_hidden: int = 128,
        n_group_inner_layers:int = 128,
        d_model:int = 256, 
        d_inner:int = 512, 
        n_head:int = 1, 
        d_k:int = 32, 
        d_v:int = 32, 
        dropout: float = 0.0,
        n_units: int = 128,
        train_type: str = 'pretrain',
        MIT: bool = True,
        input_with_mask: bool = True,
        MIT_missing_rate: float = 0.2
        ): # params추가해야됨
        super().__init__()
        self._config = {
            name: val for name, val in locals().items()
            if name not in ['self', '__class__']
        }
        self.output_activation = output_activation
        self.output_dims = output_dims
        self.n_groups = n_groups
        self.n_group_inner_layers = n_group_inner_layers
        self.d_model = d_model
        self.d_inner = d_inner
        self.n_head = n_head
        self.d_k = d_k
        self.d_v = d_v
        self.n_units=n_units
        self.dropout_rate=dropout
        self.input_with_mask = input_with_mask
        self.param_sharing_strategy = False
        self.MIT = MIT
        self.MIT_missing_rate = MIT_missing_rate
        self.transformer_blocks = []
        self.add = keras.layers.Add()
        self.transformer_blocks2 = []
        self.train_type = train_type
        self.n_units = n_units
        self.n_hidden = n_hidden
        self.positional_encoding = PositionalEncoding(d_feature=self.d_model,n_position=300)

    '''
    def get_position_angle_vec(self,position):
        return [position / np.power(10000, 2 * (feature // 2) / self.d_model) for feature in range(self.d_model)] # feature 축에 대한 함수 list
    
    def _get_sinusoid_encoding_table(self, n_position):

        sinusoid_table = np.array([self.get_position_angle_vec(pos_i) for pos_i in range(n_position)],dtype=float) # [n_position , feature] 의 tensor
        sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2])  # dim 2i
        sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2])  # dim 2i+1
        return tf.constant(sinusoid_table,dtype=tf.float32)
    
    def positional_encoding2(self,x):
        table = self._get_sinusoid_encoding_table(300)
        return x + table
    '''

    def preprocess(self,measurements,artificial_mis): # artificial masking 여기서 포함
                
        artificial_mask = tfp.distributions.Bernoulli(probs=1-artificial_mis).sample(sample_shape=tf.shape(measurements))
        masks = tf.cast(measurements, tf.float32)*tf.cast(artificial_mask,tf.float32) # hat(M)에 해당 - artificially masked된 자리까지 missing으로 보는 mask
        indicator_masks = (1.-masks)*tf.cast(measurements, tf.float32) ## make indicator - artificially masked된 자리만 1로 찍힌 mask
        return masks,indicator_masks


    def data_preprocessing_fn(self):  # 전부 length 300으로 패딩하기, batch 차원이 없음  여기는
        def add_time_dim(inputs, label):
            demo, times, values, measurements, lengths = inputs # (none,16)
            times = tf.expand_dims(times, -1)
            return (demo, times, values, measurements, lengths), label
        return add_time_dim
    def build(self, input_shape):        
        _, times_shape, values_shape, _, _ = input_shape
        self.d_feature = values_shape[-1]
        self.actual_d_feature = self.d_feature * 2 if self.input_with_mask else self.d_feature
        self.dropout = keras.layers.Dropout(rate=self.dropout_rate)
        # self.position_enc = PositionalEncoding(d_model, n_position=200) <- 제대로 구현해서 살려서 쓰기
        # for operation on time dim
        self.embedding_1 = keras.layers.Dense(self.d_model,name='embed1')
        self.reduce_dim_z =  keras.layers.Dense(self.d_feature,name='reduce_z_1')
        # for operation on measurement dim
        self.embedding_2 =  keras.layers.Dense(self.d_model,name='embed2')
        self.reduce_dim_beta = keras.layers.Dense(self.d_feature,name='reduce_beta')
        self.reduce_dim_gamma = keras.layers.Dense(self.d_feature,name='reduce_gamma')
        # for delta decay factor
        self.weight_combine = keras.layers.Dense(self.d_feature,name='weight_combine')
        # self.gen_list = [self.embedding_1,self.reduce_dim_z,self.embedding_2,self.reduce_dim_beta,self.reduce_dim_gamma,self.weight_combine]     
        #self.add = tf.keras.layers.Add()
        # self.positional_encoding.build(times_shape)

        for i in range(self.n_groups): # since param_sharing_strategy = inner
            
            transformer_block = EncoderLayer(self.actual_d_feature, self.d_model, self.d_inner, self.n_head, self.d_k, self.d_v, self.dropout_rate, 0)

            self.transformer_blocks.append(transformer_block)
            setattr(self, f'transformer_{i}', transformer_block)
            
        for i in range(self.n_groups): # since param_sharing_strategy = inner
            
            transformer_block = EncoderLayer(self.actual_d_feature, self.d_model, self.d_inner, self.n_head, self.d_k, self.d_v, self.dropout_rate, 0)

            self.transformer_blocks2.append(transformer_block)
            setattr(self, f'transformer2_{i}', transformer_block)
            
        if self.train_type in ['pretrain','joint','classifier']: # 우선 gru로 classifier 고정
            self.classifier_gru = keras.layers.GRU(
                self.n_units, return_sequences=False, dropout=self.dropout_rate,
                name="classifier/gru",
            )
            self.classifier_dense = keras.layers.Dense(self.output_dims, activation=self.output_activation, name="classifier/dense")
            self.initial_encoder = keras.Sequential([
                keras.layers.Dense(self.n_hidden, activation=tf.nn.tanh, name="initial_encoder/dense1"),
                keras.layers.Dense(self.n_units,  activation=tf.nn.tanh, name="initial_encoder/dense2"),
            ], name="initial_encoder")
            
    def call(self,inputs,training=False, output=None,return_y=True):  # equivalent to impute and calc loss
        statics, times, values, measurements, lengths = inputs
        pad_vm = [[0,0],[0, 300 - tf.shape(values)[1]],[0,0]] # value shape broadcasting되게 적용
        values = tf.pad(values,pad_vm,mode='CONSTANT')                                                  # [n_batch, 300, x_dim]
        # X = values
        
        measurements = tf.pad(measurements,pad_vm,mode='CONSTANT')                                      # [n_batch, 300, x_dim]
        masks,indicator_masks = self.preprocess(measurements,artificial_mis=self.MIT_missing_rate)
        masks = masks if self.MIT and training else tf.cast(measurements, tf.float32)
        if training and self.MIT:
            X = values*masks
        else:
            X= values
        input_X_for_first = tf.concat([X, tf.cast(measurements, tf.float32)], 2) if self.input_with_mask else X                     # [n_batch, 300, x_dim*2 or x_dim]
        input_X_for_first = self.embedding_1(input_X_for_first)                                                              # [n_batch, 300, d_model]
        input_X_for_first = self.positional_encoding(input_X_for_first)
        
        if len(lengths.get_shape()) == 2:
            lengths = tf.squeeze(lengths, -1)

        mask = tf.sequence_mask(lengths) # batch, max len
        paddings = [[0, 0,], [0, 300-tf.shape(mask)[1]]] # 이게 중요함 지금 -> 데이터 프리프로세싱 자체를 바꿔야함
        mask = tf.pad(mask,paddings,mode='CONSTANT')
        enc_output = self.dropout(input_X_for_first)  # w/o positional encoding. 포지셔널 넣으려면 self.position_enc(input_X_for_first) # [n_batch, 300, d_model]
        
        for encoder_layer in self.transformer_blocks:
            enc_output,_ = encoder_layer(enc_output,mask)                                                   # [n_batch,300,d_model]                               
        
        X_tilde_1 = self.reduce_dim_z(enc_output)                                                           # [n_batch,300,x_dim] 
        X_prime = masks * X + (1 - masks) * X_tilde_1                                                       # [n_batch,300,x_dim] 
        
        input_X_for_second = tf.concat([X_prime, tf.cast(measurements, tf.float32)], 2) if self.input_with_mask else X_prime            # [n_batch, 300, x_dim*2 or x_dim]
        input_X_for_second = self.embedding_2(input_X_for_second)                                           # [n_batch, 300, d_model]
        enc_output = self.positional_encoding(input_X_for_second)
        
        for encoder_layer in self.transformer_blocks2:
            enc_output,attn_weights = encoder_layer(enc_output,mask)
        
        X_tilde_2 = self.reduce_dim_gamma(tf.nn.relu(self.reduce_dim_beta(enc_output)))                     # [n_batch,300,x_dim]
        
        attn_weights = tf.transpose(attn_weights,perm=[0,3,2,1]) # batch,t,t,head
        attn_weights = tf.reduce_mean(attn_weights,3) # batch,t,t
        attn_weights = tf.transpose(attn_weights,perm=[0,2,1]) 
        combining_weights = keras.activations.sigmoid(self.weight_combine(tf.concat([masks, attn_weights], 2))) # weight combine input의 크기는 d_feature+d_time.
        # 우리 데이터는 미니배치마다 sequence length가 다르기 때문에 이걸 적용하면 뉴럴넷의 인풋크기가 매 배치마다 달라지게 된다. 따라서 학습이 안됨.
        X_tilde_3 = (1 - combining_weights) * X_tilde_2 + combining_weights * X_tilde_1
        X_c = masks * X + (1 - masks) * X_tilde_3  # replace non-missing part with original data - 이게 x_gen과 같음 shape = []
        
        ########## loss function 계산 #############
        def reconstruct_loss(x_tilde,x,masks):
            return tf.reduce_sum(abs(x_tilde-x)*masks)/(tf.reduce_sum(masks)+1e-09)
        
        L_ORT = (reconstruct_loss(X_tilde_1,X,masks)+
                 reconstruct_loss(X_tilde_2,X,masks)+
                 reconstruct_loss(X_tilde_3,X,masks))/3        
        # tf.print(L_ORT)
        if self.MIT and training:
            L_MIT = reconstruct_loss(X_c,X,indicator_masks)
            loss=L_MIT
        else:
            loss=0

        # if self.train_type in ['joint','classifier']:
        padding_mask = mask
        initial_state = self.initial_encoder(statics)
        
        if self.train_type=='classifier':
            X_c = tf.stop_gradient(X_c)
        else:
            pass

        v=self.classifier_gru(X_c, mask=padding_mask, initial_state=initial_state)
        logits = self.classifier_dense(v)   
        labels = tf.ones_like(logits) if output is None else tf.cast(output, tf.float32)
        p_y = Bernoulli(logits=logits) 
        log_p_y = p_y.log_prob(labels) # [batch,1]
    
        y_prob = tf.exp(log_p_y)   
        # tf.print(tf.shape(y_prob))
        if self.train_type == 'pretrain': # fix classification part
            loss += L_ORT
            # tf.print(loss)
        elif self.train_type == 'classifier': # fix generative part
            # assert self.MIT == False
            loss = tf.reduce_mean(-log_p_y)
            
        elif self.train_type == 'joint': # train both classifier and generative
            loss += (L_ORT + tf.reduce_mean(-log_p_y)) # scalar
        
        self.loss = loss 
        return y_prob
         
    def train_step(self, data):
        from tensorflow.python.keras.engine import data_adapter
        from tensorflow.python.eager import backprop

        data = data_adapter.expand_1d(data)
        x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data)
        # Run forward pass.
        with backprop.GradientTape() as tape:
            y_prob = self(x,output=y,training=True)
            # Run backwards pass.
        self.optimizer.minimize(self.loss, self.trainable_variables, tape=tape)

        return_metrics = {"loss": self.loss}

        return return_metrics

    def test_step(self, data):
        from tensorflow.python.keras.engine import data_adapter

        x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data)

        y_prob = self(x,output=y,training=False)
 
        return {"loss": self.loss}  # self.compute_metrics(x, y, y_pred, sample_weight)


    def get_config(self):
        return self._config

        # problem 1.         
        