import itertools
import logging
import math
import os
import time
from typing import Tuple, Dict, List

import torch
import torch.multiprocessing as mp

from extensions.rl_poisoneddoors.poisoneddoors_scripts.save_poisoneddoors_data import (
    iteratively_run_poisoneddoors_experiments,
)
from main import _get_args
from utils.misc_utils import uninterleave, rand_float

mp = mp.get_context("forkserver")
from setproctitle import setproctitle as ptitle
import pandas as pd
import numpy as np

LOGGER = logging.getLogger("embodiedrl")


if __name__ == "__main__":
    """Run this with the following command.

    Command:
    ```
    pipenv run python \
    extensions/rl_poisoneddoors/poisoneddoors_scripts/poisoneddoors_random_hp_search.py \
    --experiment_base extensions/rl_poisoneddoors/poisoneddoors_experiments/\
    --single_process_training \
    --output_dir experiment_output/poisoneddoors_random_hp_runs \
    --env_name PoisonedDoors
    ```
    """

    args = _get_args()

    nsamples = 50
    nprocesses = min(1 if not torch.cuda.is_available() else 56, mp.cpu_count())
    gpu_ids = (
        [] if not torch.cuda.is_available() else list(range(torch.cuda.device_count()))
    )

    np.random.seed(1)
    lr_samples = np.exp(rand_float(math.log(1e-4), math.log(0.5), nsamples))

    np.random.seed(2)
    tf_ratios = rand_float(0.1, 0.9, nsamples)

    np.random.seed(3)
    fixed_alphas = [np.random.choice([1.0, 20.0]) for _ in range(nsamples)]

    lr_gps = ["hyperparams.lr = {}".format(lr) for lr in lr_samples]
    tf_ratio_gps = ["hyperparams.tf_ratio = {}".format(ratio) for ratio in tf_ratios]
    fixed_alpha_gps = [
        "hyperparams.fixed_alpha = {}".format(fixed_alpha)
        for fixed_alpha in fixed_alphas
    ]

    lr_tf_ratio_gps = list(zip(lr_gps, tf_ratio_gps))

    fixed_advisor_gps = list(zip(lr_gps, fixed_alpha_gps))

    dagger_fixed_advisor_gps = list(zip(lr_gps, tf_ratio_gps, fixed_alpha_gps))

    lr_gps = list(zip(lr_gps))

    # lr value for all methods comes from dagger then ppo's optimal lr value
    experiment_types_and_gps: Dict[str, List[Tuple[str, ...]]] = {
        "bc_teacher_forcing_then_ppo": lr_tf_ratio_gps,
        "bc_teacher_forcing_then_advisor_fixed_alpha_different_head_weights": dagger_fixed_advisor_gps,
        "bc": lr_gps,
        "dagger": lr_tf_ratio_gps,
        "ppo": lr_gps,
        "advisor_fixed_alpha_different_heads": fixed_advisor_gps,
        "bc_teacher_forcing": lr_gps,
    }

    # Writing things this way to mirror `minigrid_random_hp.search`
    # and make it easier to diff
    experiment_types_and_gps.update(
        {
            "bc_then_ppo": lr_tf_ratio_gps,
            "dagger_then_ppo": lr_tf_ratio_gps,
            "dagger_then_advisor_fixed_alpha_different_head_weights": dagger_fixed_advisor_gps,
        }
    )

    experiment_types_and_gps.update(
        {
            "ppo_with_offpolicy_advisor_fixed_alpha_different_heads": fixed_advisor_gps,
            "ppo_with_offpolicy": lr_gps,
            "pure_offpolicy": lr_gps,
        }
    )

    # Currently, saving data for one task at a time
    task_names = [args.env_name]
    exp_type_gp_params = []

    for exp_type, gp_params_variants in experiment_types_and_gps.items():
        if len(gp_params_variants) == 0:
            gp_params_variants = [None]
        for seed, gp_params in enumerate(gp_params_variants):
            exp_type_gp_params.append((exp_type, gp_params, seed))

    ptitle("Master ({})".format(" and ".join(task_names)))

    output_dir = args.output_dir

    os.makedirs(output_dir, exist_ok=True)

    assert len(task_names) == 1
    matrix_save_data_path = os.path.join(
        output_dir, "random_hp_search_poisoneddoors_runs_{}.tsv".format(task_names[0]),
    )

    if os.path.exists(matrix_save_data_path):
        df = pd.read_csv(matrix_save_data_path, sep="\t")
        df = df.where(pd.notnull(df), None)
        df["gp_params"] = df["gp_params"].astype(str)
    else:
        df = pd.DataFrame(
            dict(
                poisoneddoors_env=[],
                exp_type=[],
                gp_params=[],
                success=[],
                reached_near_optimal=[],
                avg_ep_length=[],
                train_steps=[],
                found_goal=[],
                max_comb_correct=[],
                seed=[],
                extra_tag=[],
                lr=[],
            )
        )

    seen_tuples = set(
        zip(df["poisoneddoors_env"], df["exp_type"], df["gp_params"], df["seed"])
    )
    all_tuples_to_train_set = set()

    input_queue = mp.Queue()
    total_runs = 0

    for poisoneddoors_env, (exp_type, gp_params, seed) in sum(
        uninterleave(
            list(itertools.product(task_names, exp_type_gp_params)), parts=nsamples,
        ),
        [],
    ):
        total_runs += 1
        t = (poisoneddoors_env, exp_type, gp_params, seed)
        # df loads gp_params as a string
        t_for_matching = (poisoneddoors_env, exp_type, str(gp_params), seed)
        all_tuples_to_train_set.add(t_for_matching)
        if t_for_matching not in seen_tuples:
            input_queue.put(t)

    seen_tuples = seen_tuples & all_tuples_to_train_set
    output_queue = mp.Queue()

    print(
        "{} (of {}) experiments already completed! Running the rest.".format(
            len(seen_tuples), total_runs
        )
    )

    processes = []
    for i in range(min(nprocesses, total_runs - len(seen_tuples))):
        processes.append(
            mp.Process(
                target=iteratively_run_poisoneddoors_experiments,
                kwargs=dict(
                    process_id=i,
                    gpu_id=gpu_ids[i % len(gpu_ids)] if len(gpu_ids) != 0 else None,
                    args=args,
                    input_queue=input_queue,
                    output_queue=output_queue,
                    should_log=not args.disable_logging,
                ),
            )
        )
        processes[-1].start()
        time.sleep(0.1)

    while len(seen_tuples) != total_runs:
        output_seed, run_data = output_queue.get()

        seen_tuples.add(
            (
                run_data["poisoneddoors_env"],
                run_data["exp_type"],
                run_data["gp_params"],
                output_seed,
            )
        )

        df = df.append(run_data, ignore_index=True)

        df.to_csv(matrix_save_data_path, sep="\t", index=False)

    for p in processes:
        try:
            p.join(1)
        except Exception as _:
            pass

    print("Saving poisoneddoors data is done!")
