import torch
import hydra
from omegaconf import OmegaConf
import torch_geometric

from experiments.utils import set_seed
from experiments.data_generalization import CNNBatch

set_seed(42)

OmegaConf.register_new_resolver("prod", lambda x, y: x * y)


def wb_to_batch(weights, biases, d0, d1, d2, d3):
    w = weights[-1]
    w_padded = torch.zeros(
        (w.shape[0], w.shape[1], w.shape[2], weights[-2].shape[-1]),
        dtype=w.dtype, device=w.device
    )
    w_padded[..., [4]] = w

    x = torch.cat(
        [
            torch.zeros((biases[0].shape[0], d0, 1), dtype=biases[0].dtype,
                        device=biases[0].device),
            *biases
        ],
        dim=1
    )
    edge_index = torch.cat(
        [
            torch.cartesian_prod(torch.arange(d0), torch.arange(d0, d0+d1)).T,
            torch.cartesian_prod(torch.arange(d0, d0+d1), torch.arange(d0+d1, d0+d1+d2)).T,
            torch.cartesian_prod(torch.arange(d0+d1, d0+d1+d2), torch.arange(d0+d1+d2, d0+d1+d2+d3)).T,
        ],
        dim=1
    )
    edge_attr = torch.cat(
        [
            weights[0].flatten(1, 2),
            weights[1].flatten(1, 2),
            w_padded.flatten(1, 2),
        ],
        dim=1
    )
    batch = torch_geometric.data.Batch.from_data_list(
        [
            torch_geometric.data.Data(
                x=x[i],
                edge_index=edge_index,
                edge_attr=edge_attr[i],
                layer_layout=[d0, d1, d2, d3],
            )
            for i in range(4)
        ]
    )
    return batch

def test_model_invariance(model):
    d0, d1, d2, d3 = 3, 16, 16, 10
    weights = (
        torch.randn(4, d0, d1, 9),
        torch.randn(4, d1, d2, 9),
        torch.randn(4, d2, d3, 1),
    )

    biases = (
        torch.randn(4, d1, 1),
        torch.randn(4, d2, 1),
        torch.randn(4, d3, 1),
    )

    batch = wb_to_batch(weights, biases, d0, d1, d2, d3)
    out = model(batch)
    # perm test
    perm1 = torch.randperm(d1)
    perm2 = torch.randperm(d2)

    perm_weights = (
        weights[0][:, :, perm1, :],
        weights[1][:, perm1, :, :][:, :, perm2, :],
        weights[2][:, perm2, :, :],
    )
    perm_biases = (
        biases[0][:, perm1, :],
        biases[1][:, perm2, :],
        biases[2],
    )
    perm_batch = wb_to_batch(perm_weights, perm_biases, d0, d1, d2, d3)

    out_perm = model(perm_batch)

    assert torch.allclose(out, out_perm, atol=1e-5, rtol=0)
    return out, out_perm


@hydra.main(config_path="configs", config_name="base", version_base=None)
def main(cfg):
    model = hydra.utils.instantiate(cfg.model)
    test_model_invariance(model)


if __name__ == "__main__":
    main()
