import jax.numpy as jnp
from evosax import Strategies


def get_strategy_and_params(popsize, num_dims, part_size, padding, args):
    STRATEGY = {
        'MahiES': {'popsize': popsize, 'num_dims': num_dims, 'opt_name': args.opt_name},
        'OpenES': {'popsize': popsize, 'num_dims': num_dims, 'opt_name': args.opt_name},
        'PGPE': {'popsize': popsize, 'num_dims': num_dims, 'opt_name': args.opt_name},
        'CMA_ES': {'popsize': popsize, 'num_dims': num_dims, 'opt_name': args.opt_name},
        'SoES': {'popsize': popsize, 'num_dims': num_dims, 'part_size': part_size, 'padding': padding},
        'SoGradES': {'popsize': popsize, 'num_dims': num_dims, 'part_size': part_size, 'padding': padding},
    }
    strategy = Strategies[args.strategy](**STRATEGY[args.strategy])
    es_params = strategy.default_params
    es_params = strategy.update_params(es_params, args)
    return strategy, es_params
