import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import lightning as L

class MultiheadAttention(nn.Module):
    def __init__(self, d_model, nhead):
        super(MultiheadAttention, self).__init__()

        assert d_model % nhead == 0, "d_model must be divisible by nhead"
        self.d_model = d_model
        self.nhead = nhead
        self.head_dim = d_model // nhead

        # Linear layers for the query, key, and value projections
        self.q_linear = nn.Linear(d_model, d_model)
        self.k_linear = nn.Linear(d_model, d_model)
        self.v_linear = nn.Linear(d_model, d_model)

        # Linear layer for the output of the attention heads
        self.out_linear = nn.Linear(d_model, d_model)

    def forward(self, query, key, value, mask=None):
        """
        Args:
            query, key, value: Tensors of shape (batch_size, seq_len, d_model)
            mask: Optional mask to mask out elements in the input sequence
                  (e.g., for padding or future elements in the decoder)
        Returns:
            output: Tensor of shape (batch_size, seq_len, d_model)
        """
        # Linearly project queries, keys, and values
        q = self.q_linear(query)
        k = self.k_linear(key)
        v = self.v_linear(value)

        # Split the queries, keys, and values into multiple heads
        q = q.view(q.size(0), -1, self.nhead, self.head_dim).transpose(1, 2)
        k = k.view(k.size(0), -1, self.nhead, self.head_dim).transpose(1, 2)
        v = v.view(v.size(0), -1, self.nhead, self.head_dim).transpose(1, 2)

        # Compute scaled dot-product attention
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float("-inf"))

        attn_weights = F.softmax(scores, dim=-1)
        output = torch.matmul(attn_weights, v)

        # Concatenate and linearly project the attention heads
        output = output.transpose(1, 2).contiguous().view(output.size(0), -1, self.d_model)
        output = self.out_linear(output)

        return output

class PositionwiseFeedforward(nn.Module):
    def __init__(self, d_model, dropout=0.1):
        super(PositionwiseFeedforward, self).__init__()

        self.linear1 = nn.Linear(d_model, 4*d_model)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(4*d_model, d_model)

    def forward(self, x):
        """
        Args:
            x: Input tensor of shape (batch_size, seq_len, d_model)
        Returns:
            Output tensor after applying position-wise feedforward network.
        """
        x = F.relu(self.linear1(x))
        x = self.dropout(x)
        x = self.linear2(x)
        return x
   
class SublayerConnection(nn.Module):
    def __init__(self, size, dropout):
        super(SublayerConnection, self).__init__()
        self.norm = nn.LayerNorm(size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, sublayer):
        """
        Args:
            x: Input tensor
            sublayer: Sublayer module (e.g., self-attention, feedforward)
        Returns:
            Output tensor after applying layer normalization, dropout, and the sublayer.
        """
        return x + self.dropout(sublayer(self.norm(x)))

class EncoderLayer(nn.Module):
    def __init__(self, size, self_attn, feed_forward, dropout):
        super(EncoderLayer, self).__init__()

        self.self_attn = self_attn
        self.feed_forward = feed_forward
        self.sublayer = SublayerConnection(size, dropout)

    def forward(self, x):
        """
        Args:
            x: Input tensor of shape (batch_size, seq_len, d_model)
            mask: Optional mask to mask out elements in the input sequence
                  (e.g., for padding)
        Returns:
            Output tensor after applying self-attention and position-wise feedforward.
        """
        x = self.sublayer(x, lambda x: self.self_attn(x, x, x))
        if self.feed_forward is None:
            return x  
        else:
            x = self.sublayer(x, self.feed_forward)
        return x


class TransformerEncoder(nn.Module):
    def __init__(self, num_layers, d_model, nhead, dropout=0.1,feed_forward=PositionwiseFeedforward):
        super(TransformerEncoder, self).__init__()
        if feed_forward:
            self.layers = nn.ModuleList([
                EncoderLayer(d_model, MultiheadAttention(d_model, nhead), feed_forward(d_model, dropout), dropout)
                for _ in range(num_layers)
            ])
        else:
            self.layers = nn.ModuleList([
                EncoderLayer(d_model, MultiheadAttention(d_model, nhead), feed_forward, dropout)
                for _ in range(num_layers)
            ])

    def forward(self, x, mask=None):
        """
        Args:
            x: Input tensor of shape (batch_size, seq_len, d_model)
            mask: Optional mask to mask out elements in the input sequence
                  (e.g., for padding)
        Returns:
            Output tensor after applying the specified number of encoder layers.
        """
        for layer in self.layers:
            x = layer(x)
        return x

