import sys
sys.path.append("..")
from globa_utils import PathConfig


class CVConfig(PathConfig):
    def __init__(self):
        super(CVConfig, self).__init__(cfg_path="path_config.yaml")

    def get_dataset_path(self):
        return self.cfg['dataset_path']

    def get_data_pool_path(self):
        return self.cfg['data_pool_path']

    def get_fe_path(self):
        return self.cfg['feature_extractor_save_path']

    def get_distribution_path(self):
        return self.cfg['class_distri_save_path']

    def get_cifar10_dataset_path(self):
        return self.get_dataset_path()

    def get_cifar100_dataset_path(self):
        return self.get_dataset_path()

    def get_cifar10_fe_path(self):
        return self.get_fe_path()+'cifar10/model.th'

    def get_cifar100_fe_path(self):
        return self.get_fe_path() + 'cifar100/resnet20/model.th'

    def get_cifar10_distribution_save_path(self):
        return self.get_distribution_path() + 'cifar10/'

    def get_cifar100_distribution_save_path(self):
        return self.get_distribution_path() + 'cifar100/'

    def get_global_random_seed(self, dst_name):
        seed_dict = self.cfg['global_random_seed']
        return seed_dict[dst_name]
        # return

    def get_kfoldCV_seed(self):
        return self.cfg['kfoldCV_seed']

    def get_cifar10_data_pool_path(self):
        return self.get_data_pool_path() + 'seed0/cifar10_pool/'

    def get_cifar100_data_pool_path(self):
        return self.get_data_pool_path() + 'cifar100_pool_v5/'

    def get_data_pool_info_path(self):
        return self.cfg['augment_data_info_path']

    # def get_cifar10_data_pool_info(self):
    #     return self.get_data_pool_info_path() + 'cifar10/seed0/cifar10_info_kfold2_resnet20.pkl'
    #
    # def get_cifar100_data_pool_info(self):
    #     return self.get_data_pool_info_path() + 'cifar100/cifar100_info.pkl' # v4
    def get_cifar10_data_pool_info(self):
        return self.get_data_pool_info_path() + 'cifar10/seed0/cifar10_info_kfold3_resnet20.pkl'

    def get_cifar100_data_pool_info(self):
        return self.get_data_pool_info_path() + 'cifar100/cifar100_info_kfold3_resnet20.pkl' # v4

    # def get_val_index_path(self):
    #     return self.cfg['val_index_path']
    #
    # def get_cifar10_val_index_path(self, sample_num_per_class=500):
    #     return self.get_val_index_path() + 'val_cifar10_'+str(sample_num_per_class)+'.pkl'
    #
    # def get_cifar100_val_index_path(self, sample_num_per_class=50):
    #     return  self.get_val_index_path() + 'val_cifar100_'+str(sample_num_per_class)+'.pkl'


if __name__ == '__main__':
    print(CVConfig().get_cifar100_dataset_path())