import os
import socket
import uuid
from typing import Unpack
import rrls
from td3.trainer import Trainer
from td3.td3 import TD3Config
from utils import env_factory
from fire import Fire
from dotenv import load_dotenv

load_dotenv()


def main(
    experiment_name: str = "test",
    env_name: str = "Walker",
    max_steps: int = 1_000_000,
    start_steps: int = 25_000,
    seed: int = 0,
    eval_freq: int = 10_000,
    track: bool = True,
    device: str = "cuda:0",
    project_name: str = "vanilla",
    output_dir: str | None = None,
    **kwargs: Unpack[TD3Config],
):
    unique_id = str(uuid.uuid4())
    if output_dir is not None:
        os.makedirs(f"{output_dir}/{unique_id}", exist_ok=True)

    env = env_factory(env_name=env_name)
    eval_env = env_factory(env_name=env_name)
    params = {
        "env_name": env_name,
        "seed": seed,
        "machine_name": socket.gethostname(),
    }
    trainer = Trainer(
        env=env,
        eval_env=eval_env,
        device=device,
        save_dir=output_dir,
        params=params,
        **kwargs,
    )
    trainer.train(
        experiment_name=experiment_name,
        max_steps=max_steps,
        start_steps=start_steps,
        seed=seed,
        eval_freq=eval_freq,
        track=track,
        project_name=project_name,
    )


if __name__ == "__main__":
    Fire(component=main)
