import logging
import os
import time
from queue import Empty
from typing import Dict

import torch
import torch.multiprocessing as mp

from extensions.rl_lighthouse.lighthouse_scripts.save_pairwise_imitation_data import (
    iteratively_run_lighthouse_experiments,
)
from main import _get_args
from rl_base.common import ActorCriticOutput

mp = mp.get_context("forkserver")
from setproctitle import setproctitle as ptitle
import pandas as pd
import numpy as np

LOGGER = logging.getLogger("embodiedrl")


def compute_expert_log_probs_from_rollouts(
    actor_critic_output: ActorCriticOutput, observations: Dict[str, torch.Tensor]
):
    return (
        (
            actor_critic_output.distributions.log_probs_tensor
            * observations["expert_policy"][:, :-1]
        )
        .sum(-1)
        .mean()
        .item()
    )


if __name__ == "__main__":
    # Controls the master process that:
    # (1) Instantiates several subprocesses which run the experiments.
    # (2) Collates the results from the experiments run in the subprocesses.

    # Get command line arguments that define the experiment. For instance, we might run
    # this script (from within the `rl_lighthouse` directory), with arguments
    #
    # ```
    #   --expertiment dagger_then_ppo \
    #   --experiment_base ../lighthouse_experiments \
    #   --single_process_training \
    #   --output_dir pairwise_training
    # ```
    #
    # And this will train using the `dagger_then_ppo` experiment using the 10
    # possible view radii (view radius for the expert will always be the maximum value)
    # and varying the learning rate between 1e-1 and 1e-5 on a log scale.

    # Get command line arguments
    args = _get_args()

    # Define fixed parameters
    world_dim = 2
    world_radius = 15
    view_radii = list(range(1, 16, 2))
    use_expert = args.experiment.split(".")[-1] not in ["a2c", "ppo"]
    should_log = not torch.cuda.is_available()
    test_seed_offset = int(1e6)

    learning_rates = np.exp(np.linspace(np.log(10), np.log(1e-4), num=100))

    nprocesses = min(1 if not torch.cuda.is_available() else 56, mp.cpu_count())
    gpu_ids = (
        [] if not torch.cuda.is_available() else list(range(torch.cuda.device_count()))
    )

    ptitle("Master (lr optimizer, LH)")

    output_dir = args.output_dir

    os.makedirs(output_dir, exist_ok=True)

    # Where to save data
    tsv_save_data_path = os.path.join(
        output_dir,
        "{}__{}_{}.tsv".format(
            args.experiment.replace(".", "_"), world_dim, world_radius
        ),
    )

    # Get any experiment data already saved (e.g. from previous runs)
    if os.path.exists(tsv_save_data_path):
        df = pd.read_csv(tsv_save_data_path, sep="\t")
    else:
        df = pd.DataFrame(
            dict(
                view_radius=[],
                expert_view_radius=[],
                reached_near_optimal=[],
                avg_ep_length=[],
                train_steps=[],
                seed=[],
                start_time_str=[],
                lr=[],
                expert_ce=[],
            )
        )

    # The experiments we've already run
    seen_tuples = set(
        zip(
            df["view_radius"],
            [None if np.isnan(x) else x for x in df["expert_view_radius"]],
            df["seed"],
            df["lr"],
        )
    )

    # Add experiments details into the `input_queue` but
    # don't include experiments we've already run.
    input_queue = mp.Queue()
    total_runs = 0
    for i, view_radius in enumerate(view_radii):
        for seed, lr in enumerate(learning_rates):
            total_runs += 1
            t = (view_radius, max(view_radii) if use_expert else None, seed, lr)
            if t not in seen_tuples:
                input_queue.put(t)

    output_queue = mp.Queue()

    # Create the subprocesses that run experiments.
    processes = []
    for i in range(min(nprocesses, total_runs - len(seen_tuples))):
        processes.append(
            mp.Process(
                target=iteratively_run_lighthouse_experiments,
                args=(
                    i,
                    gpu_ids[i % len(gpu_ids)] if len(gpu_ids) != 0 else None,
                    args,
                    world_dim,
                    world_radius,
                    input_queue,
                    output_queue,
                    should_log,
                    {"expert_ce": compute_expert_log_probs_from_rollouts},
                    test_seed_offset,
                ),
            )
        )
        processes[-1].start()
        time.sleep(0.1)

    # Save experimental results from the subprocesses into a tsv file.
    while len(seen_tuples) != total_runs:
        try:
            new_seen_tuple, run_data = output_queue.get(timeout=60)

            seen_tuples.add(new_seen_tuple)

            df = df.append(run_data, ignore_index=True)

            df.to_csv(tsv_save_data_path, sep="\t", index=False)
        except Empty:
            print("Queue empty for 60 seconds, trying again...")

    for p in processes:
        try:
            p.join(1)
        except Exception as _:
            pass

    print("Saving learning rate optimization data is done!")
