import argparse
import os
import pickle
from datetime import datetime

import torch
import yaml


def get_device():
    if torch.cuda.is_available():
        print("Running on GPU.")
        device = torch.device("cuda")
    else:
        print("Running on CPU.")
        device = torch.device("cpu")
    return device


def parse_args():
    description = 'Parses configuration file with paths to inputs and outputs.'
    parser = argparse.ArgumentParser(description)

    parser.add_argument('--config', type=str, help='Path to config.yaml')

    # Different arguments are used by different scripts:
    parser.add_argument('--size', type=int,
                        help='Size to resize THINGS images.')

    parser.add_argument('--initial_size', type=int,
                        help='Initial size to resize THINGS images.')
    parser.add_argument('--final_size', type=int,
                        help='Final size to resize THINGS images.')
    parser.add_argument('--step_size', type=int,
                        help='Step size to resize THINGS images.')
    parser.add_argument('--k', type=int, help='Number of neighbors in the MLE'
                                              'method to calculate ID.')

    parser.add_argument('--subject_ids', nargs='+', type=int,
                        help='List of subjects to be processed.')
    parser.add_argument('--hemis', nargs='+', type=str,
                        help='List of hemispheres to be processed.')
    parser.add_argument('--num_neighbors', type=int,
                        help='Number of voxel neighbors to calculate '
                             'ED and ID.')

    parser.add_argument('--trial_types', nargs='+', type=str,
                        help='List of trials types from train and test.')
    parser.add_argument('--num_components', type=int,
                        help='Number of PCA components to save.')

    parser.add_argument('--model_name', type=str,
                        help='Model name to get the embeddings.')

    args = parser.parse_args()

    return args


def log(message):
    time = datetime.now().strftime("[%H:%M:%S]")
    print(f'{time} {message}')


def load_paths(path_config):
    assert path_config is not None, 'You need to pass the configuration file path!'
    base_dir = os.path.dirname(os.path.abspath(path_config))

    with open(path_config, 'r') as file:
        paths = yaml.safe_load(file)

    for key, path in paths['raw'].items():
        paths['raw'][key] = os.path.join(base_dir,
                                         os.path.expanduser(path))
        # assert os.path.exists(paths['raw'][key]), f'{paths["raw"][key]} does not exist!'

    for key, path in paths['processed'].items():
        paths['processed'][key] = os.path.join(base_dir,
                                               os.path.expanduser(path))

    return paths


def load_pickle(path_pickle):
    with open(path_pickle, 'rb') as f:
        data = pickle.load(f)
    return data


def save_pickle(data, path_pickle):
    with open(path_pickle, 'wb') as f:
        pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL)
