# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the CC BY-NC 4.0 license found in the
# LICENSE file in the root directory of this source tree.

import copy
import dataclasses
from typing import Literal

import tyro

from metamotivo.envs.dmc_tasks import ALL_TASKS
from metamotivo.misc.launcher_utils import all_combinations_of_nested_dicts_for_sweep, flatten, \
    launch_trials, unflatten, setup_wandb_config

BASE_CFG = {
    "num_train_steps": 1_000_000,
    "data": {
        "name": "dmc",
        "domain": "walker",
        "load_n_episodes": 5_000,
        "obs_type": "state",
        "buffer_type": "parallel",
    },
    "env": {"name": "dmc", "domain": "walker", "task": "walk"},
    "agent": {
        "name": "",  # to be updated later
        "compile": True,
        "model": {
            "device": "cuda",
            "obs_normalizer": {
                "name": "IdentityNormalizerConfig",
            },
            "archi": {
                "successor_features": {
                    "name": "ForwardArchi",
                    "hidden_dim": 1024,
                    "hidden_layers": 1,
                },
                "actor": {
                    "hidden_dim": 1024,
                    "hidden_layers": 1,
                    "name": "simple",
                },
                "features": {
                    "name": "BackwardArchi",
                    "hidden_dim": 256,
                    "hidden_layers": 2,
                    "norm": True,
                },
                "left_encoder": {
                    "name": "BackwardArchi",
                    "hidden_dim": 256,
                    "hidden_layers": 0,
                    "norm": True,
                },
                "L_dim": 256,
                "z_dim": 50,
                "norm_z": True,
            },
        },
        "train": {
            "batch_size": 1024,
            "discount": 0.98,
            "sf_target_tau": 0.001,
            "features_target_tau": 0.001,
        },
    },
}

# extra params for the specific instance of SF agent
BASE_AGENT_CFG = {
    "sf": {"agent.name": "SFAgent"},
    "laplacian": {"agent.name": "LaplacianAgent"},
    "hilp": {
        "agent.name": "HilpAgent",
        "agent.train.expectile": 0.5,
        "agent.train.prob_random_goal": 0.375,
        "agent.model.archi.features.norm": False,
        "agent.model.center_features": True,
        "data.future": 0.99,
    },
    "byol": {
        "agent.name": "BYOLAgent",
        "agent.train.lr_predictor": 1e-4,
        "agent.train.ortho_coef": 0.1,
        "agent.train.multi_step": False,
        "agent.model.archi.predictor.hidden_dim": 1024,
        "agent.model.archi.predictor.hidden_layers": 1,
    },
    "byol-gamma": {
        "agent.name": "BYOLAgent",
        "agent.train.lr_predictor": 1e-4,
        "agent.train.ortho_coef": 0.1,
        "agent.train.multi_step": True,
        "data.future": 0.98,
        "agent.model.archi.predictor.hidden_dim": 1024,
        "agent.model.archi.predictor.hidden_layers": 1,
    },
    "icvf": {
        "agent.name": "ICVFAgent",
        "agent.train.expectile": 0.9,
        "agent.train.prob_random_goal": 0.375,
        "agent.model.archi.features.norm": True,
        "data.future": 0.99,
        "agent.model.archi.L_dim": 50,
    },
}


def sweep_byol_walker():
    conf = {
        "seed": [3917, 3502, 8948, 9460, 4729],
        "env.domain": ["walker"],
        "agent": {
            "model": {"archi": {"z_dim": [50]}},
            "train": {"lr_features": [1e-4, 1e-5], "ortho_coef": [0.001, 0.01, 0.1]},
        },
    }
    return conf


def sweep_byol_cheetah():
    conf = {
        "seed": [3917, 3502, 8948, 9460, 4729],
        "env.domain": ["cheetah"],
        "agent": {
            "model": {"archi": {"z_dim": [50]}},
            "train": {"lr_features": [1e-4, 1e-5], "ortho_coef": [0.001, 0.01, 0.1]},
        },
    }
    return conf


