import torch
import torch.nn as nn
from einops import rearrange


class VQ_model(nn.Module):
    def __init__(self, enc_in=7, seq_len=96,d_ff=16, dropout=0.1, n_heads=4, activation='relu', e_layers=4, num_code=5,
                 d_code=256, kernel=[4, 3], stride=[1, 1], vq_layers=3):
        super(VQ_model, self).__init__()
        self.encoder = nn.Sequential(nn.Conv1d(enc_in, enc_in, kernel_size=kernel[0], stride=stride[0]),
                                     nn.ReLU(),
                                     nn.Conv1d(enc_in, enc_in, kernel_size=kernel[1], stride=stride[1]))
        self.decoder = nn.Sequential(nn.ConvTranspose1d(enc_in, enc_in, kernel_size=kernel[1], stride=stride[1]),
                                     nn.ReLU(),
                                     nn.ConvTranspose1d(enc_in, enc_in, kernel_size=kernel[0], stride=stride[0]))


        self.encoder_out_layer = nn.Linear(int((int((seq_len-kernel[0])/stride[0]+1)-kernel[1])/stride[1]+1), d_code)
        self.decoder_input_layer = nn.Linear(d_code, int((int((seq_len-kernel[0])/stride[0]+1)-kernel[1])/stride[1]+1))

        self.vq_layers = vq_layers

        self.multi_embedding = nn.ModuleList([nn.Embedding(num_code, d_code) for _ in range(self.vq_layers)])


        for i in range(self.vq_layers):
            self.multi_embedding[i].weight.data.uniform_(-1.0 / num_code, 1.0 / num_code)

        self.mix_zq = nn.Linear(d_code*self.vq_layers, d_code)


    def forward(self, enc_in):

        # We will provide the complete code after the receipt.

        pass



