
import click
import os

import numpy as np
import torch
import yaml

from rlkit.envs import ENVS
from rlkit.envs.wrappers import NormalizedBoxEnv
from maml_rl.meta_learner import MetaLearner
from maml_rl.trpo import TRPO

@click.command()
@click.option('--train_env',default=None)
@click.option('--seed',default=0)

def main(train_env,seed):
    with open(os.path.join("maml_rl/configs", f"{train_env}.yaml"),"r",encoding="utf-8") as file:
        env_config = yaml.load(file, Loader=yaml.FullLoader)

    env = NormalizedBoxEnv(ENVS[train_env]())
    env.set_train_task(env_config['train_tasks'])
    tasks = env.get_all_task_idx()
    np.random.seed(seed)
    torch.manual_seed(seed)

    observ_dim: int = env.observation_space.shape[0]
    action_dim: int = env.action_space.shape[0]
    policy_hidden_dim: int = env_config["policy_hidden_dim"]
    vf_hidden_dim: int = env_config["value_function_hidden_dim"]

    device = (torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"))

    agent = TRPO(
        observ_dim=observ_dim,
        action_dim=action_dim,
        policy_hidden_dim=policy_hidden_dim,
        vf_hidden_dim=vf_hidden_dim,
        device=device,
        **env_config["pg_params"],
    )

    meta_learner = MetaLearner(
        env=env,
        train_env=train_env,
        agent=agent,
        observ_dim=observ_dim,
        action_dim=action_dim,
        train_tasks=tasks,
        device=device,
        **env_config["maml_params"],
    )

    

    # MAML 학습 시작
    meta_learner.meta_train()


if __name__ == "__main__":
    main()