import os.path
import warnings
from collections import defaultdict
from dataclasses import dataclass
from typing import cast, Literal, Any

import numpy as np
from tianshou.data import Batch
from tianshou.data.batch import BatchProtocol
from tianshou.data.types import ObsBatchProtocol, ModelOutputBatchProtocol, RolloutBatchProtocol
from tianshou.policy import BasePolicy, TrainingStats
from tianshou.policy.modelfree.dqn import DQNPolicy, TDQNTrainingStats

from data.mpec.ns_pi_memory import NSPiMemory
from data.mpec.q_memory import QMemory
from data.mpec.utils.non_domination_util import non_dominated_mask, duplicated_mask
from policy.multi_objective.multi_objective_policy import MultiObjectivePolicy
from utils.collections_util import entry_or_none
from utils.environment_util import flatten_obs


@dataclass(kw_only=True)
class DQNTrainingStats(TrainingStats):
    loss: float


class MPECNSPolicy(MultiObjectivePolicy, BasePolicy):

    def __init__(self, seed, MPEC_DICT, ordered_obs_keys, use_normalized_returns=False, *args, **kwargs):

        self._learned_q_values = []
        self._num_of_policies = 0
        self.eps = 0.0

        self.seed = seed
        self.MPEC_DICT = MPEC_DICT
        self._ordered_obs_keys = ordered_obs_keys

        self.use_normalized_returns = use_normalized_returns

        self._chosen_policy_trade_off = None
        self._user_chosen_policy_trade_off = None
        self._terminate_if_no_policy = self.MPEC_DICT['terminate_if_no_policy']

        self.q_memory = QMemory(
            **MPEC_DICT,
        )

        self.pi_memory = NSPiMemory(
            seed=seed,
            **MPEC_DICT,
        )
        
        self.is_debug = MPEC_DICT.get('debug', False)
        self._debug_track_trajectories_length = MPEC_DICT.get('debug_track_trajectories_length', False)
        self._debug_naive_selection = MPEC_DICT.get('debug_naive_selection', False)
        self._debug_disable_trajectory_length = MPEC_DICT.get('debug_disable_trajectory_length', False)
        self._debug_disable_average_reward = MPEC_DICT.get('debug_disable_average_reward', False)
        self.debug_trajectory_length_tracking = self.q_memory.debug_trajectory_length_tracking

        self.chosen_policy = None
        self.chosen_return = None

        self.count = -1

        self.collected_return = None

        del kwargs['model']
        del kwargs['optim']
        del kwargs['discount_factor']
        del kwargs['estimation_step']
        del kwargs['target_update_freq']
        super().__init__(*args, **kwargs)

    def set_eps(self, eps: float) -> None:
        """Set the eps for epsilon-greedy exploration."""
        self.eps = eps

    def set_n_step_test(self, n_step, return_):
        return_step = np.abs(return_[-1])
        mod = n_step % return_step
        if mod != 0:
            n_step = n_step - mod + return_step
        return n_step

    def save_memories(self, path):
        self.q_memory.save(os.path.join(path, "q_memory"))
        self.pi_memory.save(os.path.join(path, "pi_memory"))

    def load_memories(self, path):
        self.q_memory.load(os.path.join(path, "q_memory"))
        self.pi_memory.load(os.path.join(path, "pi_memory"))

    def forward(
        self,
        last_obs_batch: ObsBatchProtocol,
        last_act,
        last_pi_option,
        obs_batch: ObsBatchProtocol,
        ready_env_ids,
        state: dict | BatchProtocol | np.ndarray | None = None,
        model: Literal["model", "model_old"] = "model",
        is_training=False,
        **kwargs: Any,
    ) -> ModelOutputBatchProtocol:
        """Compute action over the given batch data.

        If you need to mask the action, please add a "mask" into batch.obs, for
        example, if we have an environment that has "0/1/2" three actions:
        ::

            batch == Batch(
                obs=Batch(
                    obs="original obs, with batch_size=1 for demonstration",
                    mask=np.array([[False, True, False]]),
                    # action 1 is available
                    # action 0 and 2 are unavailable
                ),
                ...
            )

        :return: A :class:`~tianshou.data.Batch` which has 3 keys:

            * ``act`` the action.
            * ``logits`` the network's raw output.
            * ``state`` the hidden state.

        .. seealso::

            Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for
            more detailed explanation.
        """

        terminate_env = np.zeros(len(ready_env_ids), dtype=bool)

        act_normalized_RA = np.array([None] * len(ready_env_ids))
        pi_option_R = np.array([None] * len(ready_env_ids))
        if not is_training:
            for i, ready_env_id in enumerate(ready_env_ids):
                last_observation = flatten_obs(last_obs_batch[i].obs, self._ordered_obs_keys)
                observation = flatten_obs(obs_batch[i].obs, self._ordered_obs_keys)
                # TODO: fix condition (None gets to be [None, None, ...] in a ndarray)
                if last_observation is not None and None not in last_observation:
                    if self._debug_naive_selection:
                        action, pi_option = self.get_random_best_action_pi_option(observation, i)
                    else:
                        action, pi_option = self.pi_memory.get(
                            last_observation, last_act[i], last_pi_option[i], observation)

                    act_normalized_RA[i] = action
                    pi_option_R[i] = pi_option

        missing_act_args = [arg[0] for arg in np.argwhere(act_normalized_RA == None)]

        for i in missing_act_args:
            observation = flatten_obs(obs_batch[i].obs, self._ordered_obs_keys)
            if is_training or (self._user_chosen_policy_trade_off is None and self.chosen_policy is None):
                action, pi_option = self.get_epsilon_greedy_action_pi_option(observation, i)
            else:
                if self._terminate_if_no_policy:
                    terminate_env[i] = True
                    action, pi_option = self.get_random_action_pi_option(observation, i)
                else:
                    action, pi_option = self.get_closest_preferred_trade_off_action_pi_option(observation, i)

                    if pi_option is None:
                        action, pi_option = self.get_random_action_pi_option(observation, i)

            act_normalized_RA[i] = action
            pi_option_R[i] = pi_option

        act_normalized_RA = act_normalized_RA.astype(int)
        result = Batch(act=act_normalized_RA, pi_option=pi_option_R, terminate_env=terminate_env)
        return cast(ModelOutputBatchProtocol, result)

    def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TDQNTrainingStats:

        def _debug_as6_trajectory_length_tracking(state_pattern, action_pattern, pi_option, dominated_sets):
            if self.is_debug and self._debug_track_trajectories_length:
                self.q_memory.debug_trajectory_length_tracking.pop((state_pattern, action_pattern, pi_option), None)
                pi_info = entry_or_none(self.pi_memory.info, [state_pattern, action_pattern, pi_option])
                if pi_info is not None:
                    next_state_pattern, value = list(pi_info.items())[0]
                    next_action_pattern, next_pi_option = value['next_action_pattern'], value['next_pi_option']
                    if (next_state_pattern, next_action_pattern, next_pi_option) not in dominated_sets:
                        next_info = entry_or_none(
                            self.q_memory.info, [next_state_pattern, next_action_pattern, next_pi_option])
                        self.q_memory.debug_trajectory_length_tracking[
                            (next_state_pattern, next_action_pattern, next_pi_option)] = next_info['trajectory_length']

        def _debug_as6_trajectory_length_tracking__dominated():
            if self.is_debug and self._debug_track_trajectories_length:
                state_pattern = abs_observation
                action_pattern = abs_action
                pi_option = current_pi_option
                _debug_as6_trajectory_length_tracking(state_pattern, action_pattern, pi_option, dominated_sets)

        def _debug_as6_trajectory_length_tracking__delete():
            if self.is_debug and self._debug_track_trajectories_length:
                state_pattern = self.q_memory.abstract_state(observation)
                action_pattern = self.q_memory.abstract_action(action)
                _debug_as6_trajectory_length_tracking(state_pattern, action_pattern, pi_option, sapi_next_to_delete)

        ready_env_ids_R = batch.info.env_id

        for i in range(len(ready_env_ids_R)):

            self.count += 1

            observation = flatten_obs(batch.obs[i], self._ordered_obs_keys)
            action = batch.act[i]
            pi_option = batch.pi_option[i]
            next_observation = flatten_obs(batch.obs_next[i], self._ordered_obs_keys)
            reward = batch.rew[i]

            terminated = batch.terminated[i]

            non_dominated, q_values = self.q_memory.fetch_non_dominated_set(next_observation)
            non_dominated_set = set(map(tuple, non_dominated))

            # TODO: Handle non-deterministic transitions
            # Finding a state-action-next_state for the first time
            if len(non_dominated_set) == 0:
                # q_memory add
                pi_option = batch.pi_option[i] = self.pi_memory.generate_pi_option_id(observation, action)
                self.q_memory.add(observation, action, pi_option, next_observation, reward)

            # "q_value" = observation, action, pi_option -> q_value
            # "sapi association" = observation, action, pi_option, next_observation -> next_action, next_pi_option
            # "sapi 'next'" = next_observation, next_action, next_pi_option
            sapi_associations = self.pi_memory.fetch_sapi_associations(observation, action, next_observation)
            sapi_next_set = set(map(tuple, [
                (abs_next_observation, abs_next_action, next_pi_option)
                for (_, _, _, abs_next_observation, abs_next_action, next_pi_option) in sapi_associations
            ]))

            sapi_next_to_delete = (sapi_next_set - non_dominated_set)
            for a_i, (abs_observation, abs_action, pi_option, abs_next_observation, abs_next_action, abs_next_pi_option) in reversed(list(enumerate(sapi_associations))):
                sapi_next = (abs_next_observation, abs_next_action, abs_next_pi_option)
                if sapi_next in sapi_next_to_delete:
                    # pi_memory delete
                    _debug_as6_trajectory_length_tracking__delete()
                    self.q_memory.update(abs_observation, abs_action, pi_option, None, None, None, reward)
                    self.pi_memory.delete(observation, action, pi_option, next_observation)
                    sapi_associations.pop(a_i)

            sapi_next_to_associate = (non_dominated_set - sapi_next_set)
            abs_observation = self.pi_memory.abstract_state(observation)
            abs_action = self.pi_memory.abstract_action(action)
            for abs_next_sapi in sapi_next_to_associate:
                new_pi_option = self.pi_memory.generate_pi_option_id(observation, action)
                sapi_association = (abs_observation, abs_action, new_pi_option, *abs_next_sapi)

                # pi_memory add
                self.pi_memory.add(*sapi_association)
                sapi_associations.append(sapi_association)

            for sapi_association in sapi_associations:
                # q_memory update
                below_max_length = self.q_memory.update(*sapi_association, reward)
                if not below_max_length:
                    current_pi_option = sapi_association[2]
                    self.pi_memory.delete(observation, action, current_pi_option)

            # Clean up q-memory
            dominated_sets = self.q_memory.delete_dominated_sets(observation, action)
            for (_, _, current_pi_option) in dominated_sets:
                _debug_as6_trajectory_length_tracking__dominated()
                self.pi_memory.delete(observation, action, current_pi_option)

            if terminated:
                # Add terminal state association to the pi_memory
                if not sapi_associations:
                    abs_next_observation = self.pi_memory.abstract_state(next_observation)
                    if pi_option is None:
                        pi_option = batch.pi_option[i] = self.pi_memory.generate_pi_option_id(observation, action)
                    sapi_association = (abs_observation, abs_action, pi_option, abs_next_observation, None, None)
                    self.pi_memory.add(*sapi_association)

                # Add terminal state to the q_memory
                # It helps during q-update with pi_memory
                if non_dominated.size == 0:
                    self.q_memory.add(next_observation, None, None, None, np.zeros_like(reward))

        # self._track_policies()

        return DQNTrainingStats(loss=0.0)  # type: ignore[return-value]

    def ask_policy(self, obs, chosen_policy=None):

        if isinstance(obs, (Batch, dict)):
            obs = flatten_obs(obs, self._ordered_obs_keys)

        if chosen_policy is not None:
            action, pi_option = self.chosen_policy[1:3]
            return action, pi_option

        non_dominated_set, q_values = self.q_memory.fetch_non_dominated_set(obs, True)

        if q_values.size != 0:
            indices = np.lexsort(q_values.T[::-1])
            non_dominated_set, q_values = non_dominated_set[indices], q_values[indices]

        if self.use_normalized_returns:
            normalized_q_values = self._get_normalized_q_values(q_values)
            non_dominated_normalized_q_values = self._get_non_dominated_q_values(normalized_q_values)

            choice, trade_off = self.prompt_user_choice(non_dominated_normalized_q_values)
            self._user_chosen_policy_trade_off = trade_off

            if choice is not None:
                chosen_q_value = non_dominated_normalized_q_values[choice]
                choice = np.where(np.all(normalized_q_values == chosen_q_value, axis=1))[0][0]
        else:
            choice, trade_off = self.prompt_user_choice(q_values)
            self._user_chosen_policy_trade_off = trade_off

        if choice is not None:
            _, action, pi_option = non_dominated_set[choice]
        else:
            action, pi_option = self.get_closest_preferred_trade_off_action_pi_option(obs, 0)

        return action, pi_option

    def _get_action_space(self, env_index):
        try:
            action_space = self._action_space[env_index]
        # TODO: test whether envpool env explicitly
        except TypeError:  # envpool's action space is not for per-env
            action_space = self._action_space

        return action_space

    def get_random_action_pi_option(self, state, env_index):
        if isinstance(state, Batch):
            state = flatten_obs(state, self._ordered_obs_keys)
        action_space = self._get_action_space(env_index)
        rng = action_space.np_random

        action = action_space.sample()
        pi_options = self.q_memory.fetch_pi_options(state, action)
        pi_option = None
        if pi_options:
            pi_option = rng.choice(pi_options)

        return action, pi_option

    def get_random_best_action_pi_option(self, state, env_index):
        if isinstance(state, Batch):
            state = flatten_obs(state, self._ordered_obs_keys)
        non_dominated_set, _ = self.q_memory.fetch_non_dominated_set(state)

        if non_dominated_set.size == 0:
            return self.get_random_action_pi_option(state, env_index)

        items = defaultdict(list)
        for _, action, pi_option in non_dominated_set:
            if action:
                items[action].append(pi_option)

        action_space = self._get_action_space(env_index)
        rng = action_space.np_random

        action_choices = np.unique(list(items.keys()))
        if action_choices.size == 0:
            return None, None

        action = rng.choice(action_choices)
        pi_option = rng.choice(items[action])

        return action, pi_option

    def get_epsilon_greedy_action_pi_option(self, state, env_index):
        if isinstance(state, Batch):
            state = flatten_obs(state, self._ordered_obs_keys)
        action_space = self._get_action_space(env_index)
        rng = action_space.np_random

        if rng.uniform(0, 1) < self.eps:
            action, pi_option = self.get_random_action_pi_option(state, env_index)
        else:
            action, pi_option = self.get_random_best_action_pi_option(state, env_index)
            if action is None:
                action, pi_option = self.get_random_action_pi_option(state, env_index)

        return action, pi_option

    def get_closest_preferred_trade_off_action_pi_option(self, state, env_index):
        if isinstance(state, Batch):
            state = flatten_obs(state, self._ordered_obs_keys)
        non_dominated_set, q_values = self.q_memory.fetch_non_dominated_set(state)

        if len(q_values) == 0:
            return None, None

        collected_return = 0
        if self.collected_return is not None:
            collected_return = self.collected_return

        desired_trade_off = self._user_chosen_policy_trade_off or self.chosen_return

        next_collected_return = collected_return + q_values
        ideal_collected_return = (desired_trade_off / desired_trade_off[-1:]) * next_collected_return[:,-1:]

        mean = next_collected_return.mean(axis=0)
        std = next_collected_return.std(axis=0)
        std_safe = np.where(std == 0, 1, std)
        next_collected_return_norm = (next_collected_return - mean) / std_safe
        target_norm = (ideal_collected_return - mean) / std_safe

        distances = np.linalg.norm(target_norm - next_collected_return_norm, axis=1)
        min_index = np.argmin(distances)
        closest_trade_off = q_values[min_index]

        minimum_distance = np.min(distances)
        if minimum_distance > 0:
            warnings.warn(f"Switching from trade-off {self._user_chosen_policy_trade_off} "
                          f"to trade-off {closest_trade_off} with distance {minimum_distance}")

        _, action, pi_option = non_dominated_set[min_index]
        if action is None:
            action, pi_option = self.get_random_action_pi_option(state, env_index)
            warnings.warn(f"Switching from trade-off {self._user_chosen_policy_trade_off} "
                          f"to a random action {action}, pi_option {pi_option}")

        return action, pi_option

    def _get_normalized_q_values(self, q_values):
        normalized_q_values = q_values / -q_values[:, -1:]
        return normalized_q_values

    def _get_non_dominated_q_values(self, q_values):
        non_dominated_indices = non_dominated_mask(q_values)
        duplicated_indices = duplicated_mask(q_values)
        return q_values[non_dominated_indices & ~duplicated_indices]

    def fetch_policies(self, observation, ignore_trajectory_length=False, use_average=False):
        if self.is_debug and self._debug_disable_average_reward:
            self.q_memory.is_debug = False
        result = self.q_memory.fetch_non_dominated_set(observation, ignore_trajectory_length, use_average)
        if self.is_debug:
            self.q_memory.is_debug = True
        return result

    def update_collected_return(self, **kwargs):
        if 'last_rew_R' in kwargs and not np.all(kwargs['last_rew_R'] == None):
            last_rew_R = kwargs['last_rew_R'].copy()
            not_null_mask = np.all(last_rew_R != None, axis=1)

            n_rows, n_cols = last_rew_R.shape
            if self.is_debug and self._debug_disable_trajectory_length:
                if self.collected_return is None:
                    self.collected_return = np.zeros((n_rows, n_cols))

                expanded = np.empty((n_rows, n_cols), dtype=object)
                expanded[not_null_mask, :] = last_rew_R[not_null_mask]
            else:
                if self.collected_return is None:
                    self.collected_return = np.zeros((n_rows, n_cols + 1))

                expanded = np.empty((n_rows, n_cols + 1), dtype=object)
                expanded[not_null_mask, :n_cols] = last_rew_R[not_null_mask]
                expanded[not_null_mask, n_cols] = -1.0

            self.collected_return[not_null_mask] = self.collected_return[not_null_mask] + expanded[not_null_mask]

    # def _track_policies(self):
    #     non_dominated_set, q_values = self.q_memory.fetch_non_dominated_set([0, 0, 1])
    #     if set(map(tuple, self._learned_q_values)) != set(map(tuple, q_values)):
    #         self._learned_q_values = q_values
    #         self._num_of_policies = len(q_values)
