import logging
import time
from copy import copy, deepcopy
from typing import Any, cast

import gymnasium as gym
import numpy as np
import torch

from tianshou.data import (
    Batch,
    ReplayBuffer,
    to_numpy,
)
from tianshou.data.types import (
    ObsBatchProtocol,
    RolloutBatchProtocol,
)
from tianshou.data import AsyncCollector, CollectStats
from tianshou.env import BaseVectorEnv
from tianshou.policy import BasePolicy
from tianshou.data.collector import _nullable_slice, _HACKY_create_info_batch, _dict_of_arr_to_arr_of_dicts

from data.mpec.utils.policy_selection_collector import PolicySelectionCollector

from data.necsa.abstraction_mode import HS
from utils.json_util import encode_dict_with_tuple_keys

log = logging.getLogger(__name__)

class MPECCollector(PolicySelectionCollector):
    # NAMING CONVENTION (mostly suffixes):
    # episode - An episode means a rollout until done (terminated or truncated). After an episode is completed,
    # the corresponding env is either reset or removed from the ready envs.
    # N - number of envs, always fixed and >= R.
    # R - number ready env ids. Note that this might change when envs get idle.
    #     This can only happen in n_episode case, see explanation in the corresponding block.
    #     For n_step, we always use all envs to collect the data, while for n_episode,
    #     R will be at most n_episode at the beginning, but can decrease during the collection.
    # O - dimension(s) of observations
    # A - dimension(s) of actions
    # H - dimension(s) of hidden state
    # D - number of envs that reached done in the current collect iteration. Only relevant in n_episode case.
    # S - number of surplus envs, i.e. envs that are ready but won't be used in the next iteration.
    #     Only used in n_episode case. Then, R becomes R-S.

    def __init__(
        self,
        policy: BasePolicy,
        env: gym.Env | BaseVectorEnv,
        buffer: ReplayBuffer | None = None,
        exploration_noise: bool = False,
        seed=None,
        MPEC_DICT=None,
        training=False,
        *args,
        **kwargs
    ) -> None:
        """:param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class.
        :param env: a ``gym.Env`` environment or an instance of the
            :class:`~tianshou.env.BaseVectorEnv` class.
        :param buffer: an instance of the :class:`~tianshou.data.ReplayBuffer` class.
            If set to None, will instantiate a :class:`~tianshou.data.VectorReplayBuffer`
            as the default buffer.
        :param exploration_noise: determine whether the action needs to be modified
            with the corresponding policy's exploration noise. If so, "policy.
            exploration_noise(act, batch)" will be called automatically to add the
            exploration noise into action. Default to False.
        """
        super().__init__(policy, env, buffer, exploration_noise, MPEC_DICT, training, *args, **kwargs)

        assert seed is not None

        self.training = training
        self.is_circular_buffer = MPEC_DICT['circular_buffer']

        self.is_debug = MPEC_DICT.get('debug', False)
        self._debug_disable_reconnection = MPEC_DICT.get('debug_disable_reconnection', False)
        self._debug_disable_ssm = MPEC_DICT.get('debug_disable_ssm', False)
        self._debug_disable_cycle_detection = MPEC_DICT.get('debug_disable_cycle_detection', False)
        self._debug_track_trajectories_length = MPEC_DICT.get('debug_track_trajectories_length', False)
        self._debug_track_trajectories_split_and_mismatches = MPEC_DICT.get('debug_track_trajectories_split_and_mismatches', False)
        self._debug_track_policies = MPEC_DICT.get('debug_track_policies', False)
        self.debug_policy_tracking = []

        self.reward_lists = []
        self.ep_reward = []

        for i in range(self.env_num):
            self.reward_lists.append([])

        self.mode = MPEC_DICT['mode']

        ########## result ###########
        self.ep_reward_recorder = []

    def _compute_action_pi_option_policy_hidden(
        self,
        random: bool,
        ready_env_ids_R: np.ndarray,
        second_to_last_obs_R: np.ndarray,
        second_to_last_info_R: np.ndarray,
        last_act_RA: np.ndarray,
        last_pi_option_R: np.ndarray,
        last_obs_RO: np.ndarray,
        last_info_R: np.ndarray,
        last_hidden_state_RH: np.ndarray | torch.Tensor | Batch | None = None,
        last_rew_R: np.ndarray = None
    ) -> tuple[np.ndarray, np.ndarray, Batch, np.ndarray | torch.Tensor | Batch | None, np.ndarray, np.ndarray]:
        """Returns the action, the normalized action, a "policy" entry, and the hidden state."""
        if random:
            act_pi_option = [
                self.policy.get_random_action_pi_option(last_obs_RO[i], ready_env_id) for i, ready_env_id in enumerate(ready_env_ids_R)
            ]
            act_normalized_RA, pi_option_R = map(np.array, zip(*act_pi_option))
            act_batch_RA = Batch(act=act_normalized_RA, pi_option=pi_option_R)
            act_terminated_R = np.zeros(len(ready_env_ids_R), dtype=bool)
        else:
            second_to_last_info_batch = _HACKY_create_info_batch(second_to_last_info_R)
            second_to_last_obs_batch_R = cast(ObsBatchProtocol, Batch(obs=second_to_last_obs_R, info=second_to_last_info_batch))

            info_batch = _HACKY_create_info_batch(last_info_R)
            obs_batch_R = cast(ObsBatchProtocol, Batch(obs=last_obs_RO, info=info_batch))

            act_batch_RA = self.policy(
                second_to_last_obs_batch_R,
                last_act_RA,
                last_pi_option_R,
                obs_batch_R,
                ready_env_ids_R,
                last_hidden_state_RH,
                is_training=self.training,
                last_rew_R=last_rew_R
            )
            act_normalized_RA = act_batch_RA.act
            pi_option_R = act_batch_RA.pi_option

            act_terminated_R = act_batch_RA.terminate_env

        # TODO: cleanup the whole policy in batch thing
        # todo policy_R can also be none, check
        policy_R = act_batch_RA.get("policy", Batch())
        if not isinstance(policy_R, Batch):
            raise RuntimeError(
                f"The policy result should be a {Batch}, but got {type(policy_R)}",
            )

        hidden_state_RH = act_batch_RA.get("state", None)
        # TODO: do we need the conditional? Would be better to just add hidden_state which could be None
        if hidden_state_RH is not None:
            policy_R.hidden_state = (
                hidden_state_RH  # save state into buffer through policy attr
            )

        feature = None
        if self.mode == HS:
            feature = act_batch_RA.get("feature", None)
            feature = feature.detach().cpu().numpy()

        return act_normalized_RA, pi_option_R, policy_R, hidden_state_RH, feature, act_terminated_R

    def _compute_feature(
        self,
        obs_next_RO: np.ndarray,
        last_info_R: np.ndarray,
        last_hidden_state_RH: np.ndarray | torch.Tensor | Batch | None = None,
    ) -> tuple[np.ndarray]:
        """Returns the next hidden state."""

        feature = None
        if self.mode == HS:
            # for necsa
            info_batch = _HACKY_create_info_batch(last_info_R)
            obs_next_batch_R = cast(ObsBatchProtocol, Batch(obs=obs_next_RO, info=info_batch))

            act_batch_RA = self.policy(
                obs_next_batch_R,
                last_hidden_state_RH,
            )

            feature = act_batch_RA.get("feature", None)
            feature = feature.detach().cpu().numpy()

        return feature

    # TODO: reduce complexity, remove the noqa
    def _collect(
        self,
        n_step: int | None = None,
        n_episode: int | None = None,
        random: bool = False,
        render: float | None = None,
        gym_reset_kwargs: dict[str, Any] | None = None,
    ) -> CollectStats:
        # TODO: can't do it init since AsyncCollector is currently a subclass of Collector
        if self.env.is_async:
            raise ValueError(
                f"Please use {AsyncCollector.__name__} for asynchronous environments. "
                f"Env class: {self.env.__class__.__name__}.",
            )

        if n_step is not None:
            ready_env_ids_R = np.arange(self.env_num)
        elif n_episode is not None:
            ready_env_ids_R = np.arange(min(self.env_num, n_episode))
        else:
            raise ValueError("Either n_step or n_episode should be set.")

        start_time = time.time()
        if self._pre_collect_obs_RO is None or self._pre_collect_info_R is None:
            raise ValueError(
                "Initial obs and info should not be None. "
                "Either reset the collector (using reset or reset_env) or pass reset_before_collect=True to collect.",
            )

        # get the first obs to be the current obs in the n_step case as
        # episodes as a new call to collect does not restart trajectories
        # (which we also really don't want)
        step_count = 0
        num_collected_episodes = 0
        episode_returns: list[float] = []
        episode_lens: list[int] = []
        episode_start_indices: list[int] = []

        second_to_last_feature = _nullable_slice(self._pre_collect_second_to_last_feature, ready_env_ids_R)
        is_second_to_last_obs_RO_valid = _nullable_slice(self._pre_collect_is_second_to_last_obs_RO_valid, ready_env_ids_R)
        second_to_last_obs_RO = _nullable_slice(self._pre_collect_second_to_last_obs_RO, ready_env_ids_R)
        second_to_last_info_R = _nullable_slice(self._pre_collect_second_to_last_info_R, ready_env_ids_R)

        last_act_normalized_RA = _nullable_slice(self._pre_collect_last_act_normalized_RA, ready_env_ids_R)
        last_pi_option_R = _nullable_slice(self._pre_collect_last_pi_option_R, ready_env_ids_R)

        # in case we select fewer episodes than envs, we run only some of them
        last_obs_RO = _nullable_slice(self._pre_collect_obs_RO, ready_env_ids_R)
        last_info_R = _nullable_slice(self._pre_collect_info_R, ready_env_ids_R)
        last_hidden_state_RH = _nullable_slice(
            self._pre_collect_hidden_state_RH,
            ready_env_ids_R,
        )

        act_normalized_RA = _nullable_slice(self._pre_collect_act_normalized_RA, ready_env_ids_R)
        pi_option_R = _nullable_slice(self._pre_collect_pi_option_R, ready_env_ids_R)

        last_rew_R = _nullable_slice(self._pre_collect_rew_R, ready_env_ids_R)

        while True:

            feature = np.array([None] * len(ready_env_ids_R))
            if self.mode == HS:
                feature = self._compute_feature(
                    obs_next_RO=last_obs_RO,
                    last_info_R=last_info_R,
                    last_hidden_state_RH=last_hidden_state_RH,
                )

            act_terminated_R = np.zeros(len(ready_env_ids_R), dtype=bool)

            missing_act_args = [arg[0] for arg in np.argwhere(act_normalized_RA == None)]

            second_to_last_obs_RO[~is_second_to_last_obs_RO_valid] = None

            act_normalized_RA_new, pi_option_R_new, policy_R, hidden_state_RH, _, act_terminated_R_new = (
                self._compute_action_pi_option_policy_hidden(
                    random=random,
                    ready_env_ids_R=ready_env_ids_R[missing_act_args],
                    second_to_last_obs_R=second_to_last_obs_RO[missing_act_args],
                    second_to_last_info_R=second_to_last_info_R[missing_act_args],
                    last_act_RA=last_act_normalized_RA[missing_act_args],
                    last_pi_option_R=last_pi_option_R[missing_act_args],
                    last_obs_RO=last_obs_RO[missing_act_args],
                    last_info_R=last_info_R[missing_act_args],
                    last_hidden_state_RH=last_hidden_state_RH,
                    last_rew_R=last_rew_R,
                )
            )

            # TODO: remove workaround when possible. Tianshou is sharing reference somewhere
            #       so last_obs get updated when observation is a dict.
            last_obs_RO = deepcopy(last_obs_RO)

            act_normalized_RA[missing_act_args] = act_normalized_RA_new
            pi_option_R[missing_act_args] = pi_option_R_new
            act_terminated_R[missing_act_args] = act_terminated_R_new

            act_RA = self.policy.map_action_inverse(act_normalized_RA).astype(np.int64)

            obs_next_RO, rew_R, terminated_R, truncated_R, info_R = self.env.step(
                act_normalized_RA,
                ready_env_ids_R,
            )
            if isinstance(info_R, dict):  # type: ignore[unreachable]
                # This can happen if the env is an envpool env. Then the info returned by step is a dict
                info_R = _dict_of_arr_to_arr_of_dicts(info_R)  # type: ignore[unreachable]
            terminated_R = np.logical_or(terminated_R, act_terminated_R)
            done_R = np.logical_or(terminated_R, truncated_R)

            self.policy.update_collected_return(last_rew_R=rew_R)

            current_iteration_batch = cast(
                RolloutBatchProtocol,
                Batch(
                    obs=last_obs_RO,
                    act=act_RA,
                    pi_option=pi_option_R,
                    policy=policy_R,
                    obs_next=obs_next_RO,
                    rew=rew_R,
                    terminated=terminated_R,
                    truncated=truncated_R,
                    done=done_R,
                    info=info_R,
                ),
            )

            # TODO: only makes sense if render_mode is human.
            #  Also, doubtful whether it makes sense at all for true vectorized envs
            if render:
                self.env.render()
                if not np.isclose(render, 0):
                    time.sleep(render)

            # add data into the buffer
            ptr_R, ep_rew_R, ep_len_R, ep_idx_R = self.buffer.add(
                current_iteration_batch,
                buffer_ids=ready_env_ids_R,
            )

            for i, ready_env_id in enumerate(ready_env_ids_R):
                reward = rew_R[i]
                self.reward_lists[ready_env_id].append(reward)

                if done_R[i]:
                    self.ep_reward = deepcopy(self.reward_lists[ready_env_id])
                    self.reward_lists[ready_env_id] = []

                    self.update_buffer(ready_env_id, ep_idx_R[i], ep_len_R[i])

            # collect statistics
            num_episodes_done_this_iter = np.sum(done_R)
            num_collected_episodes += num_episodes_done_this_iter
            step_count += len(ready_env_ids_R)

            # preparing for the next iteration
            # obs_next, info and hidden_state will be modified inplace in the code below,
            # so we copy to not affect the data in the buffer
            second_to_last_feature = copy(feature)
            is_second_to_last_obs_RO_valid = np.array([True] * len(ready_env_ids_R), dtype=bool)
            second_to_last_obs_RO = np.array(copy(last_obs_RO), dtype=object)
            second_to_last_info_R = copy(last_info_R)
            last_act_normalized_RA = copy(act_normalized_RA)
            last_pi_option_R = copy(pi_option_R)

            last_obs_RO = copy(obs_next_RO)
            last_info_R = copy(info_R)
            last_hidden_state_RH = copy(hidden_state_RH)

            act_normalized_RA = np.array([None] * len(ready_env_ids_R))
            pi_option_R = np.array([None] * len(ready_env_ids_R))

            last_rew_R = copy(rew_R)

            # Preparing last_obs_RO, last_info_R, last_hidden_state_RH for the next while-loop iteration
            # Resetting envs that reached done, or removing some of them from the collection if needed (see below)
            if num_episodes_done_this_iter > 0:
                # TODO: adjust the whole index story, don't use np.where, just slice with boolean arrays
                # D - number of envs that reached done in the rollout above
                env_ind_local_D = np.where(done_R)[0]
                env_ind_global_D = ready_env_ids_R[env_ind_local_D]
                episode_lens.extend(ep_len_R[env_ind_local_D])
                episode_returns.extend(ep_rew_R[env_ind_local_D])
                episode_start_indices.extend(ep_idx_R[env_ind_local_D])
                # now we copy obs_next to obs, but since there might be
                # finished episodes, we have to reset finished envs first.

                gym_reset_kwargs = gym_reset_kwargs or {}
                obs_reset_DO, info_reset_D = self.env.reset(
                    env_id=env_ind_global_D,
                    **gym_reset_kwargs,
                )

                # Set the hidden state to zero or None for the envs that reached done
                # TODO: does it have to be so complicated? We should have a single clear type for hidden_state instead of
                #  this complex logic
                self._reset_hidden_state_based_on_type(env_ind_local_D, last_hidden_state_RH)

                second_to_last_feature[env_ind_local_D] = None
                is_second_to_last_obs_RO_valid[env_ind_local_D] = False
                second_to_last_obs_RO[env_ind_local_D] = None
                second_to_last_info_R[env_ind_local_D] = None
                last_act_normalized_RA = last_act_normalized_RA.astype(object)
                last_act_normalized_RA[env_ind_local_D] = None
                last_pi_option_R[env_ind_local_D] = None

                # preparing for the next iteration
                last_obs_RO[env_ind_local_D] = obs_reset_DO
                last_info_R[env_ind_local_D] = info_reset_D

                if not self.training:
                    act_normalized_RA[env_ind_local_D] = _nullable_slice(self._pre_collect_act_normalized_RA, env_ind_local_D)
                    pi_option_R[env_ind_local_D] = _nullable_slice(self._pre_collect_pi_option_R, env_ind_local_D)

                    rew_R = rew_R.astype(np.float32, copy=False) # TODO delete that
                    rew_R[env_ind_local_D] = _nullable_slice(self._pre_collect_rew_R, env_ind_local_D)

                # Handling the case when we have more ready envs than desired and are not done yet
                #
                # This can only happen if we are collecting a fixed number of episodes
                # If we have more ready envs than there are remaining episodes to collect,
                # we will remove some of them for the next rollout
                # One effect of this is the following: only envs that have completed an episode
                # in the last step can ever be removed from the ready envs.
                # Thus, this guarantees that each env will contribute at least one episode to the
                # collected data (the buffer). This effect was previous called "avoiding bias in selecting environments"
                # However, it is not at all clear whether this is actually useful or necessary.
                # Additional naming convention:
                # S - number of surplus envs
                # TODO: can the whole block be removed? If we have too many episodes, we could just strip the last ones.
                #   Changing R to R-S highly increases the complexity of the code.
                if n_episode:
                    remaining_episodes_to_collect = n_episode - num_collected_episodes
                    surplus_env_num = len(ready_env_ids_R) - remaining_episodes_to_collect
                    if surplus_env_num > 0:
                        # R becomes R-S here, preparing for the next iteration in while loop
                        # Everything that was of length R needs to be filtered and become of length R-S.
                        # Note that this won't be the last iteration, as one iteration equals one
                        # step and we still need to collect the remaining episodes to reach the breaking condition.

                        # creating the mask
                        env_to_be_ignored_ind_local_S = env_ind_local_D[:surplus_env_num]
                        env_should_remain_R = np.ones_like(ready_env_ids_R, dtype=bool)
                        env_should_remain_R[env_to_be_ignored_ind_local_S] = False
                        # stripping the "idle" indices, shortening the relevant quantities from R to R-S
                        ready_env_ids_R = ready_env_ids_R[env_should_remain_R]
                        last_obs_RO = last_obs_RO[env_should_remain_R]
                        last_info_R = last_info_R[env_should_remain_R]
                        if hidden_state_RH is not None:
                            last_hidden_state_RH = last_hidden_state_RH[env_should_remain_R]  # type: ignore[index]

                        second_to_last_feature = second_to_last_feature[env_should_remain_R]
                        is_second_to_last_obs_RO_valid = is_second_to_last_obs_RO_valid[env_should_remain_R]
                        second_to_last_obs_RO = second_to_last_obs_RO[env_should_remain_R]
                        second_to_last_info_R = second_to_last_info_R[env_should_remain_R]
                        last_act_normalized_RA = last_act_normalized_RA[env_should_remain_R]
                        last_pi_option_R = last_pi_option_R[env_should_remain_R]

                        act_normalized_RA = act_normalized_RA[env_should_remain_R]
                        pi_option_R = pi_option_R[env_should_remain_R]

                        rew_R = rew_R[env_should_remain_R]

            if (n_step and step_count >= n_step) or (
                n_episode and num_collected_episodes >= n_episode
            ):
                break

        # generate statistics
        self.collect_step += step_count
        self.collect_episode += num_collected_episodes
        collect_time = max(time.time() - start_time, 1e-9)
        self.collect_time += collect_time

        collected_return = None
        if self.policy.collected_return is not None:
            collected_return = self.policy.collected_return[0]

        if n_step:
            # persist for future collect iterations
            self._pre_collect_second_to_last_feature = second_to_last_feature
            self._pre_collect_is_second_to_last_obs_RO_valid = is_second_to_last_obs_RO_valid
            self._pre_collect_second_to_last_obs_RO = second_to_last_obs_RO
            self._pre_collect_second_to_last_info_R = second_to_last_info_R

            self._pre_collect_last_act_normalized_RA = last_act_normalized_RA
            self._pre_collect_last_pi_option_R = last_pi_option_R

            self._pre_collect_obs_RO = last_obs_RO
            self._pre_collect_info_R = last_info_R
            self._pre_collect_hidden_state_RH = last_hidden_state_RH

            self._pre_collect_act_normalized_RA = act_normalized_RA
            self._pre_collect_pi_option_R = pi_option_R

            self._pre_collect_rew_R = rew_R
        elif n_episode:
            # reset envs and the _pre_collect fields
            self.reset_env(gym_reset_kwargs)  # todo still necessary?

        self._update_policy_debug_results()
        if not self.training:
            self._update_policy_eval_results(collected_return)

        return CollectStats.with_autogenerated_stats(
            returns=np.array(episode_returns),
            lens=np.array(episode_lens),
            n_collected_episodes=num_collected_episodes,
            n_collected_steps=step_count,
            collect_time=collect_time,
            collect_speed=step_count / collect_time,
        )

    def update_buffer(self, slot_idx, ep_idx, ep_len):

        buffer_num = 1
        if hasattr(self.buffer, 'buffer_num'):
            buffer_num = self.buffer.buffer_num

        buffer_slot_size = self.buffer.maxsize // buffer_num

        if self.is_circular_buffer:
            # Add to buffer in a circular manner
            u_bound = slot_idx * buffer_slot_size + buffer_slot_size
            l_bound = slot_idx * buffer_slot_size

            if ep_len > buffer_slot_size:
                #warnings.warn(f"ep_len is bigger than buffer slot size: {ep_len} > {buffer_slot_size}")
                return

            if ep_idx + ep_len > u_bound:
                update_len = u_bound - ep_idx
                self.buffer.rew[ep_idx:u_bound] = self.ep_reward[:update_len]
                self.buffer.rew[l_bound:l_bound + (ep_len - update_len)] = self.ep_reward[update_len:]
            else:
                self.buffer.rew[ep_idx: ep_idx + ep_len] = self.ep_reward
        else:
            if self.collect_step < buffer_slot_size:
                self.buffer.rew[ep_idx: ep_idx + ep_len] = self.ep_reward

    def _update_policy_debug_results(self):
        result_dict = {}

        if not self.is_debug:
            self.policy_debug_results.append(result_dict)
            return

        if self._debug_track_trajectories_split_and_mismatches:
            result_dict["debug_trajectory_splits"] = self.policy.debug_trajectory_splits
            result_dict["debug_trajectory_mismatches"] = self.policy.debug_trajectory_mismatches
        if self._debug_track_trajectories_length:
            if len(self.policy_debug_results) > 0:
                self.policy_debug_results[-1]["debug_track_trajectories_length"] = {}
            result_dict["debug_track_trajectories_length"] = encode_dict_with_tuple_keys(self.policy.debug_trajectory_length_tracking)
        if self._debug_track_policies:
            result_dict["debug_track_policies"] = self.debug_policy_tracking

        self.policy_debug_results.append(result_dict)

    def _update_policy_eval_results(self, collected_return):

        if self.policy.chosen_return is None or collected_return is None:
            return

        result_dict = {
            'chosen_return': list(self.policy.chosen_return),
            'collected_return': deepcopy(list(collected_return))
        }

        self.policy_eval_results.append(result_dict)

    def _track_policies(self):
        super()._track_policies()

        if self.is_debug and self._debug_track_policies:
            _, values = self.policy.fetch_policies(self._initial_obs)
            self.debug_policy_tracking = values
