import itertools
import string

from _gen_args.common import gen_args, write_to_file


def arg_generator(env_id, aug_function, aug_ratio, lambda_, aug_n):
    mem = 1.2
    disk = 8
    base = 0
    aug_ratio_scaled = (base + lambda_)


    args = f" --env {env_id}" \
           f" -exp {aug_function}/ratio_{aug_ratio}/lambda_{lambda_}/" \
           f" --data-factor {(1 + (1-lambda_))}" \
           f" --aug-function {aug_function} --aug-ratio {aug_ratio_scaled} --aug-n {aug_ratio_scaled*aug_n}"

    if aug_function == 'translate_proximal':
        args = f" --env {env_id}" \
               f" -exp {aug_function}_0/ratio_{aug_ratio}/lambda_{lambda_}/" \
               f" --data-factor {(1 + (1 - lambda_))}" \
               f" --aug-function {aug_function} --aug-function-kwargs p:{0} --aug-ratio {aug_ratio_scaled} --aug-n {aug_ratio_scaled * aug_n}"
    return mem, disk, args

def arg_generator_df(env_id, df):
    mem = 1.2
    disk = 8

    args = f" --env {env_id}" \
           f" -exp no_aug/df_{df}/" \
           f" --data-factor {df}"
    return mem, disk, args

def arg_generator_prox(env_id, aug_function, aug_ratio, aug_n, p):
    mem = 1.2
    disk = 8

    args = f" --env {env_id}" \
           f" -exp {aug_function}/p_{p}/ratio_{aug_ratio}/" \
           f" --aug-function {aug_function} --aug-function-kwargs p:{p} --aug-ratio {aug_ratio} --aug-n {int(aug_n)}" \
           f" --hyperparams train_freq:1 gradient_steps:1"

    return mem, disk, args

def arg_generator_aug_n(env_id, aug_function, aug_ratio, aug_n):
    mem = 1.2
    disk = 8

    args = f" --env {env_id} --algo td3" \
           f" -exp {aug_function}/ratio_{aug_ratio}/n_{aug_n}" \
           f" --aug-function {aug_function}  --aug-ratio {aug_ratio} --aug-n {int(aug_n)}"
    return mem, disk, args

def write_args():

    # env_ids = ['PandaPickAndPlace-v3']
    env_ids = ['Goal2D-v0']
    f = open(f"commands/replay_ratio_goal2d.txt", "w")

    for aug_function in ['translate', 'rotate']:
        for aug_n in [1, 2, 4, 8]:
            for env_id in env_ids:
                args = gen_args(
                    device='cpu',
                    length="short",
                    arg_generator=arg_generator_aug_n,
                    aug_function=aug_function,
                    env_id=env_id,
                    aug_ratio=1,
                    aug_n=aug_n)
                write_to_file(f, args)

if __name__ == "__main__":
    write_args()


