
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "1"
import torch
import json
import numpy as np
from model import get_model  # Required: model.py must implement the 'get_model' function
from options import options


args = options().parse_args()
print(args)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Set seeds
seeds = [0, 1, 2]

# Define save directory
save_dir = './saved_models/corr_resnet'
output_file = 'l2_norms.json'

# Dictionary to store L2 norms
l2_norms = {}

# Iterate over each seed
for seed in seeds:
    seed_dir = os.path.join(save_dir, str(seed))
    
    # Get all model filenames
    model_files = [f for f in os.listdir(seed_dir) if f.endswith('.pth')]
    
    # Iterate over each model
    for model_file in model_files:
        model_path = os.path.join(seed_dir, model_file)
        
        # Load model
        model = get_model(args, device='cuda' if torch.cuda.is_available() else 'cpu')  
        model.load_state_dict(torch.load(model_path))
        
        # Calculate L2 norm
        l2_norm = 0.0
        for param in model.parameters():
            l2_norm += torch.sum(param.data ** 2).item()
        l2_norm = np.sqrt(l2_norm)
        
        # Save results to dictionary
        model_key = f"seed_{seed}_{model_file}"
        l2_norms[model_key] = l2_norm
        print(f"Model {model_key} L2 norm: {l2_norm:.4f}")

# Save results to JSON file
with open(output_file, 'w') as f:
    json.dump(l2_norms, f, indent=4)

print(f"All models' L2 norms have been saved to {output_file}")