class Transformer_block(nn.Module):
    def __init__(self, d_model, multi_head_num, num_layers,seq_len,d_input,d_output,feed_forward=True):
        super(Transformer_block, self).__init__()
        
        self.linear_project=nn.Linear(d_input,d_model)
        
        # self.dropout = nn.Dropout(p=0.1)
        self.pe = torch.zeros(seq_len, d_model)
        position = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        self.pe[:, 0::2] = torch.sin(position * div_term)
        self.pe[:, 1::2] = torch.cos(position * div_term)
        # self.pe = self.pe.unsqueeze(0).transpose(0, 1)
        # self.positional_encoding = PositionalEncoding(d_model,seq_len)

        # Transformer encoder and decoder layers (as defined in previous examples)
        if feed_forward:
            self.transformer_encoder = TransformerEncoder(num_layers,d_model,multi_head_num)
        else:
            self.transformer_encoder = TransformerEncoder(num_layers,d_model,multi_head_num,feed_forward=None)

        # Output layer
        self.fc = nn.Linear(d_model, d_output)

    def forward(self, src):
        # Embedding and positional encoding
        self.pe = self.pe.to(src)
        enc_out = self.linear_project(src)
        batch_size, seq_len, d_model = enc_out.size()
        enc_out = enc_out + self.pe.unsqueeze(0).expand(batch_size, seq_len, -1)
        # src = self.positional_encoding(self.linear_project(src))
        # Transformer encoder and decoder
        memory = self.transformer_encoder(enc_out)

        # Linear layer for the final output
        output = self.fc(memory)

        return output


class Trans_Encoder(nn.Module):
    def __init__(self,d_model=256, seq_len=25, enc_in=17, trans_embed_size=128,multi_head_num=8,trans_encoder_layers=6,feed_forward=True):
        super().__init__()

        self.transformer=Transformer_block(d_model=d_model,
                                           multi_head_num=multi_head_num,
                                           num_layers=trans_encoder_layers,
                                           seq_len=seq_len,
                                           d_input=enc_in,
                                           d_output=enc_in,
                                           feed_forward=feed_forward)
        self.linear_min1=nn.Linear(seq_len*enc_in,1024)
        self.linear_min2=nn.Linear(1024,512)
        self.linear_min3=nn.Linear(512,trans_embed_size)

    def forward(self, x):
        x=self.transformer(x)
        x=x.view(x.shape[0],-1)
        x=self.linear_min1(x)
        x=F.relu(x)
        x=self.linear_min2(x)
        x=F.relu(x)
        x=self.linear_min3(x)
        return x



class Trans_Decoder(nn.Module):
    def __init__(self,d_model=256, seq_len=25, enc_in=17, trans_embed_size=256,multi_head_num=8,trans_encoder_layers=6,feed_forward=True):
        super().__init__()
        self.seq_len = seq_len
        self.enc_in = enc_in
        self.transformer=Transformer_block(d_model,
                                           multi_head_num,trans_encoder_layers,
                                           seq_len,
                                           enc_in,
                                           enc_in,
                                           feed_forward=feed_forward)
        self.linear_min1=nn.Linear(trans_embed_size,512)
        self.linear_min2=nn.Linear(512,1024)
        self.linear_min3=nn.Linear(1024,seq_len*enc_in)

    def forward(self, x):

        x=self.linear_min1(x)
        x=F.relu(x)
        x=self.linear_min2(x)
        x=F.relu(x)
        x=self.linear_min3(x)

        x=x.view(x.shape[0],self.seq_len,self.enc_in)
        x=self.transformer(x)
        return x


class SimLOBEmbedding(L.LightningModule):
    def __init__(self,
                 trans_embed_size,
                 multi_head_num,
                 trans_encoder_layer,
                 feed_forward,
                 d_model,
                 seq_len,
                 enc_in,
                 **kwargs):
        super().__init__(**kwargs)
        self.seq_len = seq_len
        self.d_model = d_model
        self.enc_in = enc_in
        self.trans_embed_size = trans_embed_size
        self.encoder = Trans_Encoder(d_model=d_model,
                                     enc_in=enc_in,
                                     seq_len=seq_len,
                                     trans_embed_size=trans_embed_size,
                                     multi_head_num=multi_head_num,
                                     trans_encoder_layers=trans_encoder_layer,
                                     feed_forward=feed_forward)
        
        self.linear_encoding = nn.Linear(in_features=trans_embed_size, out_features=d_model)
        decoder_layer = nn.TransformerDecoderLayer(d_model, nhead=8, batch_first=True)
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=6)
        self.projection = nn.Linear(d_model, self.seq_len * self.enc_in, bias=True)
 
    def encode(self,x_enc):
        enc_out = self.encoder(x_enc)
        enc_out = self.linear_encoding(enc_out.view((-1, self.trans_embed_size)))
        return enc_out
    
    def forward(self, x):
        enc_out = self.encoder(x)
        enc_out = enc_out.unsqueeze(1)
        # final
        memory = torch.zeros(enc_out.shape[0], enc_out.shape[1], enc_out.shape[2], device=enc_out.device)
        enc_out = self.decoder(enc_out, memory)
        out = self.projection(enc_out)
        out = out.view(out.shape[0], self.seq_len, self.enc_in)
        return out   