
import re
import os
import sys
import yaml
import torch

class Config(object):

    def __init__(self, args=None):

        self._init_parameters_category()
        self.yaml_loader = self._build_yaml_loader()

        self.model, self.dataset = args.model, args.dataset.strip("'")
        self.trainer = args.trainer
        self._load_internal_config_dict(self.model, self.dataset, self.trainer)
        self._load_args_config_dict(args=args)
        self.final_config_dict = self._get_final_config_dict()
        
        if self.trainer == 'standard_trainer':
            self.final_config_dict['v_input_type'] = 'x'
        else:
            self.final_config_dict['v_input_type'] = 'x_and_z'
        # self._init_device()

    def _init_parameters_category(self):

        curPath = os.path.abspath(os.path.dirname(__file__))
        rootPath = curPath[:curPath.find('iclr2026-main') + len('iclr2026-main')]
        print('now pc root path:',rootPath)
        self.rootPath = rootPath

    def _build_yaml_loader(self):
        loader = yaml.FullLoader
        loader.add_implicit_resolver(
            u'tag:yaml.org,2002:float',
            re.compile(u'''^(?:
             [-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)?
            |[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+)
            |\\.[0-9_]+(?:[eE][-+][0-9]+)?
            |[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]*
            |[-+]?\\.(?:inf|Inf|INF)
            |\\.(?:nan|NaN|NAN))$''', re.X),
            list(u'-+0123456789.'))
        return loader


    def _load_args_config_dict(self,args=None):

        self.args_config_dict = vars(args)
        self.args_config_dict['rootPath'] = self.rootPath
        if self.internal_config_dict['stop']:
            self.args_config_dict['start_order'] = self.internal_config_dict['data_id']
        else:
            self.args_config_dict['start_order'] = 1 

        dataset_dir = os.path.join(self.rootPath, 'dataset', self.dataset.strip("'"))
        self.args_config_dict['end_order'] = len(os.listdir(dataset_dir)) #- 1

    def _load_internal_config_dict(self, model, dataset, trainer):
        current_path = os.path.dirname(os.path.realpath(__file__))
        overall_init_file = os.path.join(current_path, '../properties/overall.yaml')
        model_init_file = os.path.join(current_path, '../properties/model/' + model + '.yaml')
        dataset_init_file = os.path.join(current_path, '../properties/dataset/' + dataset + '.yaml')
        trainer_init_file = os.path.join(current_path, '../properties/trainer/' + trainer + '.yaml')

        self.internal_config_dict = dict()
        for file in [ model_init_file, dataset_init_file, overall_init_file, trainer_init_file]:
            if os.path.isfile(file):
                with open(file, 'r', encoding='utf-8') as f:
                    config_dict = yaml.load(f.read(), Loader=self.yaml_loader)
                    if config_dict is not None:
                        self.internal_config_dict.update(config_dict)


    def _get_final_config_dict(self):
        final_config_dict = dict()
        final_config_dict.update(self.internal_config_dict)
        final_config_dict.update(self.args_config_dict)
        return final_config_dict

    def _init_device(self):
        self.final_config_dict['device'] = torch.device(self.args_config_dict['device'])

    def __setitem__(self, key, value):
        if not isinstance(key, str):
            raise TypeError("index must be a str.")
        self.final_config_dict[key] = value

    def __getitem__(self, item):
        if item in self.final_config_dict:
            return self.final_config_dict[item]
        else:
            return None

    def __contains__(self, key):
        if not isinstance(key, str):
            raise TypeError("index must be a str.")
        return key in self.final_config_dict

    def __str__(self):
        args_info = '[{},{},{}] '.format(self.trainer,self.dataset,self.model)

        args_info +='Hyper Parameters: \n'
        args_info += '\n'.join(
            ["{}={}".format(arg, value)
             for arg, value in self.final_config_dict.items() if arg != 'model' and arg != 'dataset'])
        args_info += '\n\n'
        return args_info

    def __repr__(self):
        return self.__str__()
