#!/usr/bin/env python3

import numpy as np
from ml_logger import RUN
from params_proto.neo_hyper import Sweep
from dmc_gen.config import Args
from invr_thru_inf.config import Adapt, CollectData
import sys; sys.path.append('../..')
from const import envs

with Sweep(RUN, Args, Adapt, CollectData) as sweep:
    Adapt.latent_buffer_size = 1_000_000
    Adapt.policy_on_clean_buffer = 'random'
    Adapt.policy_on_distr_buffer = 'random'
    CollectData.buffer_type = 'target'

    Adapt.snapshot_prefix = None

    with sweep.zip:
        # Args.env_name = [f'dmc:{env_name}-v1' for env_name in envs]
        Args.env_name = [f'dmc:{env_name}-v1' for env_name in envs]
        Args.eval_env_name = [f'distracting_control:{env_name}-intensity-v1' for env_name in envs]

    with sweep.product:
        # Adapt.distraction_types = [['background'], ['camera'], ['video-background'], ['dmcgen-color-hard']]
        Adapt.distraction_types = [['background'], ['color'], ['camera']]
        Adapt.distraction_intensity = list(np.linspace(0, 1, 5 + 1))[1:]  # remove the first item (0)
        Args.seed = [(i + 1) * 100 for i in range(5)]


@sweep.each
def tail(RUN, Args, Adapt, CollectData):
    # NOTE: pretrained policy (Adapt.snapshot_prefix) does not matter!
    from invr_thru_inf.utils import get_buffer_prefix
    from const import get_distraction_coef
    Adapt.distraction_intensity = Adapt.distraction_intensity * get_distraction_coef(Adapt.distraction_types)

    if Args.env_name == 'dmc:Finger-spin-v1':
        Args.action_repeat = 2
    elif Args.env_name == 'dmc:Cartpole-swingup-v1':
        Args.action_repeat = 8
    else:
        Args.action_repeat = 4  # default

    RUN.job_name = f"{Adapt.agent_stem}/{get_buffer_prefix(Args.env_name, Args.action_repeat, Args.seed, Adapt, eval_env=Args.eval_env_name)}"


sweep.save("sweep_targ.jsonl")
