import logging
import os
import time
import typing

import babyai
import blosc
import gin
import torch
import torch.multiprocessing as mp
from tqdm import tqdm

from extensions.rl_minigrid.minigrid_experiments.base import (
    MiniGridBaseExperimentConfig,
)
from extensions.rl_minigrid.minigrid_tasks import MiniGridTaskSampler
from main import _get_args, _load_config
from utils.misc_utils import partition_sequence

mp = mp.get_context("forkserver")
import queue
from setproctitle import setproctitle as ptitle
import numpy as np

LOGGER = logging.getLogger("embodiedrl")


def collect_demos(
    process_id: int, args, input_queue: mp.Queue, output_queue: mp.Queue,
):
    """Saves a collection of training demos."""
    ptitle("({}) Demo Saver".format(process_id))

    output_data_list = []
    try:
        cfg: MiniGridBaseExperimentConfig
        gin.clear_config()
        cfg, _ = _load_config(args)  # type: ignore

        type(cfg).AGENT_VIEW_CHANNELS = 3

        wait_episodes = 100  # if torch.cuda.is_available() else 1

        task_sampler_args = cfg.train_task_sampler_args(
            process_ind=0, total_processes=0,
        )
        task_sampler = typing.cast(
            MiniGridTaskSampler,
            cfg.make_sampler_fn(
                **{
                    **task_sampler_args,
                    "task_seeds_list": ["UNDEFINED"],
                    "deterministic_sampling": True,
                    "repeat_failed_task_for_min_steps": 0,
                }
            ),
        )

        while True:
            seeds = input_queue.get(timeout=1)

            for seed in seeds:
                task_sampler.task_seeds_list[0] = seed

                task = task_sampler.next_task()
                images = []
                actions = []
                directions = []

                def append_values():
                    assert not task.is_done()

                    obs = task.get_observations()
                    images.append(obs["minigrid_ego_image"])
                    actions.append(int(obs["expert_action"].reshape(-1)[0]))
                    directions.append(task.env.agent_dir)

                while not task.is_done():
                    append_values()
                    task.step(action=actions[-1])

                output_data_list.append(
                    {
                        "seed": seed,
                        "images": blosc.pack_array(np.array(images)),
                        "actions": actions,
                        "directions": directions,
                    }
                )

                if len(output_data_list) >= wait_episodes:
                    output_queue.put(output_data_list)
                    # print(
                    #     sum(len(od["actions"]) for od in output_data_list)
                    #     / len(output_data_list)
                    # )
                    output_data_list = []
    except queue.Empty:
        if len(output_data_list) != 0:
            output_queue.put(output_data_list)

        LOGGER.info("Queue empty for worker {}, exiting.".format(process_id))


def create_demos(args, nprocesses, min_demos):
    assert args.experiment in ["", "bc"], "`--experiment` must be either empty or 'bc'."
    assert len(args.gp) >= 1
    assert os.path.relpath(args.output_dir) != ""

    task_name = args.gp[0].split("=")[-1].strip()[1:-1]

    ptitle("Master (DEMOs {})".format(" and ".join(task_name)))

    output_dir = args.output_dir
    os.makedirs(output_dir, exist_ok=True)

    demos_save_path = os.path.join(output_dir, "MiniGrid-{}-v0.pkl".format(task_name))

    if os.path.exists(demos_save_path):
        demos_list = babyai.utils.load_demos(demos_save_path)
        if len(demos_list) > min_demos:
            min_demos = len(demos_list)
        demos_list.extend([None] * (min_demos - len(demos_list)))
        remaining_seeds = set(i for i, d in enumerate(demos_list) if d is None)
    else:
        demos_list = [None] * min_demos
        remaining_seeds = set(range(min_demos))

    if len(remaining_seeds) == 0:
        print("No more demos to save for task {}".format(task_name))
        return len(demos_list), sum([len(dl[3]) for dl in demos_list])

    print("Beginning to save demos with {} remaining".format(len(remaining_seeds)))

    input_queue = mp.Queue()
    for seeds in partition_sequence(
        list(remaining_seeds), min(2 ** 15 - 1, len(remaining_seeds))
    ):
        # Annoyingly a mp.Queue can hold a max of 2**15 - 1 items so we have to do this hack
        input_queue.put(seeds)

    output_queue = mp.Queue()

    processes = []
    for i in range(min(nprocesses, len(remaining_seeds))):
        processes.append(
            mp.Process(
                target=collect_demos,
                kwargs=dict(
                    process_id=i,
                    args=args,
                    input_queue=input_queue,
                    output_queue=output_queue,
                ),
            )
        )
        processes[-1].start()
        time.sleep(0.1)

    with tqdm(total=len(remaining_seeds)) as pbar:
        total_demos_created = sum(d is not None for d in demos_list)
        while len(remaining_seeds) != 0:
            try:
                run_data_list = output_queue.get(timeout=60)
                for run_data in run_data_list:
                    remaining_seeds.remove(run_data["seed"])

                    demos_list[run_data["seed"]] = (
                        "",
                        run_data["images"],
                        run_data["directions"],
                        run_data["actions"],
                    )

                    total_demos_created += 1
                    if total_demos_created % 10000 == 0:
                        babyai.utils.save_demos(demos_list, demos_save_path)

                    pbar.update(1)
            except queue.Empty as _:
                print("No demo saved for 60 seconds")

    babyai.utils.save_demos(demos_list, demos_save_path)

    for p in processes:
        try:
            p.join(1)
        except Exception as _:
            pass

    print("Single stage of saving data is done!")

    return len(demos_list), sum([len(dl[3]) for dl in demos_list])


if __name__ == "__main__":
    """Run this with the following command:

    Command:
    pipenv run python minigrid_scripts/save_expert_demos.py \
    --experiment bc \
    --experiment_base minigrid_experiments/key_corridor \
    --output_dir minigrid_data/minigrid_demos \
    --gp "task_name.name = 'WallCrossingCorruptExpertS25N10'"
    """

    args = _get_args()
    assert len(args.gp) == 1
    initial_processes = min(6 if not torch.cuda.is_available() else 10, mp.cpu_count())
    nprocesses = min(6 if not torch.cuda.is_available() else 56, mp.cpu_count())

    args.gp.append("hyperparams.lr = 12")  # Silly but necessarily, essentially ignored
    min_demos = int(20)
    count = 0
    while count < int(1e6):
        min_demos, count = create_demos(
            args,
            nprocesses=initial_processes if count == 0 else nprocesses,
            min_demos=min_demos,
        )
        min_demos = max(int(1e6 / (count / min_demos)), min_demos) + 100

    print("Saving explore combination data is done!")
