#!/usr/bin/env python3

import numpy as np
from ml_logger import RUN
from params_proto.neo_hyper import Sweep
from dmc_gen.config import Args
from invr_thru_inf.config import Adapt
import sys; sys.path.append('../..')
from const import envs

with Sweep(RUN, Args, Adapt) as sweep:
    Adapt.latent_buffer_size = 1_000_000
    Adapt.policy_on_clean_buffer = 'random'
    Adapt.policy_on_distr_buffer = 'random'

    with sweep.zip:
        Args.env_name = [f'dmc:{env_name}-v1' for env_name in envs]
        Args.eval_env_name = [f'distracting_control:{env_name}-intensity-v1' for env_name in envs]

    with sweep.product:
        # Adapt.distraction_types = [['background'], ['camera'], ['video-background'], ['dmcgen-color-hard']]
        Adapt.distraction_types = [['background'], ['color'], ['camera']]
        Adapt.distraction_intensity = list(np.linspace(0, 1, 5 + 1))[1:]  # remove the first item (0)
        Adapt.agent_stem = ['sac', 'soda', 'svea', 'pad']
        Args.seed = [(i + 1) * 100 for i in range(5)]


@sweep.each
def tail(RUN, Args, Adapt):
    # NOTE: pretrained policy (Adapt.snapshot_prefix) does not matter!
    from invr_thru_inf.utils import get_buffer_prefix
    from const import get_distraction_coef
    Adapt.distraction_intensity = Adapt.distraction_intensity * get_distraction_coef(Adapt.distraction_types)

    Adapt.snapshot_prefix = f"model-free/model-free/baselines/dmc_gen/run/{Adapt.agent_stem}/{Args.env_name.lower().split(':')[1][:-3]}/{Args.seed}"
    RUN.job_name = f"{Adapt.agent_stem}/{get_buffer_prefix(Args.env_name, Args.action_repeat, Args.seed, Adapt, eval_env=Args.eval_env_name)}"


sweep.save("sweep_intensity.jsonl")