def sweep_byol_quadruped():
    conf = {
        "seed": [3917, 3502, 8948, 9460, 4729],
        "env.domain": ["quadruped"],
        "agent": {
            "model": {"archi": {"z_dim": [50]}},
            "train": {"lr_features": [1e-4, 1e-5], "ortho_coef": [0.001, 0.01, 0.1]},
        },
    }
    return conf


def sweep_byol_pointmass():
    conf = {
        "seed": [3917, 3502, 8948, 9460, 4729],
        "env.domain": ["pointmass"],
        "env.task": ["reach_top_left"],
        "agent": {
            "model": {"archi": {"z_dim": [50]}},
            "train": {"lr_features": [1e-4, 1e-5], "ortho_coef": [0.001, 0.01, 0.1], "discount": [0.99], "lr_actor": [1e-6]},
        },
    }
    return conf


def sweep_hilp_walker():
    conf = {
        "seed": [3917, 3502, 8948, 9460, 4729],
        "env.domain": ["walker"],
        "agent": {
            "model": {"archi": {"z_dim": [50]}},
            "train": {"lr_features": [1e-4, 1e-5], "prob_random_goal": [0.375, 0.2, 0.5]},
        },
    }
    return conf


def sweep_hilp_cheetah():
    conf = {
        "seed": [3917, 3502, 8948, 9460, 4729],
        "env.domain": ["cheetah"],
        "agent": {
            "model": {"archi": {"z_dim": [50]}},
            "train": {"lr_features": [1e-4, 1e-5], "prob_random_goal": [0.375, 0.2, 0.5]},
        },
    }
    return conf


def sweep_hilp_quadruped():
    conf = {
        "seed": [3917, 3502, 8948, 9460, 4729],
        "env.domain": ["quadruped"],
        "agent": {
            "model": {"archi": {"z_dim": [50]}},
            "train": {"lr_features": [1e-4, 1e-5], "prob_random_goal": [0.375, 0.2, 0.5]},
        },
    }
    return conf


def sweep_hilp_pointmass():
    conf = {
        "seed": [3917, 3502, 8948, 9460, 4729],
        "env.domain": ["pointmass"],
        "agent": {
            "model": {"archi": {"z_dim": [50]}},
            "train": {"lr_features": [1e-4, 1e-5], "prob_random_goal": [0.375, 0.2, 0.5], "discount": [0.99], "lr_actor": [1e-6]},
        },
    }
    return conf


def sweep_laplacian_walker():
    conf = {
        "seed": [3917, 3502, 8948, 9460, 4729],
        "env.domain": ["walker"],
        "agent": {
            "model": {"archi": {"z_dim": [50]}},
            "train": {"lr_features": [1e-4, 1e-5], "ortho_coef": [0, 0.1, 1]},
        },
    }
    return conf


def sweep_laplacian_cheetah():
    conf = {
        "seed": [3917, 3502, 8948, 9460, 4729],
        "env.domain": ["cheetah"],
        "agent": {
            "model": {"archi": {"z_dim": [50]}},
            "train": {"lr_features": [1e-4, 1e-5], "ortho_coef": [0, 0.1, 1]},
        },
    }
    return conf


def sweep_laplacian_quadruped():
    conf = {
        "seed": [3917, 3502, 8948, 9460, 4729],
        "env.domain": ["quadruped"],
        "agent": {
            "model": {"archi": {"z_dim": [50]}},
            "train": {"lr_features": [1e-4, 1e-5], "ortho_coef": [0, 0.1, 1]},
        },
    }
    return conf


def sweep_laplacian_pointmass():
    conf = {
        "seed": [3917, 3502, 8948, 9460, 4729],
        "env.domain": ["pointmass"],
        "env.task": ["reach_top_left"],
        "agent": {
            "model": {"archi": {"z_dim": [50]}},
            "train": {"lr_features": [1e-4, 1e-5], "discount": [0.99], "lr_actor": [1e-6], "ortho_coef": [0, 0.1, 1]},
        },
    }
    return conf


