import json
from pathlib import Path

import torch

with open('dataset/cnn_generalization_splits.json', 'r') as f:
    data = json.load(f)

files = data['train']['path']
all_weights = []
all_biases = []

for file in files:
    ckpt = torch.load(file)
    weight_counter = 0
    bias_counter = 0
    for k, v in ckpt.items():
        if 'weight' in k:
            if weight_counter == len(all_weights):
                all_weights.append([])
            all_weights[weight_counter].append(v.flatten())
            weight_counter += 1
        if 'bias' in k:
            if bias_counter == len(all_biases):
                all_biases.append([])
            all_biases[bias_counter].append(v.flatten())
            bias_counter += 1

per_layer_weights_mean = [torch.cat(layer).mean(dim=0).item() for layer in all_weights]
per_layer_weights_std = [torch.cat(layer).std(dim=0).item() for layer in all_weights]
per_layer_biases_mean = [torch.cat(layer).mean(dim=0).item() for layer in all_biases]
per_layer_biases_std = [torch.cat(layer).std(dim=0).item() for layer in all_biases]

statistics = {
    "weights": {"mean": per_layer_weights_mean, "std": per_layer_weights_std},
    "biases": {"mean": per_layer_biases_mean, "std": per_layer_biases_std},
}

out_path = Path("dataset")
out_path.mkdir(exist_ok=True, parents=True)
torch.save(statistics, out_path / "cnn_zoo_statistics.pth")
