import torch
import hydra

from experiments.utils import set_seed

set_seed(42)


def test_model_invariance(model):
    d0, d1, d2, d3, d4, d5 = 2, 32, 32, 32, 32, 3
    weights = (
        torch.randn(4, d0, d1, 1),
        torch.randn(4, d1, d2, 1),
        torch.randn(4, d2, d3, 1),
        torch.randn(4, d3, d4, 1),
        torch.randn(4, d4, d5, 1),
    )
    biases = (
        torch.randn(4, d1, 1),
        torch.randn(4, d2, 1),
        torch.randn(4, d3, 1),
        torch.randn(4, d4, 1),
        torch.randn(4, d5, 1),
    )

    out = model((weights, biases))
    # perm test
    perm1 = torch.randperm(d1)
    perm2 = torch.randperm(d2)
    perm3 = torch.randperm(d3)
    perm4 = torch.randperm(d4)
    out_perm = model(
        (
            (
                weights[0][:, :, perm1, :],
                weights[1][:, perm1, :, :][:, :, perm2, :],
                weights[2][:, perm2, :, :][:, :, perm3, :],
                weights[3][:, perm3, :, :][:, :, perm4, :],
                weights[4][:, perm4, :, :],
            ),
            (
                biases[0][:, perm1, :],
                biases[1][:, perm2, :],
                biases[2][:, perm3, :],
                biases[3][:, perm4, :],
                biases[4],
            ),
        )
    )

    assert torch.allclose(out, out_perm, atol=1e-5, rtol=0)
    return out, out_perm


@hydra.main(config_path="configs", config_name="base", version_base=None)
def main(cfg):
    model = hydra.utils.instantiate(cfg.model, layer_layout=(2, 32, 32, 32, 32, 3))
    test_model_invariance(model)


if __name__ == "__main__":
    main()
