import logging
import os
import time
from typing import Optional, Dict, Callable

import gin
import torch
import torch.multiprocessing as mp

from extensions.rl_lighthouse.lighthouse_environment import LightHouseEnvironment
from extensions.rl_lighthouse.lighthouse_experiments.base import (
    BaseLightHouseExperimentConfig,
)
from main import _get_args, _init_logging, _load_config
from onpolicy_sync.engine import OnPolicyTrainer, OnPolicyTester
from rl_base.common import ActorCriticOutput

mp = mp.get_context("forkserver")
import queue
from setproctitle import setproctitle as ptitle
import pandas as pd
import numpy as np

LOGGER = logging.getLogger("embodiedrl")


def iteratively_run_lighthouse_experiments(
    process_id: int,
    gpu_id: Optional[int],
    args,
    world_dim: int,
    world_radius: int,
    input_queue: mp.Queue,
    output_queue: mp.Queue,
    should_log: bool,
    str_to_extra_metrics_func: Optional[
        Dict[str, Callable[[ActorCriticOutput, Dict[str, torch.Tensor]], float]]
    ] = None,
    test_seed_offset: Optional[int] = None,
):
    """Iteratively train lighthouse models with different levels of
    supervision.

    This function is meant to be run as a subprocess. It iteratively samples
    `input_queue` from a queue that define the next experiment to run (e.g. the agent's view
    radius, the expert's view radius, and any seed). It then runs this experiment
    and adds the results of the experiment to the `output_queue` which is collated
    by the main process.

    # Attributes

    process_id : This process' id.
    gpu_id : The gpu to run experiments on.
    args : Command line arguments specifying the experiment config to run. E.g.
        `extensions.rl_lighthouse.lighthouse_experiments.advisor`. Details of this
        experiment config, for instance the agent's `view_radius` config are modified
        by this function based on the values from the `input_queue`.
    world_dim : The world dimension used in all experiments.
    world_radius : The world radius used in all experiments.
    input_queue : The queue from which experiment details are taken.
    output_queue : The queue into which the results of running an experiment are saved.
    str_to_extra_metrics_func : A collection of named metric functions to be called during
        evaluation, see the use of `compute_expert_log_probs_from_rollouts` for an example.
        This can be useful for computing metrics during evaluation that aren't provided by
        default (e.g. losses).
    test_seed_offset : If not `None`, used to redefine the `TEST_SEED_OFFSET` class constant
        associated with the experiment config.
    """
    ptitle("({}) Create Im. Mat. Runner".format(process_id))

    _init_logging()

    # Grab the experiment config and set its GPU_ID.
    cfg: BaseLightHouseExperimentConfig
    gin.clear_config()
    cfg, _ = _load_config(args)  # type: ignore
    type(cfg).GPU_ID = gpu_id
    type(cfg).SHOULD_LOG = should_log

    if test_seed_offset is not None:
        type(cfg).TEST_SEED_OFFSET = test_seed_offset

    lr_changed = False

    str_to_extra_metrics_func = (
        {} if str_to_extra_metrics_func is None else str_to_extra_metrics_func
    )

    try:
        while True:
            # Sample a new set of values defining the new experiment to run
            view_radius, expert_view_radius, seed, lr = input_queue.get(timeout=1)
            optimal_ave_ep_length = LightHouseEnvironment.optimal_ave_ep_length(
                world_dim=world_dim, world_radius=world_radius, view_radius=view_radius
            )

            if lr is not None:
                type(cfg).DEFAULT_LR = lr
                lr_changed = True
            assert (
                not lr_changed
            ) or lr is not None, (
                "If lr is changed once, it must be changed on very iteration."
            )

            LOGGER.info(
                "Running with (view, expert view, seed, lr) = ({}, {}, {}, {:.3g}). Target optimal ep length: {}.".format(
                    view_radius,
                    expert_view_radius,
                    seed,
                    cfg.DEFAULT_LR,
                    optimal_ave_ep_length,
                )
            )

            type(cfg).VIEW_RADIUS = view_radius
            type(cfg).EXPERT_VIEW_RADIUS = expert_view_radius

            # Train the agent based on the experiment config.
            trainer = OnPolicyTrainer(
                config=cfg,
                output_dir=args.output_dir,
                loaded_config_src_files=None,
                seed=seed,
                deterministic_cudnn=args.deterministic_cudnn,
                extra_tag=args.extra_tag,
                single_process_training=args.single_process_training,
            )
            trainer.run_pipeline()

            start_time_str = trainer.local_start_time_str

            # Save the model after the above training.
            chkpt_file_path = trainer.checkpoint_save()

            # Use the above checkpoint and run a test evaluation.
            test_results = OnPolicyTester(
                config=cfg,
                output_dir=args.output_dir,
                seed=seed,
                deterministic_cudnn=args.deterministic_cudnn,
                single_process_training=args.single_process_training,
                should_log=False,
            ).run_test(
                experiment_date="",
                checkpoint_file_name=chkpt_file_path,
                skip_checkpoints=args.skip_checkpoints,
                deterministic_agent=False,
                str_to_extra_metrics_func=str_to_extra_metrics_func,
            )

            # Remove the checkpoint file saved above as we no longer need it.
            os.remove(chkpt_file_path)

            # Put results from test evaluation into the output queue to be
            # collated by the main thread.
            output_queue.put(
                (
                    (view_radius, expert_view_radius, seed, lr),
                    {
                        "view_radius": int(view_radius),
                        "expert_view_radius": None
                        if expert_view_radius is None
                        else int(expert_view_radius),
                        "optimal": optimal_ave_ep_length,
                        "reached_near_optimal": 1
                        * (test_results[0]["ep_length"] < optimal_ave_ep_length * 1.1),
                        "avg_ep_length": float(test_results[0]["ep_length"]),
                        "train_steps": int(test_results[0]["training_steps"]),
                        "seed": seed,
                        "start_time_str": start_time_str,
                        "lr": lr,
                        **{
                            key: test_results[0][key]
                            for key in str_to_extra_metrics_func
                        },
                    },
                )
            )
    except queue.Empty:
        LOGGER.info("Queue empty for worker {}, exiting.".format(process_id))


