import itertools
import string

from _gen_args.common import gen_args, write_to_file


def arg_generator(env_id, aug_function, aug_ratio, lambda_, aug_n):
    mem = 1.2
    disk = 8
    base = 0
    aug_ratio_scaled = (base + lambda_)


    args = f" --env {env_id}" \
           f" -exp {aug_function}/ratio_{aug_ratio}/lambda_{lambda_}/" \
           f" --data-factor {(1 + (1-lambda_))}" \
           f" --aug-function {aug_function} --aug-ratio {aug_ratio_scaled} --aug-n {aug_ratio_scaled*aug_n}"

    if aug_function == 'translate_proximal':
        args = f" --env {env_id}" \
               f" -exp {aug_function}_0/ratio_{aug_ratio}/lambda_{lambda_}/" \
               f" --data-factor {(1 + (1 - lambda_))}" \
               f" --aug-function {aug_function} --aug-function-kwargs p:{0} --aug-ratio {aug_ratio_scaled} --aug-n {aug_ratio_scaled * aug_n}"
    return mem, disk, args

def arg_generator_df(env_id, df):
    mem = 1.2
    disk = 8

    args = f" --env {env_id}" \
           f" -exp no_aug/df_{df}/" \
           f" --data-factor {df}"
    return mem, disk, args

def arg_generator_prox(env_id, aug_function, aug_ratio, aug_n, p):
    mem = 1.2
    disk = 8

    args = f" --env {env_id}" \
           f" -exp {aug_function}/p_{p}/ratio_{aug_ratio}/" \
           f" --aug-function {aug_function} --aug-function-kwargs p:{p} --aug-ratio {aug_ratio} --aug-n {int(aug_n)}" \
           f" --hyperparams train_freq:1 gradient_steps:1"

    return mem, disk, args

def write_args():

    # env_ids = ['PandaPickAndPlace-v3']
    env_ids = ['Goal2D-v0']

    f = open(f"commands/reward_density_goal2d.txt", "w")

    for env_id in env_ids:
            for p in [0, 0.001, 0.01, 0.1, 0.25, 0.5, 0.75, 1]:
                args = gen_args(
                    device='cpu',
                    length="short",
                    arg_generator=arg_generator_prox,
                    aug_function='translate_proximal',
                    env_id=env_id,
                    aug_ratio=1,
                    aug_n=1,
                    p=p)
                write_to_file(f, args)


if __name__ == "__main__":
    write_args()


