import os
import warnings

import torch
import torch.nn as nn

class ExperimentConfig:
    """
    The configuration for the project.
    Uses a python object to allow for some
    simple scripting and to allow presets
    such as for what the train environment is.
    """
    def __init__(self):
        # torch
        torch.set_num_threads(1)

        # cuda
        self.device = torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu')
        self.device = torch.device('cpu')

        # environment constants
        self.headlight_range = 2
        self.speed_limit = 3

        # train_models.py constants
        self.min_seed = 0
        self.max_seed = 2000
        self.n_epochs = 12800 * 4
        self.batchsize = 800
        self.head_model_seed_amount = 1000
        self.criterion = nn.MSELoss()
        self.randomize_targets_train = False
        self.scale = 10
        self.use_checkpoints = True

        # obtain_policies.py constants
        self.k = [1, 2, 4, 5, 10, 20]
        self.n = 20
        self.gamma = 0.99
        self.theta = 1e-4
        self.resample_models = True
        self.num_iterations = 100

        # etc
        self.num_gif_frames = 50
        self.use_pretrained_models = False
        self.loss_graph_truncation = 300
        self.barplot_title = False

        self.models_dir = os.path.join(os.getcwd(), '../models')
        self.heads_dir = os.path.join(os.getcwd(), '../heads')
        self.checkpoints_dir = os.path.join(os.getcwd(), '../checkpoints')
        self.model_figs_dir = os.path.join(os.getcwd(), '../figs')
        self.head_figs_dir = os.path.join(os.getcwd(), '../figs_head')
        self.policies_dir = os.path.join(os.getcwd(), '../policies')
        self.head_policies_dir = os.path.join(os.getcwd(), '../policies_head')
        self.pkls_git_dir = os.path.join(os.getcwd(), '../pkls_git')


class DryRunConfig(ExperimentConfig):
    """
    The configuration for the project for when
    a quick dry run is desired. Ensures that
    everything will run without failing, at least.
    """
    def __init__(self):
        super(DryRunConfig, self).__init__()
        self.max_seed = 10
        self.n_epochs = 10
        self.head_model_seed_amount = 2
        self.n = 20
        self.theta = 1
        self.resample_models = True
        self.num_iterations = 1
        self.num_gif_frames = 20

if 'DRY_RUN' in os.environ and os.environ['DRY_RUN'] != '0':
    warnings.warn('Running a dry run experiment, because the environment variable \
`DRY_RUN` was not set to 0.')
    Config = DryRunConfig
else:
    Config = ExperimentConfig
