#!/usr/bin/env python3

import numpy as np
from ml_logger import RUN
from params_proto.neo_hyper import Sweep
from drqv2_invariance.config import Args
from invr_thru_inf.config import Adapt
import sys; sys.path.append('../..')
from const import envs

with Sweep(RUN, Args, Adapt) as sweep:
    Adapt.latent_buffer_size = 1_000_000
    Adapt.policy_on_clean_buffer = 'random'
    Adapt.policy_on_distr_buffer = 'random'

    Adapt.snapshot_prefix = None

    with sweep.zip:
        Args.train_env = [f'dmc:{env_name}-v1' for env_name in envs]
        Args.eval_env = [f'distracting_control:{env_name}-intensity-v1' for env_name in envs]

    with sweep.product:
        # Adapt.distraction_types = [['background'], ['camera'], ['video-background'], ['dmcgen-color-hard']]
        Adapt.distraction_types = [['background'], ['color'], ['camera']]
        Adapt.distraction_intensity = list(np.linspace(0, 1, 5 + 1))[1:]  # remove the first item (0)
        Args.seed = [(i + 1) * 100 for i in range(5)]


@sweep.each
def tail(RUN, Args, Adapt):
    # NOTE: pretrained policy (Adapt.snapshot_prefix) does not matter!
    from invr_thru_inf.utils import get_buffer_prefix
    from const import get_distraction_coef
    Adapt.distraction_intensity = Adapt.distraction_intensity * get_distraction_coef(Adapt.distraction_types)

    RUN.job_name = f"drqv2/{get_buffer_prefix(Args.train_env, Args.action_repeat, Args.seed, Adapt, eval_env=Args.eval_env)}"


sweep.save("sweep_intensity.jsonl")
