#!/usr/bin/env python3
import numpy as np
from ml_logger import RUN
from params_proto.neo_hyper import Sweep
from drqv2_invariance.config import Args
from invr_thru_inf.config import Adapt, CollectData
import sys; sys.path.append('../..')
from const import envs

with Sweep(RUN, Args, Adapt, CollectData) as sweep:
    Adapt.latent_buffer_size = 1_000_000
    Adapt.policy_on_clean_buffer = 'random'
    Adapt.policy_on_distr_buffer = 'random'
    CollectData.buffer_type = 'original'

    # Adapt.distraction_types = None
    # Adapt.distraction_intensity = None

    with sweep.zip:
        Args.train_env = [f'dmc:{env_name}-v1' for env_name in envs]
        Args.eval_env = [f'distracting_control:{env_name}-intensity-v1' for env_name in envs]

    with sweep.product:
        Args.seed = [(i + 1) * 100 for i in range(5)]


@sweep.each
def tail(RUN, Args, Adapt, CollectData):
    # NOTE: pretrained policy (Adapt.snapshot_prefix) does not matter!
    from invr_thru_inf.utils import get_buffer_prefix

    Adapt.snapshot_prefix = f"model-free/model-free/baselines/drqv2_original/train/{Args.train_env.lower().split(':')[1]}/{Args.seed}"
    RUN.job_name = f"drqv2/targ/{get_buffer_prefix(Args.train_env, Args.action_repeat, Args.seed, Adapt, eval_env=Args.eval_env)}"


sweep.save("sweep_orig.jsonl")
