import os
import queue
import random
import typing
from collections import defaultdict
from typing import Dict, Union, Tuple, Iterator, List, Any
from typing import Optional

import babyai.utils
import numpy as np
import torch
from gym_minigrid.minigrid import MiniGridEnv

from extensions.rl_babyai.babyai_sensors import BabyAIMissionSensor
from offpolicy_sync.losses.abstract_offpolicy_loss import AbstractOffPolicyLoss
from onpolicy_sync.engine import LOGGER
from onpolicy_sync.losses.advisor import AlphaScheduler, AdvisorWeightedStage
from onpolicy_sync.policy import ActorCriticModel

_DATASET_CACHE: Dict[str, Any] = {}


class BabyAIOffPolicyExpertCELoss(AbstractOffPolicyLoss[ActorCriticModel]):
    def __init__(self, total_episodes_in_epoch: Optional[int] = None):
        super().__init__()
        self.total_episodes_in_epoch = total_episodes_in_epoch

    def loss(
        self,
        model: ActorCriticModel,
        batch: Dict[str, Union[torch.Tensor, Dict[str, torch.Tensor]]],
        memory: Dict[str, torch.Tensor],
        *args,
        **kwargs
    ) -> Tuple[torch.FloatTensor, Dict[str, float], Dict[str, torch.Tensor]]:

        rollout_len, nrollouts, _, _, _ = batch["minigrid_ego_image"].shape

        recurrent_hidden_states = memory.get(
            "recurrent_hidden_states",
            torch.zeros(
                (
                    model.num_recurrent_layers,
                    nrollouts,
                    model.recurrent_hidden_state_size,
                )
            ).to(batch["minigrid_ego_image"].device),
        )

        observations = {}
        for k in ["minigrid_ego_image", "babyai_mission"]:
            if k in batch:
                observations[k] = batch[k].view(
                    rollout_len * nrollouts, *batch[k].shape[2:]
                )

        ac_out, new_hidden_states = model.forward(
            observations=observations,
            recurrent_hidden_states=recurrent_hidden_states,
            prev_actions=None,
            masks=batch["masks"].view(rollout_len * nrollouts, -1),
        )

        memory["recurrent_hidden_states"] = recurrent_hidden_states

        expert_ce_loss = -ac_out.distributions.log_probs(
            batch["expert_action"].view(rollout_len * nrollouts, 1)
        ).mean()

        info = {"expert_ce": expert_ce_loss.item()}

        if self.total_episodes_in_epoch is not None:
            if "completed_episode_count" not in memory:
                memory["completed_episode_count"] = 0
            memory["completed_episode_count"] += (
                int(np.prod(batch["masks"].shape)) - batch["masks"].sum().item()
            )
            info["epoch_progress"] = (
                memory["completed_episode_count"] / self.total_episodes_in_epoch
            )

        return expert_ce_loss, info, memory


class BabyAIOffPolicyAdvisorLoss(AbstractOffPolicyLoss[ActorCriticModel]):
    def __init__(
        self,
        total_episodes_in_epoch: Optional[int] = None,
        fixed_alpha: Optional[float] = 1,
        fixed_bound: Optional[float] = 0.1,
        alpha_scheduler: AlphaScheduler = None,
        smooth_expert_weight_decay: Optional[float] = None,
        *args,
        **kwargs
    ):
        super().__init__()

        self.advisor_loss = AdvisorWeightedStage(
            rl_loss=None,
            fixed_alpha=fixed_alpha,
            fixed_bound=fixed_bound,
            alpha_scheduler=alpha_scheduler,
            smooth_expert_weight_decay=smooth_expert_weight_decay,
            *args,
            **kwargs
        )
        self.total_episodes_in_epoch = total_episodes_in_epoch

    def loss(
        self,
        step_count: int,
        model: ActorCriticModel,
        batch: Dict[str, Union[torch.Tensor, Dict[str, torch.Tensor]]],
        memory: Dict[str, torch.Tensor],
        *args,
        **kwargs
    ) -> Tuple[torch.FloatTensor, Dict[str, float], Dict[str, torch.Tensor]]:

        rollout_len, nrollouts, _, _, _ = batch["minigrid_ego_image"].shape

        recurrent_hidden_states = memory.get(
            "recurrent_hidden_states",
            torch.zeros(
                (
                    model.num_recurrent_layers,
                    nrollouts,
                    model.recurrent_hidden_state_size,
                )
            ).to(batch["minigrid_ego_image"].device),
        )

        observations = {}
        for k in ["minigrid_ego_image", "babyai_mission"]:
            if k in batch:
                observations[k] = batch[k].view(
                    rollout_len * nrollouts, *batch[k].shape[2:]
                )

        ac_out, new_hidden_states = model.forward(
            observations=observations,
            recurrent_hidden_states=recurrent_hidden_states,
            prev_actions=None,
            masks=batch["masks"].view(rollout_len * nrollouts, -1),
        )

        memory["recurrent_hidden_states"] = recurrent_hidden_states

        total_loss, losses_dict = self.advisor_loss.loss(
            step_count=step_count,
            batch={
                "observations": {
                    "expert_action": torch.cat(
                        (
                            batch["expert_action"].view(rollout_len * nrollouts, 1),
                            torch.ones(
                                rollout_len * nrollouts, 1, dtype=torch.int64
                            ).to(batch["expert_action"].device),
                        ),
                        dim=1,
                    )
                }
            },
            actor_critic_output=ac_out,
        )

        info = {"offpolicy_" + key: val for key, val in losses_dict.items()}

        if self.total_episodes_in_epoch is not None:
            if "completed_episode_count" not in memory:
                memory["completed_episode_count"] = 0
            memory["completed_episode_count"] += (
                int(np.prod(batch["masks"].shape)) - batch["masks"].sum().item()
            )
            info["epoch_progress"] = (
                memory["completed_episode_count"] / self.total_episodes_in_epoch
            )

        return (
            total_loss,
            info,
            memory,
        )


