import torch.nn as nn
from torchinfo import summary


def has_uneven_padding(model: nn.Module, input_size: tuple):
    summary_info = summary(model, input_size, verbose=0)
    print(summary_info)
    for layer_info in summary_info.summary_list:
        if isinstance(layer_info.module, nn.Conv2d):
            conv_module = layer_info.module
            if any(s > 1 for s in conv_module.stride):
                *_, h_i1, w_i1 = layer_info.input_size
                *_, h_i, w_i = layer_info.output_size
                h_i_line = h_i1 + 2 * conv_module.padding[0]
                w_i_line = w_i1 + 2 * conv_module.padding[1]

                h_i_hat = conv_module.stride[0] * (h_i - 1) + conv_module.kernel_size[0]
                w_i_hat = conv_module.stride[1] * (w_i - 1) + conv_module.kernel_size[1]

                if h_i_line != h_i_hat or w_i_line != w_i_hat:
                    print('layer', layer_info.var_name)
                    print(f'{h_i_line = }', f'{h_i_hat = }', f'{w_i_line = }', f'{w_i_hat = }')
                    return True
    return False


def main():
    from mind_the_pad.train_mnist import build_model
    model = build_model('zeros')
    for d in range(2):
        if not has_uneven_padding(model, (1, 1, 28 + d, 28 + d)):
            print(28 + d, 'is a good solution')
            return
        else:
            print(28 + d, 'is not a good solution')


if __name__ == '__main__':
    main()
