import torch
import torch.nn as nn
from einops.layers.torch import Rearrange

from einops import repeat, rearrange

class Conv1D_BN(nn.Module):
    def __init__(self, num_features_in, num_features_out, kernel_size=3, stride=1, padding=1):
        super(Conv1D_BN, self).__init__()
        self.net = nn.Sequential(nn.Conv1d(in_channels=num_features_in,
                                           out_channels=num_features_out,
                                           kernel_size=kernel_size,
                                           stride=stride,
                                           padding=padding,
                                           bias=False),
                                           
                                 nn.BatchNorm1d(num_features_out),
                                 nn.ReLU())
        
                                           

    def forward(self, x):
        return self.net(x)
    


class Skip_Connections(nn.Module):
    def __init__(self, num_features):
        super(Skip_Connections, self).__init__()
        self.first_conv = Conv1D_BN(num_features_in=num_features,
                                    num_features_out=num_features//2)
        
        self.second_conv = Conv1D_BN(num_features_in=num_features//2,
                                     num_features_out=num_features//4)

        self.third_conv = Conv1D_BN(num_features_in=num_features//4,
                                    num_features_out=num_features//4)

        self.alphas = nn.Parameter(torch.zeros(1, num_features, 1))

    
    def forward(self, x):
        res = x
        h1 = self.first_conv(x)
        h2 = self.second_conv(h1)
        h3 = self.third_conv(h2)

        h = torch.cat([h1, h2, h3], dim=1)
        return res * self.alphas + h
    

class CNN_Emb_Net(nn.Module):
    def __init__(self, model_dim, c_in):
        super(CNN_Emb_Net, self).__init__()
        self.first_conv = Conv1D_BN(c_in, model_dim, kernel_size=7, stride=2, padding=3)
        self.skip_1 = Skip_Connections(num_features=model_dim)
        self.skip_2 = Skip_Connections(num_features=model_dim)
        self.skip_3 = Skip_Connections(num_features=model_dim)

        self.max_pool = nn.MaxPool1d(kernel_size=2)
        self.last_pool = nn.AdaptiveMaxPool1d(1)
        
    def forward(self, x):
        h = self.first_conv(x)
        h = self.skip_1(h)
        h = self.max_pool(h)
        h = self.skip_2(h)
        h = self.max_pool(h)
        h = self.skip_3(h)
        return self.last_pool(h)
    

class Patch_Embeddings(nn.Module):
    
    def __init__(self, patch_size, model_dim, in_channels):
        super(Patch_Embeddings, self).__init__()
        self.ps = patch_size
        
        self.re1 = Rearrange("b (l ps) -> (b l) ps", ps = patch_size)
        self.strip_2_emb = nn.Sequential( 
                                nn.LayerNorm(patch_size),   
                                nn.Linear(in_features=patch_size, out_features=model_dim, bias=False),
                                nn.LayerNorm(model_dim))
        

    def forward(self, x):
        b, l = x.shape

        h = self.re1(x)
        h = h.unsqueeze(1)
        h = self.strip_2_emb(h).squeeze()
        h = rearrange(h, '(b l) c -> b l c', l = l // self.ps)
    
        return h


class PositionalEmbedding(nn.Module):
    def __init__(self, model_dim, do_prob, max_seq_len=26):
        super(PositionalEmbedding, self).__init__()
        self.model_dim = model_dim

        position_id = torch.arange(0, max_seq_len).unsqueeze(1)
        frequencies = torch.pow(10000., -torch.arange(0, model_dim, 2, dtype=torch.float) / model_dim)
        positional_encodings = torch.zeros(max_seq_len, model_dim)
        positional_encodings[:, 0::2] = torch.sin(position_id * frequencies)  
        positional_encodings[:, 1::2] = torch.cos(position_id * frequencies)  
        self.register_buffer('positional_encodings', positional_encodings)

        self.dropout = nn.Dropout(do_prob)
        

    def forward(self, x):
        tmp_encoding = self.positional_encodings[:x.shape[1]]
        return self.dropout(x + tmp_encoding)
    

class MultiHead_Attention(nn.Module):

    def __init__(self, model_dim, number_heads, do_prob):
        super(MultiHead_Attention, self).__init__()
        self.number_heads = number_heads

        self.scale_factor = 1 / ((model_dim / number_heads) ** 0.5)
        self.att_drop_out = nn.Dropout(do_prob)
        self.output_drop_out = nn.Dropout(do_prob)

        self.block_output = nn.Linear(model_dim, model_dim)

        self.split_head = Rearrange('b l (h d) -> b h l d', h = self.number_heads)
        self.split_head_t = Rearrange('b l (h d) -> b h d l', h = self.number_heads)
        self.concat = Rearrange('b h l d -> b l (h d)') 

        self.x_to_q = nn.Linear(model_dim, model_dim)
        self.x_to_k = nn.Linear(model_dim, model_dim)
        self.x_to_v = nn.Linear(model_dim, model_dim)


    def forward(self, q, k, v):
        # q, k and v with shape (batch_size, seq_len, embedding_dimension)
        q = self.split_head(self.x_to_q(q))
        k_transpose = self.split_head_t(self.x_to_k(k))
        v = self.split_head(self.x_to_v(v))

        attention = torch.matmul(q, k_transpose)
        attention = attention * self.scale_factor
        attention = self.att_drop_out(attention.softmax(-1))
        output = torch.matmul(attention, v)
        output = self.concat(output)
        output = self.block_output(output)
        return self.output_drop_out(output)



class FeedForwardNet(nn.Module):
    def __init__(self, model_dim, do_prob, wide_factor=4):
        super(FeedForwardNet, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(model_dim, model_dim * wide_factor),
            nn.GELU(),
            nn.Dropout(do_prob),
            nn.Linear(model_dim * wide_factor, model_dim),
            nn.Dropout(do_prob)
        )

    def forward(self, x):
        return self.net(x)



class Add_and_Norm(nn.Module):
    
    def __init__(self, model_dim):
        super(Add_and_Norm, self).__init__()
        self.norm = nn.LayerNorm(model_dim)

    def forward(self, x, res):
        return self.norm(x + res)



class EncoderBlock(nn.Module):
    def __init__(self, number_heads, model_dim, do_prob):
        super().__init__()
        self.mh_atten_block = MultiHead_Attention(number_heads=number_heads, 
                                                  model_dim=model_dim,
                                                  do_prob=do_prob)
        
        self.add_norm_mh = Add_and_Norm(model_dim=model_dim)
        self.ffn = FeedForwardNet(model_dim=model_dim, 
                                  do_prob=do_prob)

        self.add_norm_ffn = Add_and_Norm(model_dim=model_dim)

    def forward(self, x):
        res = x
        h = self.mh_atten_block(x, x, x)
        h = self.add_norm_mh(h, res)
        
        res = h
        h = self.ffn(h)
        return self.add_norm_ffn(h, res)


class Encoder(nn.Module):
    def __init__(self, num_blocks, num_heads, model_dim, do_prob):
        super(Encoder, self).__init__()
        self.num_blocks = num_blocks
        self.num_heads  = num_heads
        self.model_dim = model_dim
        self.do_prob = do_prob

        self.net = self.create_net()

    def forward(self, x):
        h = x
        for layer in self.net:
            h = layer(h)

        return h

    def create_net(self):
        net = nn.ModuleList()

        for _ in range(self.num_blocks):
            net.append(EncoderBlock(
                number_heads=self.num_heads,
                model_dim=self.model_dim,
                do_prob=self.do_prob))

        return net
    

class ViT(nn.Module):
    def __init__(self, num_blocks=6, num_heads=8, model_dim=128, do_prob=0.1, patch_size=40, in_channels=1, fs=100, l=10):
        super(ViT, self).__init__()
        self.class_token = nn.Parameter(torch.randn(1, 1, model_dim), requires_grad=True)
        self.patch_embedding = Patch_Embeddings(patch_size=patch_size, model_dim=model_dim, in_channels=in_channels)   
        self.pos_enc = PositionalEmbedding(model_dim=model_dim, max_seq_len=int(fs * l/patch_size) + 1, do_prob=do_prob)
        self.encoder = Encoder(num_blocks=num_blocks, num_heads=num_heads, model_dim=model_dim, do_prob=do_prob)


    def forward(self, x): 
        cls_token = repeat(self.class_token, '() p d -> b p d', b=x.shape[0])
       
        h = self.patch_embedding(x)        
        h = torch.cat([cls_token, h], dim=1)
        h = self.pos_enc(h)
        h = self.encoder(h)
        h = h[:, 0, :]
        return h
    
    def count_parameters(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)