import os
import random
import gymnasium as gym
import hydra
import numpy as np
import torch
from omegaconf import DictConfig

from .logging_utils import MetricLogger, setup_logging


def set_seed(seed):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


@hydra.main(config_path="../conf", config_name="config", version_base=None)
def main(cfg: DictConfig):
    logging_info = setup_logging(cfg)
    logger = logging_info["logger"]
    wandb_run = logging_info["wandb_run"]
    exp_dir = logging_info["exp_dir"]
    checkpoints_dir = logging_info["checkpoints_dir"]

    metric_logger = MetricLogger(wandb_run=wandb_run, log_interval=1)

    device = torch.device(
        "cuda" if (cfg.device == "cuda" and torch.cuda.is_available()) else "cpu"
    )
    set_seed(cfg.seed)
    algo = hydra.utils.instantiate(
        cfg.algorithm,
        env_cfg=cfg.env,
        exp_dir=exp_dir,
        metric_logger=metric_logger,
        info_logger=logger,
        device=device,
        seed=cfg.seed,
        _recursive_=False,
    )

    algo.train()
    algo.save("final")


if __name__ == "__main__":
    main()
