# %%
import torch
import matplotlib.pyplot as plt

from experiments.utils import make_coordinates
from experiments.data import INRDataset

from einops import rearrange


# https://openreview.net/pdf?id=cFuMmbWiN6


# %%
def batch_to_graphs(batch):
    # assume graph has (2 + 32 + 32 + 1) [features] + (32 + 32 + 1) [bias] = 132 nodes
    device = batch.weights[0].device
    bsz = len(batch)
    num_nodes = batch.weights[0].shape[1] + sum(w.shape[2] for w in batch.weights)

    node_features = torch.zeros(bsz, num_nodes, 1, device=device)
    edge_features = torch.zeros(bsz, num_nodes, num_nodes, 1, device=device)

    row_offset = 0
    col_offset = batch.weights[0].shape[1]  # no edge to input nodes
    for i, w in enumerate(batch.weights):
        _, num_in, num_out, _ = w.shape
        edge_features[
            :, row_offset : row_offset + num_in, col_offset : col_offset + num_out
        ] = w
        row_offset += num_in
        col_offset += num_out

    row_offset = batch.weights[0].shape[1]  # no bias in input nodes
    for i, b in enumerate(batch.biases):
        _, num_out, _ = b.shape
        node_features[:, row_offset : row_offset + num_out] = b
        row_offset += num_in

    return

    # row_offset += batch.weights[-1].shape[2]  # add output dim of last layer
    # for i, b in enumerate(batch.biases):
    #     _, num_out = b.shape
    #     edge_features[
    #         :, row_offset : row_offset + num_in, col_offset : col_offset + num_out
    #     ] = b
    #     row_offset += num_in
    #     col_offset += num_out


# %%
# data_path = "mnist-inrs/mnist_png_training_7_34890/checkpoints/model_final.pth"
# state_dict = torch.load(data_path, map_location="cpu")

train_set = INRDataset(
    path="mnist_splits.json",
    split="train",
    normalize=True,
    augmentation=True,
    permutation=False,
    statistics_path="experiments/mnist/dataset/statistics.pth",
)

train_loader = torch.utils.data.DataLoader(
    dataset=train_set,
    batch_size=512,
    shuffle=True,
    num_workers=0,
    pin_memory=True,
)
batch = next(iter(train_loader))

# %%
