import sys
from yaml import dump
from os import path
import Model_Based.utils as utils
import numpy as np
import torch
from collections import OrderedDict

class Config(object):
    def __init__(self, args):

        # SET UP PATHS
        self.paths = OrderedDict()
        self.paths['root'] = path.abspath(path.join(path.dirname(__file__), '..'))

        # Do Hyper-parameter sweep, if needed
        self.idx = args.base + args.inc
        if self.idx >= 0 and args.hyper >= 0:
            self.hyperparam_sweep = utils.dynamic_load(self.paths['root'], "random_search" + str(args.hyper), load_class=False)
            self.hyperparam_sweep.set(args, self.idx)
            del self.hyperparam_sweep  # *IMP: CanNOT deepcopy an object having reference to an imported library (DPG)

        # Make results reproducible
        seed = args.seed
        np.random.seed(seed)
        torch.manual_seed(seed)

        # Copy all the variables from args to config
        self.__dict__.update(vars(args))

        # Frequency of saving results and models.
        self.save_after = args.max_episodes // args.save_count if args.max_episodes > args.save_count else args.max_episodes

        # add path to models
        folder_suffix = args.experiment + args.folder_suffix
        self.paths['Experiments'] = path.join(self.paths['root'], 'Model_Based_Experiments')
        self.paths['experiment'] = path.join(self.paths['Experiments'], folder_suffix)

        path_prefix = [self.paths['experiment'], str(args.seed)]
        self.paths['logs'] = path.join(*path_prefix, 'Logs/')
        self.paths['ckpt'] = path.join(*path_prefix, 'Checkpoints/')
        self.paths['results'] = path.join(*path_prefix, 'Results/')
        
        if args.domain == 'land':
            self.paths['dataset'] = './data_2D.xlsx'
        else:
            self.paths['dataset'] = './data_3D.xlsx'

        # Create directories
        for (key, val) in self.paths.items():
            if key not in ['root', 'dataset', 'data']:
                utils.create_directory_tree(val)

        # Save the all the configuration settings
        dump(args.__dict__, open(path.join(self.paths['experiment'], 'args.yaml'), 'w'), default_flow_style=False,
             explicit_start=True)

        # Output logging
        sys.stdout = utils.Logger(self.paths['logs'], args.restore, args.log_output)


        # GPU
        self.device = torch.device('cuda' if args.gpu else 'cpu')

        # optimizer
        if args.optim == 'adam':
            self.optim = torch.optim.Adam
        elif args.optim == 'rmsprop':
            self.optim = torch.optim.RMSprop
        elif args.optim == 'sgd':
            self.optim = torch.optim.SGD
        else:
            raise ValueError('Undefined type of optmizer')

        print("=====Configurations=====\n", args)


if __name__ == '__main__':
    pass