import itertools
import string

import numpy as np

from _gen_args.common import gen_args, write_to_file

DISK = 7
#--run-id-offset 20

def arg_generator(env_id, aug_function, aug_ratio, lambda_, aug_n, mem, rr=1):
    disk = 9
    # if np.isclose(lambda_, 0):
    #     # train_freq = 2
    #     n_extra_collect = 0
    #     extra_collect_freq = 0
    # elif np.isclose(lambda_, 1/3):
    #     # train_freq = 1
    #     n_extra_collect = 2
    #     extra_collect_freq = 3
    # elif np.isclose(lambda_, 2/3):
    #     # train_freq = 1
    #     n_extra_collect = 1
    #     extra_collect_freq = 3
    # elif np.isclose(lambda_, 1):
    #     # train_freq = 1
    #     n_extra_collect = 0
    #     extra_collect_freq = 0
    # else:
    #     raise NotImplementedError

    n_base = int(500e3)
    n = int(np.round(n_base*(2-lambda_)))
    df = 2 - lambda_

    args = f" --env {env_id} --algo td3" \
           f" -exp {aug_function}_0/ratio_{aug_ratio}/lambda_{lambda_:.2f}/" \
           f" --aug-function {aug_function} --aug-function-kwargs p:0 --aug-ratio {lambda_} --aug-n {rr*aug_ratio*lambda_}" \
           f" --data-factor {df} --hyperparams train_freq:60 gradient_steps:30"
           # f" --hyperparams extra_collect_info:{n_extra_collect},{extra_collect_freq}"
    # if batch_size:
    #     args += f" batch_size:{batch_size}"

    mem = 1
    return mem, disk, args

def write_args():

    env_ids = ['Goal2D-v0']

    batch_size = None
    for aug_function in ['translate_proximal']:
        f = open(f"commands/coverage_goal2d.txt", "w")
        for env_id in env_ids:
            for lambda_ in [0, 1/3, 2/3, 1]:
                args = gen_args(
                    device='cpu',
                    length="short",
                    arg_generator=arg_generator,
                    aug_function=aug_function,
                    env_id=env_id,
                    aug_ratio=1,
                    aug_n=1,
                    mem=1.5,
                    lambda_=lambda_)
                write_to_file(f, args)
    #
    # for aug_function in ['translate_goal_proximal_0', 'translate_object_proximal_0', 'coda_proximal_0']:
    #     f = open(f"args/coverage_{aug_function}_cuda.txt", "w")
    #     for env_id in env_ids:
    #         for lambda_ in [1/3, 2/3]:
    #             args = gen_args(
    #                 device='cuda',
    #                 length="short",
    #                 arg_generator=arg_generator,
    #                 aug_function=aug_function,
    #                 env_id=env_id,
    #                 aug_ratio=1,
    #                 aug_n=1,
    #                 batch_size=batch_size,
    #                 mem=4,
    #                 lambda_=lambda_)
    #             write_to_file(f, args)

    # env_ids = ['PandaPush-v3', 'PandaSlide-v3']
    # batch_size = 512
    # for aug_function in ['translate_goal_proximal_0', 'translate_object_proximal_0']:
    #     f = open(f"args/coverage_rr0.5_b{batch_size}.txt", "w")
    #     for env_id in env_ids:
    #         for lambda_ in [0, 1/3, 2/3, 1]:
    #             args = gen_args(
    #                 device='cpu',
    #                 length="short",
    #                 arg_generator=arg_generator,
    #                 aug_function=aug_function,
    #                 env_id=env_id,
    #                 aug_ratio=1,
    #                 aug_n=1,
    #                 mem=2,
    #                 batch_size=batch_size,
    #                 lambda_=lambda_)
    #             write_to_file(f, args)
    #
    # env_ids = ['PandaPush-v3', 'PandaSlide-v3', 'PandaFlip-v3', 'PandaPickAndPlace-v3']
    # f = open(f"args/coverage_rr1.txt", "w")
    # batch_size = None
    # for aug_function in ['translate_goal_proximal_0', 'translate_object_proximal_0']:
    #     f = open(f"args/coverage_rr0.5_b{batch_size}.txt", "w")
    #     for env_id in env_ids:
    #         for lambda_ in [0, 1/3, 2/3, 1]:
    #             args = gen_args(
    #                 device='cuda',
    #                 length="short",
    #                 arg_generator=arg_generator,
    #                 aug_function=aug_function,
    #                 env_id=env_id,
    #                 aug_ratio=1,
    #                 aug_n=1,
    #                 train_freq=1,
    #                 batch_size=batch_size,
    #                 mem=4,
    #                 lambda_=lambda_)
    #             write_to_file(f, args)

if __name__ == "__main__":

    write_args()


