"""Sweep for pretraining."""
import os
import fire

from pretrain_alg_configs import alg_configs
from sweep_utils import (
    write_slurm_file,
    log_gridlist_and_config,
    update_data_dirs,
    ROOT_DIR,
)

from imitation_pretraining.experiments.training import train
import configs
from imitation_pretraining.experiments.evaluation import loading

MODE = "GPU"
HRS = 1

ALG_LIST = alg_configs(identity_encoder=True)
COMMON_DICT = dict(
    observation_adapter_name="embedding",
    encode_data=True,
    pretrain=True,
    num_steps=10000,
    rollout_freq=1000000,  # Do not rollout
    num_rollouts=0,  # Do not rollout
)

PRETRAIN_DATE = "Apr-2023"
PRETRAIN_SWEEP_ID = "32658444"
ENCODER_CONFIGS = loading.load_sweep_configs(ROOT_DIR, PRETRAIN_DATE, PRETRAIN_SWEEP_ID)
ENCODER_CONFIGS = [e for e in ENCODER_CONFIGS if e["job_id"] <= 45]

SIZE_FILTER = {
    "point_mass": 100,
    "metaworld_pick_place_nogoal": 100,
    "kitchen_split_0": 450,
    "metaworld_pretrain_split_door": 100,
    "metaworld_pretrain_split_0": 1000,
    "metaworld_pretrain_split_r3m": 1000,
}
ENCODER_CONFIGS = [
    e for e in ENCODER_CONFIGS if e["max_episodes"] == SIZE_FILTER[e["eval_env_name"]]
]
print(f"Number of encoders: {len(ENCODER_CONFIGS)}")

FULL_SWEEP_LIST = []
for encoder_config in ENCODER_CONFIGS:
    for alg_config in ALG_LIST:
        if alg_config["agent_name"] != "reconstruction":
            continue  # Only run reconstruction

        full_config = dict(alg_config, **COMMON_DICT)
        full_config["encoder_name"] = encoder_config["agent_name"]
        full_config["encoder_config"] = encoder_config
        # TODO: need to add logic to actually use the config in the preprocessing
        #    - Don't pop pixels from the config
        #    - Add obs adapter that treats state and next state as different
        #          - Requires passing next_obs flag into all the obs adapters (breaking change)
        full_config["encode_next_obs"] = False

        # Use same dataset as encoder
        full_config["ep"] = encoder_config["ep"]
        full_config["per"] = encoder_config["per"]
        full_config["seed"] = encoder_config["seed"]
        full_config["eval_env_name"] = encoder_config["eval_env_name"]
        full_config["max_episodes"] = encoder_config["max_episodes"]

        if full_config["eval_env_name"] == "point_mass":
            full_config["decoder_network_name"] = "decoder-conv-84"
        else:
            full_config["decoder_network_name"] = "decoder-conv-120"

        FULL_SWEEP_LIST.append(full_config)


print("Total jobs:", len(FULL_SWEEP_LIST))


def main(idx: int, sweep_id: int = 0, test: bool = False):
    """Launch sweep job."""
    if idx == 0:
        write_slurm_file(
            len(FULL_SWEEP_LIST), os.path.basename(__file__), mode=MODE, hrs=HRS
        )
    else:
        alg = FULL_SWEEP_LIST[idx - 1]["agent_name"]
        config = configs.get_config(alg)
        config.update(FULL_SWEEP_LIST[idx - 1])  # Update based on index
        config = update_data_dirs(config)
        config["job_id"] = idx
        config["sweep_id"] = sweep_id
        if idx == 1 and not test:  # Log sweep grid
            log_gridlist_and_config(FULL_SWEEP_LIST, config)
        train.run(config, test=test)


if __name__ == "__main__":
    fire.Fire(main)
