from pathlib import Path

# all possible values
all_datasets = ['carina', 'mscoco', 'reuters', 'scene', 'cifar10', 'urbansound8k', 'agnews', 'letter']
all_dataset_types = ['complete', '2k']
all_budgets = [2, 20, 200]
all_train_paradigms = ['sl', 'semi-sl', 'augmented']
all_qms = ['badge', 'bald', 'beal', 'kmeans', 'multilabel_simple_crw', 'random', 'ratio_max']
all_weight_inits = ['tl', 'random', 'self-sl']
all_training = ['frozen', 'finetune-last-2', 'finetune-last-5', 'finetune']

# selected values
max_train_samples_2k = 2000
max_train_samples_complete = 1400
random_seeds_2k = 30
random_seeds_complete = 5

datasets = None
dataset_types = None
init_train_samples = None
add_train_samples = None
train_paradigms = None
qms = None
trainings = None
weight_inits = None
nr_processing_pool = None


# use parsed argument configuration
def initialize_config(args):
    global datasets, dataset_types, \
        init_train_samples, add_train_samples, \
        train_paradigms, \
        qms, \
        trainings, \
        weight_inits, \
        nr_processing_pool

    def ensure_list(x, allowed=None):
        lst = x if isinstance(x, list) else [x]
        if allowed is not None:
            lst = [item for item in lst if item in allowed]
        return lst

    datasets = ensure_list(args.datasets, allowed=all_datasets)
    dataset_types = ensure_list(args.dataset_types, allowed=all_dataset_types)
    init_train_samples = ensure_list(args.budgets, allowed=all_budgets)
    add_train_samples = ensure_list(args.budgets, allowed=all_budgets)
    train_paradigms = ensure_list(args.train_paradigms, allowed=all_train_paradigms)
    qms = ensure_list(args.qms, allowed=all_qms)
    weight_inits = ensure_list(args.weight_inits, allowed=all_weight_inits)
    trainings = ensure_list(args.trainings, allowed=all_training)
    nr_processing_pool = args.nr_processing_pool

# training variables
nr_epochs = 50
batch_size = 32
verbose = 0


# df metadata
label_prefix = 'label_'

tag_train = 'train'
tag_validate = 'validate'
tag_evaluate = 'evaluate'
tag_unlabelled = 'unlabelled'

# paths
path_data = Path(__file__).parent / 'data'
path_exp = Path(__file__).parent / 'exp'
path_fig = Path(__file__).parent / 'fig'
path_model = Path(__file__).parent / 'models'

# local paths for pre-processing
path_datasets = Path('C:\\Users', 'blond', 'Datasets')

path_scene = Path(path_datasets, 'Scene')
path_data_scene = Path(path_scene, 'scene_csv.csv')

path_mscoco = Path(path_datasets, 'coco2017')
path_data_mscoco_train = Path(path_mscoco, 'train2017')
path_data_mscoco_test = Path(path_mscoco, 'test2017')
path_label_mscoco = Path(path_mscoco, 'annotations')

path_carina = Path(path_datasets, 'Audio', 'General', 'CARInA_Complete')

path_reuters = Path(path_datasets, 'Reuters-21578 (Text Categorization)')

path_urbansound8k = Path(path_datasets, 'Audio', 'General', 'UrbanSound8K')
path_urbansound8k_metadata = Path(path_urbansound8k, 'metadata', 'UrbanSound8K.csv')

path_agnews = Path(path_datasets, 'AG News Corpus')

path_letter = Path(path_datasets, 'Letter recognition dataset', 'letter-recognition.data')


def path_results(dataset, dataset_type):
    return Path(path_exp, dataset, dataset_type, 'results')


def path_init_metadata(experiment, random_seed):
    return Path(path_results(experiment), f'initial_metadata_seed_{random_seed}.csv')


def get_iteration_col(int_value):
    return f'iteration_{int_value}'


figures = {
    'badge': {
        'label': 'BADGE',
        'color': 'tab:blue',
        'marker': 'x',
        'marker_size': 3,
    },
    'bald': {
        'label': 'BALD',
        'color': 'tab:brown',
        'marker': '1',
        'marker_size': 3,
    },
    'beal': {
        'label': 'BEAL',
        'color': 'tab:green',
        'marker': '+',
        'marker_size': 3,
    },
    'kmeans': {
        'label': 'k-means',
        'color': 'tab:red',
        'marker': 'o',
        'marker_size': 3,
    },
    'multilabel_simple_crw': {
        'label': 'CRW',
        'color': 'tab:purple',
        'marker': '^',
        'marker_size': 3,
    },
    'random': {
        'label': 'random',
        'color': 'tab:gray',
        'marker': 's',
        'marker_size': 3,
    },
    'ratio_max': {
        'label': 'ratio max',
        'color': 'tab:orange',
        'marker': 'D',
        'marker_size': 3,
    },
    'carina': {
        'text': 'CARInA'
    },
    'mscoco': {
        'text': 'MS COCO'
    },
    'reuters': {
        'text': 'Reuters-21578'
    },
    'scene': {
        'text': 'Scene'
    },
    'urbansound8k': {
        'text': 'UrbanSound8k'
    },
    'cifar10': {
        'text': 'CIFAR-10'
    },
    'agnews': {
        'text': 'AG News'
    },
    'letter': {
        'text': 'Letter Recognition'
    },
    'lc_mean': {
        'label': 'LC mean',
        'color': 'tab:pink',
        'marker': 'd',
        'marker_size': 3,
    },
    'lc_aulc_norm': {
        'label': '$\\frac{\\text{AULC}_{\\text{qm}}}{\\text{AULC}_{\\text{rand}}}$',
        'color': 'tab:olive',
        'marker': 'v',
        'marker_size': 3,
    },
    'cutpoints': {
        'label': 'cut-points',
        'color': 'm',
        'marker': 'H',
        'marker_size': 3,
    },
    's_fix': {
        'label': 'S fix',
        'color': 'tab:cyan',
        'marker': 'X',
        'marker_size': 3,
    },
    's': {
        'label': 'S',
        'color': 'r',
        'marker': '$S$',
        'marker_size': 3,
    }
}
