import logging
from copy import deepcopy
from typing import Any

import gym
import numpy as np
import torch
from gymnasium.spaces import Dict

from tianshou.data import CollectStats
from tianshou.env import BaseVectorEnv
from tianshou.policy import BasePolicy

from _tianshou_custom.data import CustomCollector
from _tianshou_custom.data.buffer.CustomReplayBuffer import CustomReplayBuffer
from _tianshou_custom.data.buffer.CustomVectorReplayBuffer import CustomVectorReplayBuffer
from utils.environment_util import flatten_obs

log = logging.getLogger(__name__)


class PolicySelectionCollector(CustomCollector):
    # NAMING CONVENTION (mostly suffixes):
    # episode - An episode means a rollout until done (terminated or truncated). After an episode is completed,
    # the corresponding env is either reset or removed from the ready envs.
    # N - number of envs, always fixed and >= R.
    # R - number ready env ids. Note that this might change when envs get idle.
    #     This can only happen in n_episode case, see explanation in the corresponding block.
    #     For n_step, we always use all envs to collect the data, while for n_episode,
    #     R will be at most n_episode at the beginning, but can decrease during the collection.
    # O - dimension(s) of observations
    # A - dimension(s) of actions
    # H - dimension(s) of hidden state
    # D - number of envs that reached done in the current collect iteration. Only relevant in n_episode case.
    # S - number of surplus envs, i.e. envs that are ready but won't be used in the next iteration.
    #     Only used in n_episode case. Then, R becomes R-S.

    def __init__(
        self,
        policy: BasePolicy,
        env: gym.Env | BaseVectorEnv,
        buffer: CustomReplayBuffer | None = None,
        exploration_noise: bool = False,
        MPEC_DICT=None,
        training=False,
        *args,
        **kwargs
    ):
        self.training=training

        self.num_of_policies = 0
        self._initial_obs = None
        self._learned_values = []

        self.dont_ask_for_policy = MPEC_DICT['dont_ask_for_policy']

        if buffer is None:
            buffer = CustomVectorReplayBuffer(len(env), len(env))

        super().__init__(policy, env, buffer, exploration_noise, *args, **kwargs)

    @torch.no_grad()
    def collect(
        self,
        n_step: int | None = None,
        n_episode: int | None = None,
        random: bool = False,
        render: float | None = None,
        reset_before_collect: bool = False,
        gym_reset_kwargs: dict[str, Any] | None = None,
        prompt_policies=False,
        *args,
        **kwargs
    ) -> CollectStats:
        """Collect a specified number of steps or episodes.

        To ensure an unbiased sampling result with the n_episode option, this function will
        first collect ``n_episode - env_num`` episodes, then for the last ``env_num``
        episodes, they will be collected evenly from each env.

        :param n_step: how many steps you want to collect.
        :param n_episode: how many episodes you want to collect.
        :param random: whether to use random policy for collecting data.
        :param render: the sleep time between rendering consecutive frames.
        :param reset_before_collect: whether to reset the environment before collecting data.
            (The collector needs the initial obs and info to function properly.)
        :param gym_reset_kwargs: extra keyword arguments to pass into the environment's
            reset function. Only used if reset_before_collect is True.

        .. note::

            One and only one collection number specification is permitted, either
            ``n_step`` or ``n_episode``.

        :return: The collected stats
        """
        chosen_policy = kwargs.get('chosen_policy', None)
        chosen_return = kwargs.get('chosen_return', None)
        if chosen_policy is not None:
            self.policy.chosen_policy = chosen_policy
            self.policy.chosen_return = chosen_return

        if reset_before_collect:
            self.reset(reset_buffer=False, gym_reset_kwargs=gym_reset_kwargs, prompt_policies=prompt_policies, *args, **kwargs)

        self._track_policies()

        return super().collect(
            n_step,
            n_episode,
            random,
            render,
            False,
            gym_reset_kwargs,
        )

    def reset(
        self,
        reset_buffer: bool = True,
        reset_stats: bool = True,
        gym_reset_kwargs: dict[str, Any] | None = None,
        prompt_policies=False,
        *args,
        **kwargs
    ) -> tuple[np.ndarray, np.ndarray]:
        """Reset the environment, statistics, and data needed to start the collection.

        :param reset_buffer: if true, reset the replay buffer attached
            to the collector.
        :param reset_stats: if true, reset the statistics attached to the collector.
        :param gym_reset_kwargs: extra keyword arguments to pass into the environment's
            reset function. Defaults to None (extra keyword arguments)
        :return: The initial observation and info from the environment.
        """
        obs_NO, info_N = self.reset_env(gym_reset_kwargs=gym_reset_kwargs, prompt_policies=prompt_policies, *args, **kwargs)
        if reset_buffer:
            self.reset_buffer()
        if reset_stats:
            self.reset_stat()
        self._is_closed = False
        return obs_NO, info_N

    def reset_env(
        self,
        gym_reset_kwargs: dict[str, Any] | None = None,
        prompt_policies=False,
        *args,
        **kwargs
    ) -> tuple[np.ndarray, np.ndarray]:
        obs_NO, info_N = super().reset_env(gym_reset_kwargs)

        if self._initial_obs is None:
            keys = list(self.env.observation_space.spaces.keys()) if isinstance(self.env.observation_space, Dict) else None
            self._initial_obs = flatten_obs(deepcopy(obs_NO[0]), keys)
            self._initial_info = info_N[0]

        self.policy._action_space = self._action_space

        self.policy.collected_return = None

        self._pre_collect_second_to_last_feature = np.array([None] * self.env_num)
        self._pre_collect_is_second_to_last_obs_RO_valid = np.array([False] * self.env_num)
        self._pre_collect_second_to_last_obs_RO = np.array([None] * self.env_num)
        self._pre_collect_second_to_last_info_R = np.array([None] * self.env_num)

        self._pre_collect_last_act_normalized_RA = np.array([None] * self.env_num)
        self._pre_collect_last_pi_option_R = np.array([None] * self.env_num)

        if 'action' not in kwargs or not kwargs['action']:
            kwargs['action'] = [None] * self.env_num
            kwargs['pi_option'] = [None] * self.env_num
            kwargs['rew'] = [None] * self.env_num

        self._pre_collect_act_normalized_RA = kwargs['action']
        self._pre_collect_pi_option_R = kwargs['pi_option']
        self._pre_collect_rew_R = kwargs['rew']

        if prompt_policies and not self.dont_ask_for_policy:
            action, pi_option = self.policy.ask_policy(obs_NO[0])
            self._pre_collect_act_normalized_RA = [action] * self.env_num
            self._pre_collect_pi_option_R = [pi_option] * self.env_num

        self._pre_collect_act_normalized_RA = np.array(self._pre_collect_act_normalized_RA)
        self._pre_collect_pi_option_R = np.array(self._pre_collect_pi_option_R)

        self._pre_collect_rew_R = np.array(self._pre_collect_rew_R)

        empty_indices = np.where(self._pre_collect_act_normalized_RA == None)[0]

        for i in empty_indices:
            if self.policy.chosen_policy is not None:
                action, pi_option = self.policy.ask_policy(obs_NO[i], self.policy.chosen_policy)
            else:
                # Fill missing actions with random actions
                action, pi_option = self.policy.get_epsilon_greedy_action_pi_option(obs_NO[i], i)
            self._pre_collect_act_normalized_RA[i] = action
            self._pre_collect_pi_option_R[i] = pi_option

        self._pre_collect_act_normalized_RA = self._pre_collect_act_normalized_RA.astype(int)

        return obs_NO, info_N

    def close(self) -> None:
        super().close()
        self._pre_collect_act_normalized_RA = None
        self._pre_collect_pi_option_R = None

    def _track_policies(self):
        def all_close(list1, list2, rtol=1e-5, atol=1.e-8):
            arr1 = np.array(list1)
            arr2 = np.array(list2)

            if arr1.shape != arr2.shape:
                return False

            if arr1.size == 0:
                return True

            arr1 = arr1[np.lexsort(arr1.T[::-1])]
            arr2 = arr2[np.lexsort(arr2.T[::-1])]
            return np.allclose(arr1, arr2, rtol=rtol, atol=atol)

        if self._initial_obs is not None:
            non_dominated_set, values = self.policy.fetch_policies(self._initial_obs)
        else:
            return
        if not all_close(self._learned_values, values):
            self._learned_values = values
            self.num_of_policies = len(values)
            print(f"Number of policies: {self.num_of_policies}")
            if self.num_of_policies != 0:
                indices = np.lexsort(values.T[::-1])
                non_dominated_set, values = non_dominated_set[indices], values[indices]
                for i, option in enumerate(values, 1):
                    print(f"{i}. {option}")
