import logging
from typing import Optional, Dict, Callable

import gin
import torch
import torch.multiprocessing as mp

from extensions.rl_minigrid.minigrid_experiments.base import (
    MiniGridBaseExperimentConfig,
)
from main import _init_logging, _load_config
from onpolicy_sync.engine import OnPolicyTrainer, OnPolicyTester
from rl_base.common import ActorCriticOutput

mp = mp.get_context("forkserver")
import queue
from setproctitle import setproctitle as ptitle

LOGGER = logging.getLogger("embodiedrl")


def iteratively_run_minigrid_experiments(
    process_id: int,
    gpu_id: Optional[int],
    args,
    input_queue: mp.Queue,
    output_queue: mp.Queue,
    should_log: bool,
    str_to_extra_metrics_func: Optional[
        Dict[str, Callable[[ActorCriticOutput, Dict[str, torch.Tensor]], float]]
    ] = None,
    test_seed_offset: Optional[int] = None,
):
    """Iteratively train explore/combination models under different training
    regimes.

    This function is very similar to the `iteratively_run_lighthouse_experiments` function except
    that rather than training with different levels of supervision, here we only have one
    level of supervision and instead it's the training regime (e.g. ppo v.s. dagger) that is
    allowed to change based on the values in the `input_queue`.

    See `iteratively_run_lighthouse_experiments` for detailed documentation.
    """
    ptitle("({}) Create Minigrid Runner".format(process_id))

    _init_logging()

    str_to_extra_metrics_func = (
        {} if str_to_extra_metrics_func is None else str_to_extra_metrics_func
    )

    try:
        while True:
            task_name, experiment_str, gp_params, seed = input_queue.get(timeout=1)

            # parts = args.experiment.split(".")
            # assert len(parts) == 2
            # args.experiment = "{}.{}".format(parts[0], experiment_str)
            args.experiment = experiment_str
            args.gp = [
                "task_name.name = '{}'".format(task_name),
            ]
            if gp_params and gp_params[0] is not None:
                args.gp.extend(gp_params)

            cfg: MiniGridBaseExperimentConfig
            gin.clear_config()
            cfg, _ = _load_config(args)  # type: ignore

            type(cfg).GPU_ID = gpu_id
            type(cfg).SHOULD_LOG = should_log

            if test_seed_offset is not None:
                type(cfg).TEST_SEED_OFFSET = test_seed_offset

            # assert agent_view_size % 2 == 1
            optimal_ave_ep_length = cfg.task_info().get("optimal_ave_ep_length")
            LOGGER.info(
                "Running with (minigrid, exp, gp_params, seed) = ({},{},{},{}).".format(
                    task_name, experiment_str, gp_params, seed
                )
            )

            trainer = OnPolicyTrainer(
                config=cfg,
                output_dir=args.output_dir,
                loaded_config_src_files=None,
                seed=seed,
                deterministic_cudnn=args.deterministic_cudnn,
                extra_tag=args.extra_tag,
                single_process_training=args.single_process_training,
            )
            trainer.run_pipeline()

            chkpt_file_path = trainer.checkpoint_save()

            test_results = OnPolicyTester(
                config=cfg,
                output_dir=args.output_dir,
                seed=seed,
                deterministic_cudnn=args.deterministic_cudnn,
                single_process_training=args.single_process_training,
                should_log=False,
            ).run_test(
                experiment_date="",
                checkpoint_file_name=chkpt_file_path,
                skip_checkpoints=args.skip_checkpoints,
                deterministic_agent=False,
                str_to_extra_metrics_func=str_to_extra_metrics_func,
            )

            # os.remove(chkpt_file_path)

            output_data = {
                "exp_type": experiment_str,
                "minigrid_env": task_name,
                "gp_params": gp_params,
                "reward": float(test_results[0]["reward"]),
                "avg_ep_length": float(test_results[0]["ep_length"]),
                "train_steps": int(test_results[0]["training_steps"]),
                "seed": seed,
                "lr": cfg.lr(),
                "extra_tag": cfg.extra_tag(),
                **{key: test_results[0][key] for key in str_to_extra_metrics_func},
            }
            if optimal_ave_ep_length is not None:
                output_data.update(
                    {
                        "reached_near_optimal": 1
                        * (test_results[0]["ep_length"] < optimal_ave_ep_length * 1.1),
                        "optimal_avg_ep_length": optimal_ave_ep_length,
                    }
                )
            for k in ["success", "found_goal", "max_comb_correct"]:
                if k in test_results[0]:
                    output_data[k] = float(test_results[0][k])

            output_queue.put((seed, output_data,))
    except queue.Empty:
        LOGGER.info("Queue empty for worker {}, exiting.".format(process_id))