if __name__ == "__main__":
    # Controls the master process that:
    # (1) Instantiates several subprocesses which run the experiments.
    # (2) Collates the results from the experiments run in the subprocesses.

    # Get command line arguments that define the experiment. For instance, we might run
    # this script (from within the `rl_lighthouse` directory), with arguments
    #
    # ```
    #   dagger_then_ppo \
    #   --experiment_base ../lighthouse_experiments \
    #   --single_process_training \
    #   --output_dir pairwise_training
    # ```
    #
    # And this will exhaustively train using the `dagger_then_ppo` experiment
    # with various agent/expert view radii.

    # Get command line arguments
    args = _get_args()

    # Define fixed parameters
    world_dim = 2
    world_radius = 15
    view_radii = list(range(1, 16, 2))
    use_experts = args.experiment.split(".")[-1] not in ["a2c", "ppo"]
    nrepeats = 25 if use_experts else 50  # Number of random seeds per experiment
    nprocesses = 1 if not torch.cuda.is_available() else 50
    gpu_ids = (
        [] if not torch.cuda.is_available() else list(range(torch.cuda.device_count()))
    )

    ptitle("Master (pairwise)")

    output_dir = args.output_dir

    os.makedirs(output_dir, exist_ok=True)

    # Where to save data
    tsv_save_data_path = os.path.join(
        output_dir,
        "{}__{}_{}.tsv".format(
            args.experiment.replace(".", "_"), world_dim, world_radius
        ),
    )

    # Get any experiment data already saved (e.g. from previous runs)
    if os.path.exists(tsv_save_data_path):
        df = pd.read_csv(tsv_save_data_path, sep="\t")
    else:
        df = pd.DataFrame(
            dict(
                view_radius=[],
                expert_view_radius=[],
                reached_near_optimal=[],
                avg_ep_length=[],
                train_steps=[],
                seed=[],
                start_time_str=[],
            )
        )

    # The experiments we've already run
    seen_tuples = set(
        zip(
            df["view_radius"],
            [None if np.isnan(x) else x for x in df["expert_view_radius"]],
            df["seed"],
        )
    )

    # Add experiments details into the `input_queue` but
    # don't include experiments we've already run.
    input_queue = mp.Queue()
    total_runs = 0
    for i, view_radius in enumerate(view_radii):
        for expert_view_radius in view_radii[i:] if use_experts else [None]:
            for seed in range(nrepeats):
                total_runs += 1
                t = (view_radius, expert_view_radius, seed)
                if t not in seen_tuples:
                    input_queue.put(t + (None,))

    output_queue = mp.Queue()

    # Create the subprocesses that run experiments.
    processes = []
    for i in range(min(nprocesses, total_runs - len(seen_tuples))):
        processes.append(
            mp.Process(
                target=iteratively_run_lighthouse_experiments,
                args=(
                    i,
                    gpu_ids[i % len(gpu_ids)] if len(gpu_ids) != 0 else None,
                    args,
                    world_dim,
                    world_radius,
                    input_queue,
                    output_queue,
                    True,
                ),
            )
        )
        processes[-1].start()
        time.sleep(0.1)

    # Save experimental results from the subprocesses into a tsv file.
    while len(seen_tuples) != total_runs:
        new_seen_tuple, run_data = output_queue.get()

        seen_tuples.add(new_seen_tuple[:-1])  # Don't include the learning rate

        df = df.append(run_data, ignore_index=True)

        df.to_csv(tsv_save_data_path, sep="\t", index=False)

    for p in processes:
        try:
            p.join(1)
        except Exception as _:
            pass

    print("Saving pairwise imitation data is done!")
