from slot_attention.model.model_utils import Tensor, assert_shape, conv_transpose_out_shape


from torch import nn

# class SlotAttentionDecoder(nn.Module):

def build_decoder(decoder_hidden_dims,
                  kernel_size,
                  decoder_stride,
                  decoder_padding,
                  decoder_output_padding,
                  resolution,
                  out_features,
                  decoder_resolution,
                  slot_size,
                  ):
    print(f'decoder params')

    # print(f'kernel_size: {kernel_size}')
    # print(f'decoder_stride: {decoder_stride}')
    # print(f'decoder_padding: {decoder_padding}')
    # print(f'decoder_output_padding: {decoder_output_padding}')
    # print(f'resolution: {resolution}')
    print(f'out_features: {out_features}')
    # print(f'decoder_resolution: {decoder_resolution}')
    # print(f'decoder_hidden_dims: {decoder_hidden_dims}')
    print(f'slot_size: {slot_size}')

    # Build Decoder
    modules = []

    in_size = decoder_resolution[0]
    out_size = in_size
    print(f'in_size: {in_size}')
    channels = slot_size

    for i, h_dim in enumerate(decoder_hidden_dims):
        decod_stride = decoder_stride[i]
        modules.append(
            nn.Sequential(
                nn.ConvTranspose2d(
                    in_channels=channels,
                    out_channels=h_dim,
                    kernel_size=kernel_size,
                    stride=decod_stride,
                    padding=decoder_padding,
                    output_padding=decoder_output_padding,
                ),
                nn.LeakyReLU(),
            )
        )
        channels = h_dim
        out_size = conv_transpose_out_shape(out_size, stride=decod_stride, padding=decoder_padding, kernel_size=kernel_size, out_padding=decoder_output_padding)
        print(f'{decoder_output_padding=}')
        print(f'{decod_stride=}')
        print(f'{out_size=}')


    # in_size = decoder_resolution[0]
    # out_size = in_size
    # print(f'in_size: {in_size}')
    # for i in range(len(hidden_dims) - 1, -1, -1):
    #     print(f'i in decoder: {i}')
    #     modules.append(
    #         nn.Sequential(
    #             nn.ConvTranspose2d(
    #                 hidden_dims[i],
    #                 hidden_dims[i - 1],
    #                 kernel_size=kernel_size,
    #                 stride=decoder_stride,
    #                 padding=decoder_padding,
    #                 output_padding=decoder_output_padding,
    #             ),
    #             nn.LeakyReLU(),
    #         )
    #     )
    #     out_size = conv_transpose_out_shape(out_size, stride=decoder_stride, padding=decoder_padding, kernel_size=kernel_size, out_padding=decoder_output_padding)
    #     print(f'out_size: {out_size}')

    assert_shape(
        resolution,
        (out_size, out_size),
        message="Output shape of decoder did not match input resolution. Try changing `decoder_resolution`.",
    )

    # same convolutions
    modules.append(
        nn.Sequential(
            nn.ConvTranspose2d(
                out_features, out_features, kernel_size=5, stride=1, padding=2, output_padding=0,
            ),
            nn.LeakyReLU(),
            nn.ConvTranspose2d(out_features, 4, kernel_size=3, stride=1, padding=1, output_padding=0,),
        )
    )

    assert_shape(resolution, (out_size, out_size), message="")

    return nn.Sequential(*modules)


