import pandas as pd
from drqv2_invariance.config import Args, Agent
from ml_logger import RUN
from params_proto.neo_hyper import Sweep
from io import StringIO

hyperparams = pd.read_csv('pretrain_hyperparams.csv')

with Sweep(RUN, Args, Agent).product as sweep:
    with sweep.chain:
        with sweep.zip:
            Args.train_env = hyperparams['env_name'].tolist()
            Args.eval_env = hyperparams['env_name'].tolist()
            Args.train_frames = hyperparams['train_frames'].tolist()
            Args.replay_buffer_size = hyperparams['replay_buffer_size'].tolist()
            Args.batch_size = hyperparams['batch_size'].tolist()
            Args.nstep = hyperparams['nstep'].tolist()
            Agent.stddev_schedule = hyperparams['stddev_schedule'].tolist()
            Agent.lr = hyperparams['lr'].tolist()
            Agent.feature_dim = hyperparams['feature_dim'].tolist()

    Args.seed = [(i + 1) * 100 for i in range(5)]  # 10 random seeds


# NOTE: DrQ-v2-level
# easy ==> 1M steps, expl stddev: linear(1.0, 0.1, 100_000)
# med ==> 3M steps, expl stddev: linear(1.0, 0.1, 500_000)
# hard ==> 30M steps, expl stddev: linear(1.0, 0.1, 2_000_000)

@sweep.each
def tail(RUN, Args, Agent):
    RUN.job_name = f"{Args.train_env.split(':')[1].lower()}/{Args.seed}"
    Args.replay_buffer_num_workers = 3  # This slows things down a bit but should avoid the annoying OOM issues
    Args.time_limit = 60 * 60 * 24.3  # 24.3 hours
    Agent.inv_coef = 0.0
    Agent.rew_coef = 0.0


sweep.save("sweep.jsonl")
