import torch
import argparse
import numpy as np

from BC import TrainConfig, run_BC
from utils import get_setting

CUDA_AVAILABLE = torch.cuda.is_available()
if CUDA_AVAILABLE:
    DEVICE = 'cuda'
else:
    DEVICE = 'cpu'


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--setting', type=int, default=0)
    args = parser.parse_args()
    setting = args.setting

    settings = [
        'seed', 'S', [0, 1, 2],
        'env', 'E', ['swimmer', 'reacher', 'hopper'],

        'max_epochs', '', [0],  # will be replaced later
        'batch_size', '', [256],
        'data_size', '', [0],  # will be replaced later
        'lamW', 'W', np.logspace(-5, -1, 9, endpoint=True).tolist(),
        'whitening', 'Wh', ['none', 'whiten', 'normalize'],
        'arch', '', ['256-R-256-R-256-R|T'],

        'optimizer', '', ['sgd'],
        'lr', 'lr', [1e-2],
        'eval_freq', '', [100],

        'single_task', 'ST', [0, 1, 2, None],  # None for multitask (will silence invalid configs later)
    ]

    actual_setting = get_setting(settings, setting)

    """replace values"""
    config = TrainConfig(**actual_setting)

    if config.env == 'swimmer':
        config.max_epochs = int(2e5)
        config.data_size = 1000
    elif config.env == 'reacher':
        config.max_epochs = int(2e5)
        config.lamW = 1.5 * config.lamW
        config.data_size = 1000
    elif config.env == 'hopper':
        config.max_epochs = 40000
        config.data_size = 10000

    config.device = DEVICE

    config.data_folder = './dataset/mujoco/'

    # Raise Error if the config is invalid (skip invalid config runs)
    if config.single_task == 2 and config.env != 'hopper':
        raise ValueError('single_task=2 is only valid for hopper')
    if config.single_task is not None and config.whitening != 'none':
        raise ValueError('we do not run experiments with whitening for single tasks')

    run_BC(config)


if __name__ == '__main__':
    main()
