import torch
import argparse
import numpy as np

from BC import TrainConfig, run_BC
from utils import get_setting

CUDA_AVAILABLE = torch.cuda.is_available()
if CUDA_AVAILABLE:
    DEVICE = 'cuda'
else:
    DEVICE = 'cpu'


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--setting', type=int, default=0)
    args = parser.parse_args()
    setting = args.setting

    settings = [
        'seed', 'S', [0, 1],
        'env', 'E', ['carla2d'],

        'max_epochs', '', [100],
        'batch_size', '', [512],
        'data_size', '', [600000],
        'lamW', 'W', np.logspace(-5, -1, 9, endpoint=True).tolist(),
        'whitening', 'Wh', ['none', 'whiten', 'normalize'],
        'arch', '', ['resnet18'],

        'optimizer', '', ['sgd'],
        'lr', 'lr', [1e-3],
        'eval_freq', '', [1],

        'single_task', 'ST', [0, 1, None],  # None for multitask (will silence invalid configs later)
    ]

    actual_setting = get_setting(settings, setting)

    """replace values"""
    config = TrainConfig(**actual_setting)

    config.device = DEVICE

    config.data_folder = './dataset/carla2d/'

    # Raise Error if the config is invalid (skip invalid config runs)
    if config.single_task is not None and config.whitening != 'none':
        raise ValueError('we do not run experiments with whitening for single tasks')

    run_BC(config)


if __name__ == '__main__':
    main()
