#! ./venv/bin/python

import logging
import os
import random

import hydra
import numpy as np
import torch as th
from hydra.core.hydra_config import HydraConfig
from omegaconf import DictConfig, OmegaConf
from tensorboard.util import tb_logging

try:
    import wandb

    WANDB_ENTITY = ""
    WANDB_PROJECT = ""
except Exception:
    wandb = None


def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    th.manual_seed(seed)


class _IgnoreTensorboardPathNotFound(logging.Filter):
    def filter(self, record):
        assert record.name == "tensorboard"
        if "No path found after" in record.msg:
            return False
        return True


def custom_wandb_init(*args, **kwargs):
    run = wandb.init(*args, **kwargs)
    tb_logger = tb_logging.get_logger()
    tb_logger.addFilter(_IgnoreTensorboardPathNotFound())
    return run


@hydra.main(config_path="conf", config_name="config")
def main(cfg: DictConfig) -> None:
    set_seed(cfg.seed)
    np.seterr(all="warn")

    hydra_config = HydraConfig.get()

    if cfg.use_wandb and wandb:
        run = custom_wandb_init(
            entity=WANDB_ENTITY,
            project=WANDB_PROJECT,
            sync_tensorboard=True,
            monitor_gym=True,
            settings=wandb.Settings(start_method="thread"),
            group=cfg.wandb_group
        )
        wandb.config.update(
            {
                "job": hydra_config.job.name,
                "model": cfg._model_name,
                "policy_nn": cfg.policy_nn.policy_class,
                "environment": cfg.environment.id,
                "seed": cfg.seed,
            }
        )
        wandb.config.update(
            OmegaConf.to_container(
                cfg.options, resolve=True, throw_on_missing=True
            )
        )

    model = hydra.utils.instantiate(cfg.model.instance)

    if cfg.options.reuse:
        if not os.path.exists(cfg.options.reuse_path):
            artifact = run.use_artifact(
                f"{WANDB_ENTITY}/{WANDB_PROJECT}/{cfg.options.reuse_path}:v0",
                type="model",
                use_as="input",
            )
            artifacts_path = artifact.download()
            classifier_path = os.path.join(artifacts_path, "classifier")
            policy_path = os.path.join(artifacts_path, "rl_policy")
        else:
            classifier_path = cfg.options.reuse_path
            policy_path = cfg.options.reuse_path[:-len("classifier")] + "rl_policy"

        if cfg._model_name == "PPOClassifier":
            model.classifier.load(
                classifier_path, load_optimizer=cfg.options.load_optimizer
            )
        elif cfg._model_name == "PPOPolicyReuse":
            model.load_policy_reuse(policy_path)
        elif cfg._model_name == "PPOWarmStart":
            model.load_policy(policy_path)

    if cfg.options.train_classifier and cfg.use_wandb and wandb:
        wandb.watch(model.classifier)

    callback = hydra.utils.instantiate(cfg.callback)

    eval_env = hydra.utils.instantiate(cfg.env)
    try:
        model.learn(
            total_timesteps=cfg.total_timesteps,
            log_interval=cfg.log_interval,
            callback=callback,
            eval_env=eval_env,
            eval_freq=500,
        )
    except KeyboardInterrupt:
        pass

    if cfg.use_wandb and wandb:
        wandb.finish(quiet=True)


if __name__ == "__main__":
    main()
