from pathlib import Path

import torch

from experiments.data import INRDataset

train_set = INRDataset(
    path="dataset/fmnist_splits.json",
    split="train",
    statistics_path=None,
    normalize=False,
)

train_loader = torch.utils.data.DataLoader(
    train_set, batch_size=len(train_set), shuffle=False, num_workers=4
)

train_data = next(iter(train_loader))

train_weights_mean = [w.mean().item() for w in train_data.weights]
train_weights_std = [w.std().item() for w in train_data.weights]
train_biases_mean = [w.mean().item() for w in train_data.biases]
train_biases_std = [w.std().item() for w in train_data.biases]

print(f"weights_mean: {train_weights_mean}")
print(f"weights_std: {train_weights_std}")
print(f"biases_mean: {train_biases_mean}")
print(f"biases_std: {train_biases_std}")

dws_weights_mean = [w.mean(0) for w in train_data.weights]
dws_weights_std = [w.std(0) for w in train_data.weights]
dws_biases_mean = [w.mean(0) for w in train_data.biases]
dws_biases_std = [w.std(0) for w in train_data.biases]

statistics = {
    "weights": {"mean": dws_weights_mean, "std": dws_weights_std},
    "biases": {"mean": dws_biases_mean, "std": dws_biases_std},
}

out_path = Path("dataset")
out_path.mkdir(exist_ok=True, parents=True)
torch.save(statistics, out_path / "fmnist_statistics.pth")
