import torch
import yaml
import time
from models.two_layer_neural_network import TwoLayerNet, CIFAR10CNN
from data_generator.data_generator import data_generator, get_cifar10_loaders
from utils.data_loader import get_data_loaders
from utils.trainer import train
from utils.utils import init_all_params
from utils.utils import get_local_time_str
from utils.plot_utils import plot_training_curves
import os
from utils.utils import save_model


start_time = time.time()


with open("configs/default.yaml", "r") as f:
    config = yaml.safe_load(f)


if config["experiment_name"] is None:
    experiment_name = 'exp_'+ get_local_time_str()
else:
    experiment_name = config["experiment_name"]
save_dir = os.path.join(config["save_dir"], experiment_name)

device = torch.device(config["device"] if torch.cuda.is_available() else "cpu")


train_dataset, train_iter, test_dataset, test_iter = get_cifar10_loaders(config["batch_size"], config["num_samples"])

model = CIFAR10CNN()
# quit()
if config["init_gamma"] is not None:
    print("Initializing all parameters with gamma = ", config["init_gamma"])
    model = init_all_params(model, config["hidden_size"], config["init_gamma"])
model = train(model, train_iter, device, save_dir, config, test_iter=None)


save_model(model, os.path.join(save_dir, "model.pth"))

plot_training_curves(save_dir, config["plot_file"]+".png")

end_time = time.time()


print(f"Execution time: {end_time - start_time:.6f} seconds")