import argparse

import numpy as np
import torch


def set_seed(seed):
    np.random.seed(seed)
    torch.manual_seed(seed)


def to_device(gpu):
    if torch.cuda.is_available() and gpu is not None:
        if isinstance(gpu, list):
            gpu = gpu[0]
        return torch.device(f'cuda:{gpu}')
    else:
        return torch.device('cpu')


def parse_args():
    parser = argparse.ArgumentParser()

    # Experimental settings.
    parser.add_argument('--gpu', type=int, nargs='+', default=[0])
    parser.add_argument('--data-path', type=str, default='../data')
    parser.add_argument('--out-path', type=str, default='../out')
    parser.add_argument('--save', action='store_true', default=False)
    parser.add_argument('--explain', action='store_true', default=False)
    parser.add_argument('--data', type=str, default=None)
    parser.add_argument('--model', type=str, default='DTN-S')

    # Model settings.
    parser.add_argument('--layers', type=int, default=16)
    parser.add_argument('--units', type=int, default=128)

    # Hyperparameters for DTN (ignored if model=DTN-S).
    parser.add_argument('--width', type=int, default=None)
    parser.add_argument('--activation', type=str, default='softmax')
    parser.add_argument('--prune', type=int, default=None)

    # Hyperparameters for experiments.
    parser.add_argument('--epochs', type=int, default=200)
    parser.add_argument('--batch-size', type=int, default=128)
    parser.add_argument('--seeds', type=int, nargs='+', default=[0, 1, 2, 3, 4])
    parser.add_argument('--lr', type=float, default=5e-3)

    return parser.parse_args()