@dataclasses.dataclass
class LaunchArgs:
    # dataset and working paths
    data_path: str = "/home/jovyan/.exorl/expl_datasets"
    workdir_root: str = "/path/to/workdir"
    # wandb config
    use_wandb: bool = False
    wandb_gname: str | None = "td_jepa"
    wandb_ename: str | None = None
    wandb_pname: str | None = "td_jepa"
    # specify to run sweeps
    sweep_config: str | None = None
    # instead of launching all experiments, only run the first one
    first_only: bool = False
    # print out the configs instead of running the experiments
    dry: bool = False
    # which sf agent to run
    sf_agent: str = "sf"
    # selects the depth of the left encoder
    left_encoder: Literal["shallow", "deep"] = "shallow"
    compile: bool = True

def main(args: LaunchArgs):
    base_cfg = copy.deepcopy(BASE_CFG)
    base_cfg["work_dir"] = args.workdir_root
    base_cfg["data"]["dataset_root"] = args.data_path
    base_cfg["agent"]["compile"] = args.compile
    flat = flatten(base_cfg)
    flat.update(flatten(BASE_AGENT_CFG[args.sf_agent]))
    base_cfg = unflatten(flat)
    match args.left_encoder:
        case "shallow":
            pass
        case "deep":
            base_cfg["agent"]["model"]["archi"]["left_encoder"]["hidden_layers"] = 2
            base_cfg["agent"]["model"]["archi"]["L_dim"] = 50
        case _:
            raise NotImplementedError("Unknown left encoder configuration: ", args.left_encoder)

    if args.sweep_config is None:
        sweep_params = {}
    else:
        if args.sweep_config in globals().keys():
            sweep_params = globals()[args.sweep_config]()
        else:
            raise RuntimeError("Unknown sweep configuration")

    trials = []
    for i, trial in enumerate(all_combinations_of_nested_dicts_for_sweep(sweep_params)):
        trial = flatten(trial)
        trial.update(
            flatten(
                {
                    **setup_wandb_config(
                        trial,
                        base_cfg,
                        args_use_wandb=args.use_wandb,
                        args_wandb_pname=args.wandb_pname,
                        args_wandb_gname=args.wandb_gname,
                        args_wandb_ename=args.wandb_ename,
                    ),
                    "work_dir": f"{args.workdir_root}/{trial['env.domain']}/{trial['seed']}",
                    "data.domain": trial["env.domain"],
                    "env.task": ALL_TASKS[trial["env.domain"]][0],
                    "evaluations": [
                        {
                            "name": "dmc_reward_eval",
                            "env": {
                                "name": "dmc",
                                "domain": trial["env.domain"],
                                "task": ALL_TASKS[trial["env.domain"]][0],
                            },
                            "tasks": ALL_TASKS[trial["env.domain"]],
                            "num_episodes": 30,
                            "num_inference_samples": 10_000,
                            "disable_tqdm": False,
                            "record_video": False,
                        },
                    ],
                }
            )
        )
        trials.append(trial)

    launch_trials(base_cfg, trials, args.first_only, args.dry)


