import torch
import yaml
import time
from models.two_layer_neural_network import TwoLayerNet
from data_generator.data_generator import data_generator
from utils.data_loader import get_data_loaders
from utils.trainer import train
from utils.utils import init_all_params
from utils.utils import get_local_time_str
from utils.plot_utils import plot_training_curves
import os
from utils.utils import save_model


start_time = time.time()


with open("configs/default.yaml", "r") as f:
    config = yaml.safe_load(f)


if config["experiment_name"] is None:
    experiment_name = 'exp_'+ get_local_time_str()
else:
    experiment_name = config["experiment_name"]
save_dir = os.path.join(config["save_dir"], experiment_name)

device = torch.device(config["device"] if torch.cuda.is_available() else "cpu")

features_train, labels_train = data_generator(num_samples=config["num_samples"], input_dim=config["input_size"])
train_dataset, train_iter = get_data_loaders(config["batch_size"], features_train, labels_train)

features_test, labels_test = data_generator(num_samples=config["test_num_samples"], input_dim=config["input_size"])
test_dataset, test_iter = get_data_loaders(config["test_batch_size"], features_test, labels_test)

model = TwoLayerNet(config["input_size"], config["hidden_size"], config["output_size"])
    
if config["init_gamma"] is not None:
    print("Initializing all parameters with gamma = ", config["init_gamma"])
    model = init_all_params(model, config["hidden_size"], config["init_gamma"])
    
# model = train(model, train_iter, device, save_dir, config, test_iter=test_iter)
model = train(model, train_iter, device, save_dir, config, test_iter=None)


save_model(model, os.path.join(save_dir, "model.pth"))

plot_training_curves(save_dir, config["plot_file"]+".png")

end_time = time.time()


print(f"Execution time: {end_time - start_time:.6f} seconds")