import os
import torch
import torchvision.models as tv_models
from train_classification.resnet import *

print("Creating and saving standard initial weights...")

torch.manual_seed(42)

initial_model_resnet18 = ResNet18WithBN()
# initial_model_resnet18 = better_initialization_resnet18(initial_model_resnet18)

def init_weights(m):
    if isinstance(m, nn.Conv2d):
        nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
    elif isinstance(m, nn.Linear):
        nn.init.normal_(m.weight, 0, 0.01)
        nn.init.constant_(m.bias, 0)

save_path = os.path.join(os.path.dirname(__file__), "..", "initial_weights")
save_path = os.path.abspath(save_path)
os.makedirs(save_path, exist_ok=True)
print(save_path)
torch.save(initial_model_resnet18.state_dict(),
           os.path.join(save_path, "ResNet18_init.pth"))

print(f"All fixed initial weights saved to: {save_path}")