import torch
import torch.nn as nn
import math
import numpy as np
from .autodecoder import decoder
from torch.nn.modules import MultiheadAttention, Linear, Dropout, BatchNorm1d, TransformerEncoderLayer, LayerNorm
import torch.nn.functional as F

class RnnClassifier(nn.Module):
    def __init__(self,  encoding_size, hidden_size, output_size, cell_type='GRU', num_layers=1, dropout=0, bidirectional=True):
        super(RnnClassifier, self).__init__()
        self.hidden_size = hidden_size
        self.in_channel = encoding_size
        self.num_layers = num_layers
        self.cell_type = cell_type
        self.encoding_size = encoding_size
        self.bidirectional = bidirectional
        self.output_size = output_size

        self.fc = torch.nn.Sequential(torch.nn.Linear(self.hidden_size*(int(self.bidirectional) + 1), self.encoding_size))
        self.nn = torch.nn.Sequential(torch.nn.Linear(self.encoding_size, self.output_size))
        if cell_type=='GRU':
            self.rnn = torch.nn.GRU(input_size=self.in_channel, hidden_size=self.hidden_size, num_layers=num_layers,
                                    batch_first=False, dropout=dropout, bidirectional=bidirectional)
        elif cell_type=='LSTM':
            self.rnn = torch.nn.LSTM(input_size=self.in_channel, hidden_size=self.hidden_size, num_layers=num_layers,
                                    batch_first=False, dropout=dropout, bidirectional=bidirectional)
        else:
            raise ValueError('Cell type not defined, must be one of the following {GRU, LSTM, RNN}')

    def forward(self, x):
        x = x.permute(1,0,2)
        if self.cell_type=='GRU':
            past = torch.zeros(self.num_layers * (int(self.bidirectional) + 1), x.shape[1], self.hidden_size)
        elif self.cell_type=='LSTM':
            h_0 = torch.zeros(self.num_layers * (int(self.bidirectional) + 1), (x.shape[1]), self.hidden_size)
            c_0 = torch.zeros(self.num_layers * (int(self.bidirectional) + 1), (x.shape[1]), self.hidden_size)
            past = (h_0, c_0)
        out, _ = self.rnn(x)  # out shape = [seq_len, batch_size, num_directions*hidden_size]
        encodings = self.fc(out[-1])
        return self.nn(encodings)

class Rnnforecast(nn.Module):
    def __init__(self, encoding_dim, hidden_size, output_dim, middle_dim, num_layers=1, dropout=0.05):
        super(Rnnforecast, self).__init__()
        self.input_size = encoding_dim
        self.hidden_size = hidden_size
        self.output_dim = output_dim
        self.bidirectional = True
        #self.softmax = nn.Softmax()
        self.nn = torch.nn.Sequential(nn.Linear(self.hidden_size*(int(self.bidirectional) + 1), middle_dim), nn.ReLU(),
                                          nn.Linear(middle_dim, output_dim))
        self.rnn = torch.nn.GRU(input_size=self.input_size, hidden_size=self.hidden_size, num_layers=num_layers,
                                batch_first=False, dropout=dropout, bidirectional=self.bidirectional)
    
    def forward(self, x):
        out, _ = self.rnn(x)
        for i, layer in enumerate(self.nn):
            out = layer(out)
        #out = self.softmax(out)
        return out

class param_weight(nn.Module):
    def __init__(self, input_size, embedding_dim):
        super(param_weight, self).__init__()
        self.param = nn.Linear(input_size, 256)
        self.param1 = nn.Linear(256, embedding_dim)
        self.para_function = nn.Softmax()
    def forward(self, x):
        w = self.param(x)
        w =self.param1(torch.relu(w))
        return self.para_function(w)

class Model(nn.Module):
    def __init__(self, args, length):
        super(Model, self).__init__()
        self.decoder = decoder(args.input_dim, args.theta_size, args.middle_size, args.map_layers, 
                                    args.degree_of_polynomial, args.harmonics, args.trend_layers, args.season_layers, args.decomp_layers, args.norm)
        self.para_w = param_weight(args.input_dim, args.decomp_layers)
    
    def forward(self, x):
        b, _, _ = x.shape
        
        res, _, kl_q, latent_variables = self.decoder(x)
        w = self.para_w(x)

        return res, kl_q, latent_variables, w

