import torch
import torch.nn as nn

class Decoder(nn.Module):
    def __init__(self, input_dim, fc_dim, n_feature, kernel_size, output_channels=[64,128,512,1024]):
        super(Decoder, self).__init__()
        self.fc_dim = fc_dim
        self.fc1 = nn.Sequential(
            nn.Linear(input_dim, fc_dim, bias=False), 
            nn.BatchNorm1d(fc_dim), 
            nn.ReLU()
        )
        self.output_channels = output_channels
        self.un1 = nn.MaxUnpool1d(kernel_size=2, stride=2)
        self.deconv1 = nn.Sequential(
            nn.ConvTranspose1d(in_channels=output_channels[0], out_channels=output_channels[1], kernel_size=1),
            nn.ReLU()
        )

        self.un2 = nn.MaxUnpool1d(kernel_size=2, stride=2)
        self.deconv2 = nn.Sequential(
            nn.ConvTranspose1d(in_channels=output_channels[1], out_channels=output_channels[2], kernel_size=1),
            nn.ReLU()
        )

        self.un3 = nn.MaxUnpool1d(kernel_size=2, stride=2)
        self.deconv3 = nn.Sequential(
            nn.ConvTranspose1d(in_channels=output_channels[2], out_channels=output_channels[3], kernel_size=1),
            nn.ReLU()
        )

        self.un4 = nn.MaxUnpool1d(kernel_size=2, stride=2)
        self.deconv4 = nn.Sequential(
            nn.ConvTranspose1d(in_channels=output_channels[3], out_channels=n_feature, kernel_size=kernel_size),
            nn.ReLU()
        )

        self.init_weights()

    def init_weights(self):
        nn.init.xavier_uniform_(self.fc1[0].weight)
        nn.init.xavier_uniform_(self.deconv1[0].weight)
        nn.init.xavier_uniform_(self.deconv2[0].weight)
        nn.init.xavier_uniform_(self.deconv3[0].weight)
        nn.init.xavier_uniform_(self.deconv4[0].weight)

    def forward(self, z, idxs, sizes):
        h = self.fc1(z)
        h = h.view(-1, self.output_channels[0], self.fc_dim // self.output_channels[0])

        out_1 = self.un1(h, idxs[3], output_size=sizes[2])
        out_11 = self.deconv1(out_1)

        out_2 = self.un2(out_11, idxs[2], output_size=sizes[1])
        out_22 = self.deconv2(out_2)

        out_3 = self.un3(out_22, idxs[1], output_size=sizes[0])
        out_33 = self.deconv3(out_3)

        out_4 = self.un4(out_33, idxs[0])
        out = self.deconv4(out_4)
        return out



class Encoder(nn.Module):
    def __init__(self, output_dim, input_dim, fc_dim, kernel_size, out_channels=[1024, 512, 128, 64]):
        super(Encoder, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv1d(in_channels=input_dim, out_channels=out_channels[0], kernel_size=kernel_size),
            nn.ReLU()
        )
        self.pool1 = nn.MaxPool1d(kernel_size=2, stride=2, return_indices=True, ceil_mode=True)

        self.conv2 = nn.Sequential(
            nn.Conv1d(in_channels=out_channels[0], out_channels=out_channels[1], kernel_size=1),
            nn.ReLU()
        )
        self.pool2 = nn.MaxPool1d(kernel_size=2, stride=2, return_indices=True, ceil_mode=True)

        self.conv3 = nn.Sequential(
            nn.Conv1d(in_channels=out_channels[1], out_channels=out_channels[2], kernel_size=1),
            nn.ReLU()
        )
        self.pool3 = nn.MaxPool1d(kernel_size=2, stride=2, return_indices=True, ceil_mode=True)

        self.conv4 = nn.Sequential(
            nn.Conv1d(in_channels=out_channels[2], out_channels=out_channels[3], kernel_size=1),
            nn.ReLU()
        )
        self.pool4 = nn.MaxPool1d(kernel_size=2, stride=2, return_indices=True, ceil_mode=True)

        self.fc11 = nn.Sequential(nn.Linear(fc_dim, output_dim))
        self.fc12 = nn.Sequential(nn.Linear(fc_dim, output_dim), nn.Softplus())

        self.init_weights()

    def init_weights(self):
        nn.init.xavier_uniform_(self.conv1[0].weight)
        nn.init.xavier_uniform_(self.conv2[0].weight)
        nn.init.xavier_uniform_(self.conv3[0].weight)
        nn.init.xavier_uniform_(self.conv4[0].weight)

        nn.init.xavier_uniform_(self.fc11[0].weight)
        self.fc11[0].bias.data.zero_()
        nn.init.xavier_uniform_(self.fc12[0].weight)
        self.fc12[0].bias.data.zero_()
        
    def forward(self, x):
        out_conv1 = self.conv1(x.float())
        out1, idx1 = self.pool1(out_conv1)

        out_conv2 = self.conv2(out1)
        out2, idx2 = self.pool2(out_conv2)

        out_conv3 = self.conv3(out2)
        out3, idx3 = self.pool3(out_conv3)

        out_conv4 = self.conv4(out3)
        out4, idx4 = self.pool4(out_conv4)

        out = out4.reshape(-1, out4.shape[1] * out4.shape[-1])
        size1 = out1.size()
        size2 = out2.size()
        size3 = out3.size()
        size4 = out4.size()

        zy_loc = self.fc11(out)
        zy_scale = self.fc12(out) + 1e-7

        return zy_loc, zy_scale, [idx1, idx2, idx3, idx4], [size1, size2, size3, size4]
    

class MLPEncoder(nn.Module):
    def __init__(self, hidden_dim, input_dim, seq_dim):
        super(MLPEncoder, self).__init__()
        self.fc1 = nn.Linear(input_dim * seq_dim, hidden_dim)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.relu2 = nn.ReLU()
        self.fc3 = nn.Linear(hidden_dim, hidden_dim)

        self.fc11 = nn.Sequential(nn.Linear(hidden_dim, hidden_dim))
        self.fc12 = nn.Sequential(nn.Linear(hidden_dim, hidden_dim), nn.Softplus())
        self.init_weights()

    def init_weights(self):
        nn.init.xavier_uniform_(self.fc1.weight)
        nn.init.xavier_uniform_(self.fc2.weight)
        nn.init.xavier_uniform_(self.fc3.weight)
        nn.init.xavier_uniform_(self.fc11[0].weight)
        self.fc11[0].bias.data.zero_()
        nn.init.xavier_uniform_(self.fc12[0].weight)
        self.fc12[0].bias.data.zero_()

    def forward(self, x):
        # print(x.shape)
        x = x.reshape(x.shape[0], -1)
        x = self.fc1(x)
        x = self.relu1(x)
        x = self.fc2(x)
        x = self.relu2(x)
        out = self.fc3(x)
        # print(out.shape)
        zy_loc = self.fc11(out)
        zy_scale = self.fc12(out) + 1e-7
        return zy_loc, zy_scale

class MLPDecoder(nn.Module):
    def __init__(self, hidden_dim, input_dim, seq_dim):
        super(MLPDecoder, self).__init__()
        self.input_dim = input_dim
        self.seq_dim = seq_dim
        self.fc1 = nn.Linear(hidden_dim, hidden_dim)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.relu2 = nn.ReLU()
        self.fc3 = nn.Linear(hidden_dim, input_dim * seq_dim)
        self.init_weights()
    
    def init_weights(self):
        nn.init.xavier_uniform_(self.fc1.weight)
        nn.init.xavier_uniform_(self.fc2.weight)
        nn.init.xavier_uniform_(self.fc3.weight)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu1(x)
        x = self.fc2(x)
        x = self.relu2(x)
        x = self.fc3(x)
        x = x.reshape(-1,self.input_dim, self.seq_dim)
        return x