
import re
import os
import sys
import yaml
import torch

class Config(object):

    def __init__(self, args=None):

        self._init_parameters_category()
        self.yaml_loader = self._build_yaml_loader()

        self.model, self.dataset = args.model, args.dataset
        self._load_internal_config_dict(self.model, self.dataset)
        self._load_args_config_dict(args=args)
        self.final_config_dict = self._get_final_config_dict()

        self._init_device()

    def _init_parameters_category(self):

        curPath = os.path.abspath(os.path.dirname(__file__))
        rootPath = curPath[:curPath.find('Causally_new') + len('Causally_new')]

        self.rootPath = rootPath

    def _build_yaml_loader(self):
        loader = yaml.FullLoader
        loader.add_implicit_resolver(
            u'tag:yaml.org,2002:float',
            re.compile(u'''^(?:
             [-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)?
            |[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+)
            |\\.[0-9_]+(?:[eE][-+][0-9]+)?
            |[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]*
            |[-+]?\\.(?:inf|Inf|INF)
            |\\.(?:nan|NaN|NAN))$''', re.X),
            list(u'-+0123456789.'))
        return loader


    def _load_args_config_dict(self,args=None):

        self.args_config_dict = vars(args)
        self.args_config_dict['rootPath'] = self.rootPath
        if self.internal_config_dict['stop']:
            self.args_config_dict['start_order'] = self.internal_config_dict['data_id']
        else:
            self.args_config_dict['start_order'] = 1
        self.args_config_dict['end_order'] = len(os.listdir(
            os.path.join(self.rootPath,'dataset/'+self.dataset)))-1

    def _load_internal_config_dict(self, model, dataset):
        current_path = os.path.dirname(os.path.realpath(__file__))
        overall_init_file = os.path.join(current_path, '../properties/overall.yaml')
        model_init_file = os.path.join(current_path, '../properties/model/' + model + '.yaml')
        dataset_init_file = os.path.join(current_path, '../properties/dataset/' + dataset + '.yaml')

        self.internal_config_dict = dict()
        for file in [ model_init_file, dataset_init_file,overall_init_file]:
            if os.path.isfile(file):
                with open(file, 'r', encoding='utf-8') as f:
                    config_dict = yaml.load(f.read(), Loader=self.yaml_loader)
                    if config_dict is not None:
                        self.internal_config_dict.update(config_dict)


    def _get_final_config_dict(self):
        final_config_dict = dict()
        final_config_dict.update(self.internal_config_dict)
        final_config_dict.update(self.args_config_dict)
        return final_config_dict

    def _init_device(self):

        self.final_config_dict['device'] = torch.device("cuda" if torch.cuda.is_available()  else "cpu")

    def __setitem__(self, key, value):
        if not isinstance(key, str):
            raise TypeError("index must be a str.")
        self.final_config_dict[key] = value

    def __getitem__(self, item):
        if item in self.final_config_dict:
            return self.final_config_dict[item]
        else:
            return None

    def __contains__(self, key):
        if not isinstance(key, str):
            raise TypeError("index must be a str.")
        return key in self.final_config_dict

    def __str__(self):
        args_info = '[{},{}] '.format(self.model,self.dataset)

        args_info +='Hyper Parameters: \n'
        args_info += '\n'.join(
            ["{}={}".format(arg, value)
             for arg, value in self.final_config_dict.items() if arg != 'model' and arg != 'dataset'])
        args_info += '\n\n'
        return args_info

    def __repr__(self):
        return self.__str__()
    def get(self,key):
        return self.final_config_dict[key]