import logging
import os
from datetime import datetime as dt

import numpy as np
import torch

from arguments import get_train_args
from config import RESULTS_DIR
from src.models.train import train_and_inference_loop
from src.utils.common import create_folders, create_experiment


def main():
    args = get_train_args()
    args_dict = vars(args)

    torch.manual_seed(args_dict['seed'])
    np.random.seed(args_dict['seed'])
    torch.cuda.manual_seed(args_dict['seed'])

    if experiment_dir := args_dict['experiment_dir']:
        path = os.path.join(RESULTS_DIR, experiment_dir)
        args_dict = np.load(os.path.join(path, 'config.npy'), allow_pickle=True).item()
        args_dict['experiment_dir'] = experiment_dir
    else:
        path = create_folders(args_dict)
        np.save(os.path.join(path, 'config.npy'), args_dict)

    experiment = create_experiment(args_dict)
    experiment.path = path
    time = f"{dt.now().strftime('%Y-%m-%d_%H-%M')}"
    logging.basicConfig(level=logging.INFO, filename=os.path.join(path, f'{time}_log.txt'), filemode='w')

    train_and_inference_loop(args_dict, experiment)
    # k_fold_train_and_inference_loop(args_dict, experiment, k_folds=5)


if __name__ == '__main__':
    main()