if __name__ == "__main__":
    args = tyro.cli(LaunchArgs)
    main(args)
    # uv run -m scripts.train.proprio.launch_sf_dmc --use_wandb --wandb_gname byol_walker_proprio --data_path datasets --workdir_root results --sweep_config sweep_byol_walker --sf_agent byol
    # uv run -m scripts.train.proprio.launch_sf_dmc --use_wandb --wandb_gname byol_cheetah_proprio --data_path datasets --workdir_root results --sweep_config sweep_byol_cheetah --sf_agent byol
    # uv run -m scripts.train.proprio.launch_sf_dmc --use_wandb --wandb_gname byol_quadruped_proprio --data_path datasets --workdir_root results --sweep_config sweep_byol_quadruped --sf_agent byol
    # uv run -m scripts.train.proprio.launch_sf_dmc --use_wandb --wandb_gname byol_pointmass_proprio --data_path datasets --workdir_root results --sweep_config sweep_byol_pointmass --sf_agent byol

    # uv run -m scripts.train.proprio.launch_sf_dmc --use_wandb --wandb_gname byol_gamma_walker_proprio --data_path datasets --workdir_root results --sweep_config sweep_byol_walker --sf_agent byol-gamma
    # uv run -m scripts.train.proprio.launch_sf_dmc --use_wandb --wandb_gname byol_gamma_cheetah_proprio --data_path datasets --workdir_root results --sweep_config sweep_byol_cheetah --sf_agent byol-gamma
    # uv run -m scripts.train.proprio.launch_sf_dmc --use_wandb --wandb_gname byol_gamma_quadruped_proprio --data_path datasets --workdir_root results --sweep_config sweep_byol_quadruped --sf_agent byol-gamma
    # uv run -m scripts.train.proprio.launch_sf_dmc --use_wandb --wandb_gname byol_gamma_pointmass_proprio --data_path datasets --workdir_root results --sweep_config sweep_byol_pointmass --sf_agent byol-gamma

    # uv run -m scripts.train.proprio.launch_sf_dmc --use_wandb --wandb_gname hilp_walker_proprio --data_path datasets --workdir_root results --sweep_config sweep_hilp_walker --sf_agent hilp
    # uv run -m scripts.train.proprio.launch_sf_dmc --use_wandb --wandb_gname hilp_cheetah_proprio --data_path datasets --workdir_root results --sweep_config sweep_hilp_cheetah --sf_agent hilp
    # uv run -m scripts.train.proprio.launch_sf_dmc --use_wandb --wandb_gname hilp_quadruped_proprio --data_path datasets --workdir_root results --sweep_config sweep_hilp_quadruped --sf_agent hilp
    # uv run -m scripts.train.proprio.launch_sf_dmc --use_wandb --wandb_gname hilp_pointmass_proprio --data_path datasets --workdir_root results --sweep_config sweep_hilp_pointmass --sf_agent hilp

    # uv run -m scripts.train.proprio.launch_sf_dmc --use_wandb --wandb_gname laplacian_walker_proprio --data_path datasets --workdir_root results --sweep_config sweep_laplacian_walker --sf_agent laplacian
    # uv run -m scripts.train.proprio.launch_sf_dmc --use_wandb --wandb_gname laplacian_cheetah_proprio --data_path datasets --workdir_root results --sweep_config sweep_laplacian_cheetah --sf_agent laplacian
    # uv run -m scripts.train.proprio.launch_sf_dmc --use_wandb --wandb_gname laplacian_quadruped_proprio --data_path datasets --workdir_root results --sweep_config sweep_laplacian_quadruped --sf_agent laplacian
    # uv run -m scripts.train.proprio.launch_sf_dmc --use_wandb --wandb_gname laplacian_pointmass_proprio --data_path datasets --workdir_root results --sweep_config sweep_laplacian_pointmass --sf_agent laplacian

    # uv run -m scripts.train.proprio.launch_sf_dmc --use_wandb --wandb_gname icvf_walker_proprio --data_path datasets --workdir_root results --sweep_config sweep_hilp_walker --sf_agent icvf
    # uv run -m scripts.train.proprio.launch_sf_dmc --use_wandb --wandb_gname icvf_cheetah_proprio --data_path datasets --workdir_root results --sweep_config sweep_hilp_cheetah --sf_agent icvf
    # uv run -m scripts.train.proprio.launch_sf_dmc --use_wandb --wandb_gname icvf_quadruped_proprio --data_path datasets --workdir_root results --sweep_config sweep_hilp_quadruped --sf_agent icvf
    # uv run -m scripts.train.proprio.launch_sf_dmc --use_wandb --wandb_gname icvf_pointmass_proprio --data_path datasets --workdir_root results --sweep_config sweep_hilp_pointmass --sf_agent icvf