class ExpertTrajectoryIterator(Iterator):
    def __init__(
        self,
        data: List[Tuple[str, bytes, List[int], MiniGridEnv.Actions]],
        nrollouts: int,
        rollout_len: int,
        instr_len: Optional[int],
        restrict_max_steps_in_dataset: Optional[int] = None,
    ):
        super(ExpertTrajectoryIterator, self).__init__()
        self.restrict_max_steps_in_dataset = restrict_max_steps_in_dataset

        if restrict_max_steps_in_dataset is not None:
            restricted_data = []
            cur_len = 0
            for i, d in enumerate(data):
                if cur_len >= restrict_max_steps_in_dataset:
                    break
                restricted_data.append(d)
                cur_len += len(d[2])
            data = restricted_data

        self.data = data
        self.trajectory_inds = list(range(len(data)))
        self.instr_len = instr_len
        random.shuffle(self.trajectory_inds)

        assert nrollouts <= len(self.trajectory_inds), "Too many rollouts requested."

        self.nrollouts = nrollouts
        self.rollout_len = rollout_len

        self.rollout_queues: List[queue.Queue] = [
            queue.Queue() for _ in range(nrollouts)
        ]
        for q in self.rollout_queues:
            self.add_data_to_rollout_queue(q)

        self.babyai_mission_sensor: Optional[BabyAIMissionSensor] = None
        if instr_len is not None:
            self.babyai_mission_sensor = BabyAIMissionSensor({"instr_len": instr_len,})

    def add_data_to_rollout_queue(self, q: queue.Queue) -> bool:
        assert q.empty()
        if len(self.trajectory_inds) == 0:
            return False

        for i, step in enumerate(
            babyai.utils.demos.transform_demos([self.data[self.trajectory_inds.pop()]])[
                0
            ]
        ):
            q.put((*step, i == 0))

        return True

    def get_data_for_rollout_ind(self, rollout_ind: int) -> Dict[str, np.ndarray]:
        masks = []
        minigrid_ego_image = []
        babyai_mission = []
        expert_actions = []
        q = self.rollout_queues[rollout_ind]
        while len(masks) != self.rollout_len:
            if q.empty():
                if not self.add_data_to_rollout_queue(q):
                    raise StopIteration()

            obs, expert_action, _, is_first_obs = typing.cast(
                Tuple[
                    Dict[str, Union[np.array, int, str]],
                    MiniGridEnv.Actions,
                    bool,
                    bool,
                ],
                q.get_nowait(),
            )

            masks.append(not is_first_obs)
            minigrid_ego_image.append(obs["image"])
            if self.babyai_mission_sensor is not None:
                # noinspection PyTypeChecker
                babyai_mission.append(
                    self.babyai_mission_sensor.get_observation(
                        env=None, task=None, minigrid_output_obs=obs
                    )
                )
            expert_actions.append([expert_action])

        to_return = {
            "masks": np.array(masks, dtype=np.float32),
            "minigrid_ego_image": np.stack(minigrid_ego_image, axis=0),
            "expert_action": np.array(expert_actions, dtype=np.int64),
        }
        if self.babyai_mission_sensor is not None:
            to_return["babyai_mission"] = np.stack(babyai_mission, axis=0)
        return to_return

    def __next__(self) -> Dict[str, torch.Tensor]:
        all_data = defaultdict(lambda: [])
        for rollout_ind in range(self.nrollouts):
            data_for_ind = self.get_data_for_rollout_ind(rollout_ind=rollout_ind)
            for key in data_for_ind:
                all_data[key].append(data_for_ind[key])
        return {
            key: torch.from_numpy(np.stack(all_data[key], axis=1)) for key in all_data
        }


def create_babyai_offpolicy_data_iterator(
    path: str,
    nrollouts: int,
    rollout_len: int,
    instr_len: Optional[int],
    restrict_max_steps_in_dataset: Optional[int] = None,
) -> ExpertTrajectoryIterator:
    path = os.path.abspath(path)

    if path not in _DATASET_CACHE:
        LOGGER.info("Loading babyai dataset from {} for first time...".format(path))
        _DATASET_CACHE[path] = babyai.utils.load_demos(path)
        assert _DATASET_CACHE[path] is not None and len(_DATASET_CACHE[path]) != 0
        LOGGER.info(
            "Loading babyai dataset complete, it contains {} trajectories".format(
                len(_DATASET_CACHE[path])
            )
        )
    return ExpertTrajectoryIterator(
        data=_DATASET_CACHE[path],
        nrollouts=nrollouts,
        rollout_len=rollout_len,
        instr_len=instr_len,
        restrict_max_steps_in_dataset=restrict_max_steps_in_dataset,
    )
