# Standard library imports
import os
import traceback
import logging
import random
import signal
import psutil
from time import sleep

# Necessary for multithreading.
os.environ["OMP_NUM_THREADS"] = "1"

# Third party imports
import torch
from torch import multiprocessing as mp
import wandb
import hydra
from omegaconf import DictConfig, OmegaConf
import numpy as np

# Local application imports
from il_scale.nethack.utils.setup import DDPUtil, set_device
from il_scale.nethack.utils.model import count_params
from il_scale.nethack.agent import Agent
from il_scale.nethack.data.tty_data import TTYData
from il_scale.nethack.data.parquet_data import ParquetData
from il_scale.nethack.trainers.bc_trainer import BCTrainer
from il_scale.nethack.logger import Logger

logging.basicConfig(
    format=(
        "[%(levelname)s:%(process)d %(module)s:%(lineno)d %(asctime)s] " "%(message)s"
    ),
    level=logging.INFO,
)

def train(
    rank: int,
    world_size: int,
    ddp_util: DDPUtil,
    cfg: DictConfig,
    data: TTYData,
):
    # set seeds
    random.seed(cfg.setup.seed)
    np.random.seed(cfg.setup.seed)
    torch.manual_seed(cfg.setup.seed)

    # Only setup logger in rank 0 process
    logger = Logger(cfg)
    if rank == 0:
        logger.setup()

    # Setup DDP
    ddp_util.setup(rank, world_size)

    # Create agents
    agent = Agent(cfg, logger)
    agent.construct_model()

    # Log model size to wandb
    total_params = count_params(agent.model)
    if rank == 0:
        wandb.log({"num_params": total_params})
    logging.info("Created model with {} total params.".format(total_params))
    logging.info("Params in sequence model: {}".format(count_params(agent.model.core) + count_params(agent.model.policy_head) + count_params(agent.model.modality_mixer)))

    # Move to GPU
    agent.to(rank)

    # Create trainer
    bc_trainer = BCTrainer(cfg, logger, agent, data, ddp_util)

    # Move to DDP
    agent.move_to_ddp(rank, world_size, find_unused_parameters=((cfg.network.core_mode!="mamba" and cfg.network.core_mode != 'lstm') or cfg.network.use_message_transformer or cfg.network.use_blstats_transformer))

    # Start training
    bc_trainer.train()

    if rank == 0:
        logger.shutdown()

def signal_handler(signum, frame):
    logging.info(f'SIGNAL: {signum} received in main process! Forward to children ...')
    current_process = psutil.Process()
    # NOTE: it's important to use recursive=False, otherwise the signal will 
    # be sent to the children of the ddp processes, which includes stuff like wandb
    # and they don't seem to like it when you send them signals.
    children = current_process.children(recursive=False)
    for child in children:
        psutil.Process(child.pid).send_signal(signum)

@hydra.main(version_base=None, config_path="../../conf", config_name="nethack_config")
def main(cfg: DictConfig) -> None:
    signal.signal(signal.SIGUSR2, signal_handler)

    try:
        logging.info(OmegaConf.to_yaml(cfg))
        set_device(cfg)

        if cfg.data.dataset_type == 'ttyrec':
            data = TTYData(cfg.data)
        elif cfg.data.dataset_type == 'parquet':
            data = ParquetData(cfg.data)

        ddp_util = DDPUtil()

        mp.spawn(
            train,
            args=(cfg.setup.num_gpus, ddp_util, cfg, data),
            nprocs=cfg.setup.num_gpus,
            join=True,
        )

    except Exception:
        DDPUtil.cleanup()
        traceback.print_exc()
        raise


if __name__ == "__main__":
    main()
