import os.path
import warnings
from dataclasses import dataclass
from typing import cast, Literal, Any

import numpy as np
from tianshou.data import Batch
from tianshou.data.batch import BatchProtocol
from tianshou.data.types import ObsBatchProtocol, ModelOutputBatchProtocol, RolloutBatchProtocol
from tianshou.policy import BasePolicy, TrainingStats
from tianshou.policy.modelfree.dqn import DQNPolicy, TDQNTrainingStats

from data.mpec.s_pi_memory import SPiMemory
from data.mpec.utils.non_domination_util import non_dominated_mask, duplicated_mask
from data.mpec.v_memory import VMemory
from policy.multi_objective.multi_objective_policy import MultiObjectivePolicy
from utils.collections_util import entry_or_none
from utils.environment_util import flatten_obs


@dataclass(kw_only=True)
class DQNTrainingStats(TrainingStats):
    loss: float


class MPECSPolicy(MultiObjectivePolicy, BasePolicy):

    def __init__(self, seed, MPEC_DICT, ordered_obs_keys, use_normalized_returns=False, *args, **kwargs):

        self._learned_q_values = []
        self._num_of_policies = 0
        self.eps = 0.0

        self.seed = seed
        self.MPEC_DICT = MPEC_DICT
        self._ordered_obs_keys = ordered_obs_keys

        self.use_normalized_returns = use_normalized_returns

        self._chosen_policy_trade_off = None
        self._user_chosen_policy_trade_off = None
        self._terminate_if_no_policy = self.MPEC_DICT['terminate_if_no_policy']

        self.v_memory = VMemory(
            **MPEC_DICT,
        )

        self.pi_memory = SPiMemory(
            seed=seed,
            **MPEC_DICT,
        )

        self.is_debug = MPEC_DICT.get('debug', False)
        self._debug_disable_reconnection = MPEC_DICT.get('debug_disable_reconnection', False)
        self._debug_disable_ssm = MPEC_DICT.get('debug_disable_ssm', False)
        self._debug_disable_cycle_detection = MPEC_DICT.get('debug_disable_cycle_detection', False)
        self._debug_track_trajectories_length = MPEC_DICT.get('debug_track_trajectories_length', False)
        self._debug_track_trajectories_split_and_mismatches = MPEC_DICT.get('debug_track_trajectories_split_and_mismatches', False)
        self.debug_trajectory_length_tracking = self.v_memory.debug_trajectory_length_tracking

        self.debug_trajectory_splits = 0
        self.debug_trajectory_mismatches = 0

        self.chosen_policy = None
        self.chosen_return = None

        self.count = -1

        self.collected_return = None

        del kwargs['model']
        del kwargs['optim']
        del kwargs['discount_factor']
        del kwargs['estimation_step']
        del kwargs['target_update_freq']
        super().__init__(*args, **kwargs)

    def set_eps(self, eps: float) -> None:
        """Set the eps for epsilon-greedy exploration."""
        self.eps = eps

    def set_n_step_test(self, n_step, return_):
        return_step = np.abs(return_[-1])
        mod = n_step % return_step
        if mod != 0:
            n_step = n_step - mod + return_step
        return n_step

    def save_memories(self, path):
        self.v_memory.save(os.path.join(path, "v_memory"))
        self.pi_memory.save(os.path.join(path, "pi_memory"))

    def load_memories(self, path):
        self.v_memory.load(os.path.join(path, "v_memory"))
        self.pi_memory.load(os.path.join(path, "pi_memory"))

    def forward(
        self,
        last_obs_batch: ObsBatchProtocol,
        last_act,
        last_pi_option,
        obs_batch: ObsBatchProtocol,
        ready_env_ids,
        state: dict | BatchProtocol | np.ndarray | None = None,
        model: Literal["model", "model_old"] = "model",
        is_training=False,
        **kwargs: Any,
    ) -> ModelOutputBatchProtocol:
        """Compute action over the given batch data.

        If you need to mask the action, please add a "mask" into batch.obs, for
        example, if we have an environment that has "0/1/2" three actions:
        ::

            batch == Batch(
                obs=Batch(
                    obs="original obs, with batch_size=1 for demonstration",
                    mask=np.array([[False, True, False]]),
                    # action 1 is available
                    # action 0 and 2 are unavailable
                ),
                ...
            )

        :return: A :class:`~tianshou.data.Batch` which has 3 keys:

            * ``act`` the action.
            * ``logits`` the network's raw output.
            * ``state`` the hidden state.

        .. seealso::

            Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for
            more detailed explanation.
        """

        terminate_env = np.zeros(len(ready_env_ids), dtype=bool)

        act_normalized_RA = np.array([None] * len(ready_env_ids))
        pi_option_R = np.array([None] * len(ready_env_ids))
        if not is_training:
            for i, ready_env_id in enumerate(ready_env_ids):
                last_observation = flatten_obs(last_obs_batch[i].obs, self._ordered_obs_keys)
                observation = flatten_obs(obs_batch[i].obs, self._ordered_obs_keys)
                # TODO: fix condition (None gets to be [None, None, ...] in a ndarray)
                if last_observation is not None and None not in last_observation:
                    pi_option = last_pi_option[i]
                    action = self.pi_memory.get_action(observation, pi_option)

                    act_normalized_RA[i] = action
                    pi_option_R[i] = pi_option

        missing_act_args = [arg[0] for arg in np.argwhere(act_normalized_RA == None)]

        for i in missing_act_args:
            observation = flatten_obs(obs_batch[i].obs, self._ordered_obs_keys)
            if is_training or (self._user_chosen_policy_trade_off is None and self.chosen_policy is None):
                action, pi_option = self.get_epsilon_greedy_action_pi_option(observation, i)
            else:
                if self._terminate_if_no_policy:
                    terminate_env[i] = True
                    action = self.get_random_action(observation, i)
                    pi_option = None
                else:
                    action, pi_option = self.get_closest_preferred_trade_off_action_pi_option(observation, i)

                    if pi_option is None:
                        action, pi_option = self.get_random_action(observation, i)
                    else:
                        action = self.pi_memory.get_action(observation, pi_option)

            act_normalized_RA[i] = action
            pi_option_R[i] = pi_option

        act_normalized_RA = act_normalized_RA.astype(int)
        result = Batch(act=act_normalized_RA, pi_option=pi_option_R, terminate_env=terminate_env)
        return cast(ModelOutputBatchProtocol, result)

    def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TDQNTrainingStats:

        def _debug_as8_count_trajectory_split(transitions_to_delete):
            if self.is_debug and self._debug_track_trajectories_split_and_mismatches:
                for (state_pattern, pi_option) in transitions_to_delete:
                    info = self.pi_memory.info.get(state_pattern).get(pi_option)
                    if len(info.previous) != 0 and len(info.next) != 0:
                        self.debug_trajectory_splits += 1
                        return True
            return False

        def _debug_as8_count_trajectory_mismatches(spi_association):
            if self.is_debug and self._debug_track_trajectories_split_and_mismatches:
                if spi_association is None:
                    self.debug_trajectory_mismatches += 1
                else:
                    info = entry_or_none(self.v_memory.info, [spi_association[0], spi_association[1]])
                    next_info = entry_or_none(self.v_memory.info, [spi_association[3], spi_association[4]])
                    if info is not None:
                        if next_info is not None:
                            if info['cross_policy_length'] != 0 and info['cross_policy_length'] != next_info['cross_policy_length'] + 1:
                                if spi_association[0] != spi_association[3]:
                                    self.debug_trajectory_mismatches += 1
                        elif info['cross_policy_length'] != 1:
                            self.debug_trajectory_mismatches += 1

        def _debug_as7_trajectory_length_tracking(state_pattern, pi_option):
            if self.is_debug and self._debug_track_trajectories_length:
                self.q_memory.debug_trajectory_length_tracking.pop((state_pattern, pi_option), None)
                pi_info = entry_or_none(self.pi_memory.info, [state_pattern, pi_option])
                if pi_info is not None:
                    next_state_pattern, value = list(pi_info.items())[0]
                    next_pi_option = value['next_pi_option']
                    next_info = entry_or_none(self.q_memory.info,[next_state_pattern, next_pi_option])
                    self.q_memory.debug_trajectory_length_tracking[
                        (next_state_pattern, next_pi_option)] = next_info['cross_policy_length']

        def _debug_as8_disable_reconnection(dominated_list):
            if self.is_debug and self._debug_disable_reconnection:
                for dominated in dominated_list:
                    old_pi_option = dominated[1]
                    trajectory = self.pi_memory.recall_trajectory(*dominated)
                    spi_associations = self.pi_memory._debug_recolor(trajectory)
                    for spi_association in spi_associations:
                        self.v_memory._debug_recolor(*spi_association, old_pi_option)
                    dominated_association = spi_associations[0]
                    self.v_memory.delete([[dominated_association[0], dominated_association[1]]])
                    self.pi_memory.delete([[dominated_association[0], dominated_association[1]]])

        def clean_up_dominated(observation):
            dominated_list, dominators_list = self.v_memory.fetch_dominated_sets(observation)
            dominated_set = set(map(tuple, dominated_list))
            _debug_as8_count_trajectory_split(dominated_list)
            if self.is_debug and self._debug_disable_reconnection:
                _debug_as8_disable_reconnection(dominated_list)
                return
            for dominated, dominators in zip(dominated_list, dominators_list):
                spi_associations_to_update, spi_associations_to_delete = self.pi_memory.transfer(*dominated, dominators)
                i = -1
                for i, (spi_association, old_pi_option) in enumerate(spi_associations_to_update):
                    obs = spi_association[0]
                    dominated_set.add((obs, old_pi_option))
                    _, transition_to_delete, _ = self.v_memory.reattach(*spi_association, old_pi_option)
                    if len(transition_to_delete) > 0:
                        if self.is_debug and self._debug_disable_ssm:
                            self.pi_memory.delete([spi_association[0:2]])
                            break
                        self.v_memory.delete([spi_association[0:2]])
                        self.pi_memory.delete([spi_association[0:2]])
                        break
                for spi_association, old_pi_option in spi_associations_to_update[i + 1:]:
                    obs = spi_association[0]
                    dominated_set.add((obs, old_pi_option))
                    self.pi_memory.delete([spi_association[0:2]])
                for (obs, old_pi_option) in spi_associations_to_delete:
                    dominated_set.add((obs, old_pi_option))
            self.v_memory.delete(dominated_set)

        for idx in reversed(range(len(batch))):

            ready_env_ids_R = [batch.info.env_id[idx]]

            for i in range(len(ready_env_ids_R)):

                self.count += 1

                observation = flatten_obs(batch[idx].obs, self._ordered_obs_keys)
                action = batch[idx].act
                pi_option = batch[idx].pi_option
                next_observation = flatten_obs(batch[idx].obs_next, self._ordered_obs_keys)
                reward = batch[idx].rew

                terminated = batch[idx].terminated

                abs_observation = self.pi_memory.abstract_state(observation)
                abs_action = self.pi_memory.abstract_action(action)

                non_dominated, v_values = self.v_memory.fetch_non_dominated_set(next_observation)
                non_dominated_set = set(map(tuple, non_dominated))

                # "v_value" = observation, pi_option -> v_value
                # "spia association" = observation, pi_option -> action
                # "spi association" = observation, pi_option, action, next_observation -> next_pi_option
                # "spi 'next'" = next_observation, next_pi_option
                # anchor pi_option
                spi_associations = self.pi_memory.fetch_spi_associations(observation, action, next_observation)
                spi_next_map = {(spi_association[3], spi_association[4]): (index, spi_association)
                                for index, spi_association in enumerate(spi_associations)}
                spi_next_set = set(spi_next_map.keys())

                none_next_pi_option_spi_associations = []
                next_pi_option_none_map = {}
                for a_i, spi_association in reversed(list(enumerate(spi_associations))):
                    (abs_observation, pi_option, abs_action, abs_next_observation, next_pi_option) = spi_association

                    if next_pi_option is None:
                        spi_associations.pop(a_i)
                        none_next_pi_option_spi_associations.append(spi_association)
                        a_j = len(none_next_pi_option_spi_associations) - 1
                        next_pi_option_none_map[pi_option] = a_j
                        continue

                spi_next_to_associate = sorted(non_dominated_set - spi_next_set, key=lambda x: (x is None, str(x)))
                add_current_transition = False
                spi_next_none_spi_associations_to_delete = []
                for (abs_next_observation, next_pi_option) in spi_next_to_associate:

                    pi_option = self.pi_memory.generate_pi_option_id(abs_observation, abs_next_observation, next_pi_option)
                    spi_association = (abs_observation, pi_option, abs_action, abs_next_observation, next_pi_option)

                    spi_next = (abs_next_observation, next_pi_option)
                    assert spi_next not in spi_next_set

                    if pi_option in next_pi_option_none_map:
                        spi_next_none_spi_associations_to_delete.append(next_pi_option_none_map[pi_option])

                    if abs_observation == abs_next_observation:
                        add_current_transition = True
                        continue

                    if self.v_memory.is_above_maximum_length(abs_next_observation, next_pi_option):
                        continue

                    # pi_memory add
                    self.pi_memory.update(*spi_association)
                    spi_associations.append(spi_association)
                    spi_next_set.add(spi_next)

                for a_j in reversed(sorted(spi_next_none_spi_associations_to_delete)):
                    none_next_pi_option_spi_associations.pop(a_j)
                spi_associations.extend(none_next_pi_option_spi_associations)

                # Finding a state-action-next_state for the first time
                if len(non_dominated_set) == 0:
                    add_current_transition = True

                detected_cycles = []
                if add_current_transition:
                    # v_memory add
                    pi_option = self.pi_memory.generate_pi_option_id()
                    self.pi_memory.add(observation, pi_option, action, next_observation)
                    detected_cycle = self.v_memory.add(observation, pi_option, action, next_observation, reward)
                    detected_cycles.extend(detected_cycle)

                transitions_to_delete = []
                for spi_association in spi_associations:
                    _debug_as8_count_trajectory_mismatches(spi_association)

                    # v_memory update
                    detected_cycle, transition_to_delete, to_fix_cycle = self.v_memory.update(*spi_association, reward)
                    if to_fix_cycle:
                        abs_observation, pi_option, next_abs_observation = detected_cycle[0][0:3]
                        cycle_spi_associations = self.pi_memory.fetch_trajectory(abs_observation, pi_option, next_abs_observation)
                        self.v_memory.fix_cycle(cycle_spi_associations, detected_cycle[0])
                    detected_cycles.extend(detected_cycle)
                    transitions_to_delete.extend(transition_to_delete)

                non_dominated_cycles = self.v_memory.fetch_non_dominated_detected_cycles(observation, detected_cycles)
                non_dominated_cycles, cycle_spi_associations_list, old_pi_option_list = self.pi_memory.copy_and_mark_cycles(non_dominated_cycles)
                self.v_memory.clone_and_mark_cycles(non_dominated_cycles, cycle_spi_associations_list, old_pi_option_list)

                _debug_as8_count_trajectory_split(transitions_to_delete)
                self.pi_memory.delete(transitions_to_delete)
                self.v_memory.delete(transitions_to_delete)

                # Clean up v-memory
                clean_up_dominated(observation)

        return DQNTrainingStats(loss=0.0)  # type: ignore[return-value]

    def ask_policy(self, obs, chosen_policy=None):
        if isinstance(obs, (Batch, dict)):
            obs = flatten_obs(obs, self._ordered_obs_keys)

        if chosen_policy is not None:
            _, pi_option = chosen_policy
            action = self.pi_memory.get_action(obs, pi_option)
            return action, pi_option

        non_dominated_set, v_values = self.fetch_policies(obs)

        if v_values.size != 0:
            indices = np.lexsort(v_values.T[::-1])
            non_dominated_set, v_values = non_dominated_set[indices], v_values[indices]

        if self.use_normalized_returns:
            normalized_v_values = self._get_normalized_v_values(v_values)
            non_dominated_normalized_v_values = self._get_non_dominated_v_values(normalized_v_values)

            choice, trade_off = self.prompt_user_choice(non_dominated_normalized_v_values)
            self._user_chosen_policy_trade_off = trade_off

            if choice is not None:
                chosen_v_value = non_dominated_normalized_v_values[choice]
                choice = np.where(np.all(normalized_v_values == chosen_v_value, axis=1))[0][0]
        else:
            choice, trade_off = self.prompt_user_choice(v_values)
            self._user_chosen_policy_trade_off = trade_off

        if choice is not None:
            _, pi_option = non_dominated_set[choice]
            action = self.pi_memory.get_action(obs, pi_option)
        else:
            action, pi_option = self.get_closest_preferred_trade_off_action_pi_option(obs, 0)

        return action, pi_option

    def _get_action_space(self, env_index):
        try:
            action_space = self._action_space[env_index]
        # TODO: test whether envpool env explicitly
        except TypeError:  # envpool's action space is not for per-env
            action_space = self._action_space

        return action_space

    def get_random_action(self, state, env_index):
        action_space = self._get_action_space(env_index)

        action = action_space.sample()
        pi_option = None

        return action, pi_option

    def get_random_best_action_pi_option(self, state, env_index):
        if isinstance(state, Batch):
            state = flatten_obs(state, self._ordered_obs_keys)
        non_dominated_set, _ = self.v_memory.fetch_non_dominated_set(state)

        if non_dominated_set.size == 0:
            return self.get_random_action(state, env_index)

        action_space = self._get_action_space(env_index)
        rng = action_space.np_random

        pi_option = rng.choice(non_dominated_set)
        action = self.pi_memory.get_action(state, pi_option)

        return action, pi_option

    def get_epsilon_greedy_action_pi_option(self, state, env_index):
        if isinstance(state, Batch):
            state = flatten_obs(state, self._ordered_obs_keys)
        action_space = self._get_action_space(env_index)
        rng = action_space.np_random

        if rng.uniform(0, 1) < self.eps:
            action, pi_option = self.get_random_action(state, env_index)
        else:
            action, pi_option = self.get_random_best_action_pi_option(state, env_index)
            if action is None:
                action, pi_option = self.get_random_action(state, env_index)

        return action, pi_option

    def get_closest_preferred_trade_off_action_pi_option(self, state, env_index):
        if isinstance(state, Batch):
            state = flatten_obs(state, self._ordered_obs_keys)
        non_dominated_set, v_values = self.fetch_policies(state)

        if len(v_values) == 0:
            return None, None

        collected_return = 0
        if self.collected_return is not None:
            collected_return = self.collected_return

        desired_trade_off = self._user_chosen_policy_trade_off or self.chosen_return

        next_collected_return = collected_return + v_values
        ideal_collected_return = (desired_trade_off / desired_trade_off[-1:]) * next_collected_return[:,-1:]

        mean = next_collected_return.mean(axis=0)
        std = next_collected_return.std(axis=0)
        std_safe = np.where(std == 0, 1, std)
        next_collected_return_norm = (next_collected_return - mean) / std_safe
        target_norm = (ideal_collected_return - mean) / std_safe

        distances = np.linalg.norm(target_norm - next_collected_return_norm, axis=1)
        min_index = np.argmin(distances)
        closest_trade_off = v_values[min_index]

        minimum_distance = np.min(distances)
        if minimum_distance > 0:
            warnings.warn(f"Switching from trade-off {self._user_chosen_policy_trade_off} "
                          f"to trade-off {closest_trade_off} with distance {minimum_distance}")

        _, pi_option = non_dominated_set[min_index]
        action = self.pi_memory.get_action(state, pi_option)
        if action is None:
            action, pi_option = self.get_random_action_pi_option(state, env_index)
            warnings.warn(f"Switching from trade-off {self._user_chosen_policy_trade_off} "
                          f"to a random action {action}, pi_option {pi_option}")

        return action, pi_option

    def _get_normalized_v_values(self, v_values):
        normalized_v_values = v_values / -v_values[:, -1:]
        return normalized_v_values

    def _get_non_dominated_v_values(self, v_values):
        non_dominated_indices = non_dominated_mask(v_values)
        duplicated_indices = duplicated_mask(v_values)
        return v_values[non_dominated_indices & ~duplicated_indices]

    def fetch_policies(self, observation, ignore_trajectory_length=False, ignore_cross_policy=False, use_average=False,
                       use_average_for_cycle_trajectory=True, use_average_for_cycle=False, cycles_only=True):
        return self.v_memory.fetch_non_dominated_set(observation, ignore_trajectory_length, ignore_cross_policy, use_average,
                                                     use_average_for_cycle_trajectory, use_average_for_cycle, cycles_only)

    def update_collected_return(self, **kwargs):
        if 'last_rew_R' in kwargs and not np.all(kwargs['last_rew_R'] == None):
            last_rew_R = kwargs['last_rew_R'].copy()
            not_null_mask = np.all(last_rew_R != None, axis=1)

            n_rows, n_cols = last_rew_R.shape
            if self.collected_return is None:
                self.collected_return = np.zeros((n_rows, n_cols + 1))

            expanded = np.empty((n_rows, n_cols + 1), dtype=object)
            expanded[not_null_mask, :n_cols] = last_rew_R[not_null_mask]
            expanded[not_null_mask, n_cols] = -1.0

            self.collected_return[not_null_mask] = self.collected_return[not_null_mask] + expanded[not_null_mask]