class lr(nn.Module):
    def __init__(self, args):
        super(lr, self).__init__()
        self.logic = nn.Linear(args.encoding_dim, args.input_dim)
        self.sm = nn.Sigmoid()

    def forward(self, x):
        #print(x.shape)
        x = self.logic(x)
        x = self.sm(x)
        #x = self.proj(x.transpose(2, 1)).transpose(2, 1)
        return x

def _get_activation_fn(activation):
    if activation == "relu":
        return F.relu
    elif activation == "gelu":
        return F.gelu
    raise ValueError("activation should be relu/gelu, not {}".format(activation))

class FixedPositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=1024, scale_factor=1.0):
        super(FixedPositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)  # positional encoding
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = scale_factor * pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)  # this stores the variable in the state_dict (used for non-trainable variables)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

class TransformerEncoder(nn.Module):

    def __init__(self, feat_dim, max_len, d_model, n_heads, num_layers, dim_feedforward, dropout=0.1,
                 pos_encoding='fixed', activation='gelu', norm='BatchNorm', freeze=False):
        super(TransformerEncoder, self).__init__()

        self.max_len = max_len
        self.d_model = d_model
        self.n_heads = n_heads

        self.project_inp = nn.Linear(feat_dim, d_model)
        self.pos_enc = FixedPositionalEncoding(d_model, dropout=dropout*(1.0 - freeze), max_len=max_len)

        encoder_layer = TransformerBatchNormEncoderLayer(d_model, self.n_heads, dim_feedforward)#, dropout*(1.0 - freeze), activation=activation)

        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers)
        self.norm1 = LayerNorm(d_model)  # normalizes each feature across batch samples and time steps
        self.norm2 = BatchNorm1d(d_model, eps=1e-5)
        
        self.output_layer = nn.Linear(d_model, feat_dim)

        self.act = _get_activation_fn(activation)

        self.dropout1 = nn.Dropout(dropout)
     
        self.feat_dim = feat_dim

    def forward(self, X, padding_masks=None):
        # permute because pytorch convention for transformers is [seq_length, batch_size, feat_dim]. padding_masks [batch_size, feat_dim]
        inp = X.permute(1, 0, 2)
        #inp = X
        inp = self.project_inp(inp) * math.sqrt(self.d_model)  # [seq_length, batch_size, d_model] project input vectors to d_model dimensional space
        inp = self.pos_enc(inp)  # add positional encoding

        output = self.transformer_encoder(inp, src_key_padding_mask=padding_masks)  # (seq_length, batch_size, d_model)
        #output = self.norm2(output)
        output = output.transpose(1, 0)
        output = self.norm1(output)
        
        target = self.output_layer(output)
        #target = self.norm1(target)
        return output, target

class TransformerBatchNormEncoderLayer(nn.modules.Module):
    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.05, activation="relu"):
        super(TransformerBatchNormEncoderLayer, self).__init__()
        self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
        # Implementation of Feedforward model
        self.linear1 = Linear(d_model, dim_feedforward)
        self.dropout = Dropout(dropout)
        self.linear2 = Linear(dim_feedforward, d_model)
        #self.norm1 = BatchNorm1d(d_model)#, eps=1e-5)  # normalizes each feature across batch samples and time steps
        #self.norm2 = BatchNorm1d(d_model)#, eps=1e-5)
        #self.norm1 = LayerNorm(d_model)
        #self.norm2 = LayerNorm(d_model)
        self.dropout1 = Dropout(dropout)
        self.dropout2 = Dropout(dropout)

        self.activation = _get_activation_fn(activation)

    def __setstate__(self, state):
        if 'activation' not in state:
            state['activation'] = F.relu
        super(TransformerBatchNormEncoderLayer, self).__setstate__(state)

    def forward(self, src, src_mask, src_key_padding_mask):

        src2 = self.self_attn(src, src, src, attn_mask=src_mask,
                              key_padding_mask=src_key_padding_mask)[0]
        src = src + self.dropout1(src2)  # (seq_len, batch_size, d_model)
        return src