import copy
import pickle
import warnings
from collections import defaultdict

import h5py
import numpy as np

from data.mpec.memory import Memory
from data.mpec.utils.non_domination_util import non_dominated_mask, duplicated_mask
from utils.collections_util import dict_to_defaultdict, defaultdict_to_dict, entry_or_none, np_concatenate, \
    reference_preserving_copy
from utils.double_linked_list import DoubleLinkedList, Node
from utils.experiment_util import get_max_trajectory_length


def update_v(info, rewards, next_info, alpha=0.1, gamma=0.99):

    v_s = info.get('v_value', 0)
    v_s_next = [policy_info.get('v_value') for _, policy_info in next_info.items()][0]

    g = sum(gamma ** i * rewards[i] for i in range(len(rewards))) + gamma ** len(rewards) * v_s_next
    v_s += alpha * (g - v_s)
    return v_s

class VMemory(Memory):

    def __init__(self, policy_domination_decimal_places=None, max_trajectory_length=None, *args, **kwargs):
        if policy_domination_decimal_places is not None:
            policy_domination_decimal_places = np.array(policy_domination_decimal_places)
        self._policy_domination_decimal_places = policy_domination_decimal_places

        self.is_debug = kwargs['debug']
        self._debug_disable_ssm = kwargs['debug_disable_ssm']
        self._debug_disable_cycle_detection = kwargs['debug_disable_cycle_detection']
        self._debug_track_trajectories_length = kwargs['debug_track_trajectories_length']
        self.debug_trajectory_length_tracking = {}

        super().__init__(*args, **kwargs)

        if max_trajectory_length is None:
            self._max_trajectory_length = get_max_trajectory_length(self.gamma)
            warnings.warn(f"Gamma {self.gamma} sets max trajectory length to {self._max_trajectory_length}")
        else:
            self._max_trajectory_length = max_trajectory_length

    def setup(self):
        super().setup()
        self.info = defaultdict(lambda: defaultdict(dict))
        self.pi_list_info = defaultdict(DoubleLinkedList)
        self.anchor_point_dict = defaultdict(Node)

    def save(self, env_name):
        super().save(env_name)
        with h5py.File(f"{env_name}.h5", 'a') as f:
            data = pickle.dumps(defaultdict_to_dict(self.pi_list_info))
            f.create_dataset('pi_list_info', data=np.void(data))
        with h5py.File(f"{env_name}.h5", 'a') as f:
            data = pickle.dumps(defaultdict_to_dict(self.anchor_point_dict))
            f.create_dataset('anchor_point_dict', data=np.void(data))

    def load(self, env_name):
        super().load(env_name)
        factory_list = [defaultdict, defaultdict, dict]
        self.info = dict_to_defaultdict(self.info, factory_list)

        with h5py.File(f"{env_name}.h5", 'r') as f:
            self.pi_list_info.update(pickle.loads(bytes(f['pi_list_info'][()])))
        factory_list = [defaultdict, DoubleLinkedList]
        self.pi_list_info = dict_to_defaultdict(self.pi_list_info, factory_list)

        with h5py.File(f"{env_name}.h5", 'r') as f:
            self.anchor_point_dict.update(pickle.loads(bytes(f['anchor_point_dict'][()])))
        factory_list = [defaultdict, Node]
        self.anchor_point_dict = dict_to_defaultdict(self.anchor_point_dict, factory_list)

    def add(self, state, pi_option, action=None, next_state=None, reward=None):

        assert self.step == 1, "Not implemented yet"

        state_pattern = self.abstract_state(state)
        action_pattern = None
        next_state_pattern = None
        if next_state is not None:
            action_pattern = self.abstract_action(action)
            next_state_pattern = self.abstract_state(next_state)
            reward = np.array([reward])

        assert entry_or_none(self.info, [next_state_pattern, pi_option]) is None

        detected_cycle = []

        info = self.info[state_pattern][pi_option]
        info['action_pattern'] = action_pattern

        trajectory_info = info['trajectory_info'] = {}

        trajectory_info['length'] = 1
        trajectory_info['persistent_length'] = trajectory_info['length']

        info['cycle_info'] = None

        old_v_s = info.get('v_value', None)
        if reward is None:
            info['reward_sum'] = 0
            info['reward_count'] = 0
            trajectory_info['v_value'] = None
        else:
            info['reward_sum'] = reward
            info['reward_count'] = 1
            trajectory_info['average_return'] = np.append(info['reward_sum'] / info['reward_count'], -1)

            trajectory_info['v_value'] = trajectory_info['average_return'] * trajectory_info['length']

        info['cross_policy_length'] = trajectory_info['length']
        info['cross_policy_value'] = trajectory_info['v_value']

        anchor_point_key = (0, None)
        anchor_point_data = (0, 0, None)
        self.anchor_point_dict[anchor_point_key].data = anchor_point_data
        info['stationary_segments'] = self.anchor_point_dict.get(anchor_point_key)

        if np.any(old_v_s != info['cross_policy_value']) and np.any(info['cross_policy_value'] != 0):
            pi_list_ref = info.get('pi_list_ref', None)
            node = self.pi_list_info[state_pattern].append(pi_option)
            info['pi_list_ref'] = node.id
            if pi_list_ref is not None:
                self.pi_list_info[state_pattern].pop(id_=pi_list_ref)

        if state_pattern == next_state_pattern:
            # cycle detected
            v_value, length = trajectory_info['v_value'], trajectory_info['length']
            average_reward = np.append(info['reward_sum'] / info['reward_count'], -1)
            cycle_info = self._collect_cycle_info(length, length, v_value, v_value, average_reward)
            cycle = (state_pattern, pi_option, state_pattern, pi_option, action_pattern, cycle_info)
            detected_cycle.append(cycle)

        if self.is_debug and self._debug_track_trajectories_length:
            self.debug_trajectory_length_tracking[(state_pattern, pi_option)] = 1

        if self.is_debug and self._debug_disable_cycle_detection:
            detected_cycle = []

        return detected_cycle

    def update(self, state_pattern, pi_option, action_pattern, next_state_pattern, next_pi_option, reward=None,
               disable_cycle_detection=False):

        def check_trajectory_gap(info, next_info):

            trajectory_info = info.get('trajectory_info')
            next_info_trajectory_info = next_info.get('trajectory_info')
            next_persistent_length = next_info_trajectory_info.get('persistent_length')

            persistent_trajectory_gap = False
            if pi_options_match:
                persistent_length = trajectory_info.get('persistent_length', 0)

                if persistent_length == 0 or next_persistent_length < persistent_length:
                    persistent_length = next_persistent_length + 1

                if next_persistent_length >= persistent_length:
                    persistent_trajectory_gap = True
            else:
                persistent_length = trajectory_info['persistent_length'] = 1

            trajectory_info['persistent_length'] = persistent_length

            return persistent_trajectory_gap

        if reward is not None:
            reward = np.array([reward])

        assert self.step == 1, "Not implemented yet"

        assert next_state_pattern is not None

        pi_options_match = True
        if pi_option != next_pi_option and next_pi_option is not None:
            pi_options_match = False

        if next_pi_option is None:
            next_info = {
                'trajectory_info': {
                    'average_return': 0,
                    'length': 0,
                    'persistent_length': 0,
                },
                'cross_policy_length': 0,
                'cycle_info': None,
                'stationary_segments': []
            }
        else:
            next_info = self.info.get(next_state_pattern).get(next_pi_option)

        detected_cycle = []
        transition_to_delete = []
        to_fix_cycle = False

        if self.is_above_maximum_length(next_state_pattern, next_pi_option):
            transition_to_delete.append((state_pattern, pi_option, action_pattern, next_state_pattern, next_pi_option))
            return detected_cycle, transition_to_delete, to_fix_cycle

        next_info_trajectory_info = next_info['trajectory_info']
        next_cross_policy_length = next_info.get('cross_policy_length')

        if self.is_debug and self._debug_track_trajectories_length:
            if entry_or_none(self.info, [state_pattern, pi_option]) is None:
                self.debug_trajectory_length_tracking[(state_pattern, pi_option)] = next_cross_policy_length + 1
                self.debug_trajectory_length_tracking.pop((next_state_pattern, next_pi_option), None)
            else:
                if self.debug_trajectory_length_tracking.get((state_pattern, pi_option), None) is not None:
                    self.debug_trajectory_length_tracking[(state_pattern, pi_option)] = next_cross_policy_length + 1

        info = self.info[state_pattern][pi_option]
        trajectory_info = info['trajectory_info'] = info.get('trajectory_info', {})
        info['action_pattern'] = action_pattern

        if reward is not None:
            info['reward_sum'] = info.get('reward_sum', 0) + reward
            info['reward_count'] = info.get('reward_count', 0) + 1
        average_reward = np.append(info['reward_sum'] / info['reward_count'], -1)

        persistent_trajectory_gap = check_trajectory_gap(info, next_info)

        if next_pi_option is not None:
            info['stationary_segments'] = next_info['stationary_segments']

        old_v_s = info.get('cross_policy_value', 0)

        is_cycle_transition = info.get('cross_policy_length') == 0
        if not is_cycle_transition:
            if persistent_trajectory_gap:
                info['cross_policy_length'] = info.get('cross_policy_length', next_cross_policy_length + 1)
                info['cross_policy_value'] = info.get('cross_policy_value', average_reward)
            else:
                info['cross_policy_length'] = next_cross_policy_length + 1
                info['cross_policy_value'] = next_info.get('cross_policy_value', 0) + average_reward

            if pi_options_match:
                if not persistent_trajectory_gap:
                    next_return = next_info_trajectory_info.get('average_return')
                    next_length = next_info_trajectory_info.get('length')
                    trajectory_info['length'] = next_length + 1
                    trajectory_info['average_return'] = next_return + (average_reward - next_return) / (next_length + 1)
            else:
                trajectory_info['length'] = 1
                trajectory_info['average_return'] = average_reward

            trajectory_info['v_value'] = trajectory_info.get('average_return') * trajectory_info.get('length')

            if not pi_options_match:
                anchor_point_key = (next_info_trajectory_info['length'], next_pi_option)
                anchor_point_data = (next_info_trajectory_info['v_value'], next_info_trajectory_info['length'], next_pi_option)
                self.anchor_point_dict[anchor_point_key].data = anchor_point_data
                if anchor_point_key != info['stationary_segments'].data[1:3]:
                    self.anchor_point_dict[anchor_point_key].next = info['stationary_segments']
                info['stationary_segments'] = self.anchor_point_dict[anchor_point_key]

        info['cycle_info'] = next_info['cycle_info']

        if not disable_cycle_detection:
            detected_cycle, transition_to_delete, to_fix_cycle = self._detect_cycle(
                state_pattern, pi_option, action_pattern, next_state_pattern, next_pi_option)
            info = self.info[state_pattern][pi_option]

        if np.any(old_v_s != info['cross_policy_value']) and np.any(info['cross_policy_value'] != 0):
            pi_list_ref = info.get('pi_list_ref', None)
            node = self.pi_list_info[state_pattern].append(pi_option)
            info['pi_list_ref'] = node.id
            if pi_list_ref is not None:
                self.pi_list_info[state_pattern].pop(id_=pi_list_ref)

        return detected_cycle, transition_to_delete, to_fix_cycle

    def is_above_maximum_length(self, next_state_pattern, next_pi_option):

        above_maximum_length = False
        if next_pi_option is None:
            return above_maximum_length

        next_info = self.info.get(next_state_pattern).get(next_pi_option)
        if self.is_debug and self._debug_disable_ssm:
            if next_info is None:
                return above_maximum_length
        length = next_info.get('cross_policy_length') + 1

        if length > self._max_trajectory_length:
            above_maximum_length = True

        return above_maximum_length

    def reattach(self, state_pattern, pi_option, action_pattern, next_state_pattern, next_pi_option, old_pi_option):

        assert self.step == 1, "Not implemented yet"

        base = self.info.get(state_pattern).get(old_pi_option)
        info = self.info[state_pattern][pi_option]
        info['reward_sum'] = base['reward_sum']
        info['reward_count'] = base['reward_count']

        if self.is_debug and self._debug_disable_ssm:
            entry = entry_or_none(self.info, [next_state_pattern, next_pi_option])
            if entry is None:
                next_pi_option = None

        if self.is_above_maximum_length(next_state_pattern, next_pi_option):
            detected_cycle = []
            transition_to_delete = []
            to_fix_cycle = False

            if len(self.info[state_pattern]) == 0:
                self.info.pop(state_pattern)

            transition_to_delete.append((state_pattern, pi_option, action_pattern, next_state_pattern, next_pi_option))
            return detected_cycle, transition_to_delete, to_fix_cycle

        update_return = self.update(state_pattern, pi_option, action_pattern, next_state_pattern, next_pi_option)

        return update_return

    def _debug_recolor(self, state_pattern, pi_option, action_pattern, next_state_pattern, next_pi_option, old_pi_option):

        assert self.step == 1, "Not implemented yet"

        self.info[state_pattern][pi_option] = self.info.get(state_pattern).get(old_pi_option)
        self.pi_list_info[state_pattern].id_map[self.info[state_pattern][pi_option]['pi_list_ref']].data = pi_option
        info = self.info.get(state_pattern).get(old_pi_option)

        anchor_point_to_replace = (info['trajectory_info']['length'], old_pi_option)
        if self.anchor_point_dict.get(anchor_point_to_replace) is not None:
            anchor_point_key = (info['trajectory_info']['length'], pi_option)
            reference_preserving_copy(self.anchor_point_dict[anchor_point_to_replace], self.anchor_point_dict[anchor_point_key])
            data = self.anchor_point_dict[anchor_point_key].data
            self.anchor_point_dict[anchor_point_key].data = (data[0], data[1], pi_option)
            self.anchor_point_dict.pop(anchor_point_to_replace)

        self.info.get(state_pattern).pop(old_pi_option)
        if len(self.info[state_pattern]) == 0:
            self.info.pop(state_pattern)

    def _copy_and_update(self, state_pattern, pi_option, action_pattern, next_state_pattern, next_pi_option, old_pi_option):

        assert self.step == 1, "Not implemented yet"

        info = self.info[state_pattern][pi_option] = (
            copy.deepcopy(self.info.get(state_pattern).get(old_pi_option)))
        info['pi_list_ref'] = None
        info['stationary_segments'] = None

        update_return = self.update(state_pattern, pi_option, action_pattern, next_state_pattern, next_pi_option, disable_cycle_detection=True)

        pi_list_ref = info.get('pi_list_ref')
        if pi_list_ref is None:
            node = self.pi_list_info[state_pattern].append(pi_option)
            info['pi_list_ref'] = node.id

        assert info.get('pi_list_ref') is not None

        return update_return

    def clone_and_mark_cycles(self, detected_cycles, cycle_spi_associations_list, old_pi_option_list):

        self._clone_transitions(cycle_spi_associations_list, old_pi_option_list)

        for cycle, cycle_spi_associations in zip(detected_cycles, cycle_spi_associations_list):
            (_, _, _, _, _, cycle_info) = cycle

            pi_option = cycle_spi_associations[0][1]
            anchor_point_key = (0, pi_option)
            anchor_point_data = (0, 0, pi_option)
            self.anchor_point_dict[anchor_point_key].data = anchor_point_data

            for (state_pattern, _, _, _, _) in cycle_spi_associations:
                info = self.info[state_pattern][pi_option]
                info['cycle_info'] = cycle_info

                trajectory_info = info['trajectory_info']
                trajectory_info['persistent_length'] = 1
                trajectory_info['length'] = info['cross_policy_length'] = 0
                trajectory_info['v_value'] = info['cross_policy_value'] = np.zeros_like(trajectory_info['v_value'])
                trajectory_info['average_return'] = cycle_info['average_return']

                info['stationary_segments'] = self.anchor_point_dict.get(anchor_point_key)

    def get(self, state, pi_option=None):
        state_pattern = self.abstract_state(state)

        try:
            v_value = self.info.get(state_pattern).get(pi_option)
            return v_value
        except AttributeError:
            return None

    def _clone_transitions(self, cycle_spi_associations_list, old_pi_option_list):
        for cycle_spi_associations, old_pi_option_list in zip(cycle_spi_associations_list, old_pi_option_list):
            if len(cycle_spi_associations) > 1:
                cycle_spi_associations, cycle_spi_association = cycle_spi_associations[:-1], cycle_spi_associations[-1]
                old_pi_option_list, old_pi_option = old_pi_option_list[:-1], old_pi_option_list[-1]
                cycle_spi_association = [*cycle_spi_association]
                cycle_spi_association[4] = None
                self._copy_and_update(*cycle_spi_association, old_pi_option)

            for cycle_spi_association, old_pi_option in reversed(list(zip(cycle_spi_associations, old_pi_option_list))):
                self._copy_and_update(*cycle_spi_association, old_pi_option)

    def fetch_non_dominated_set(self, state, ignore_trajectory_length=False, ignore_cross_policy=False, use_average=False,
                                use_average_for_cycle_trajectory=False, use_average_for_cycle=False, cycles_only=False):

        if use_average:
            use_average_for_cycle_trajectory = True
            use_average_for_cycle = True

        ignore_cycle_length = ignore_trajectory_length

        state_pattern = self.abstract_state(state)

        cycle_transitions_spi_list = []
        cycle_trajectory_cycle_info_spi_list = []
        cycle_trajectory_trajectory_info_spi_list = []
        trajectory_info_spi_list = []

        cycle_transitions_v_value_list = []
        cycle_trajectory_cycle_info_v_value_list = []
        cycle_trajectory_trajectory_info_v_value_list = []
        trajectory_info_v_value_list = []

        pi_option_precedence_info = self.pi_list_info.get(state_pattern, [])
        pi_option_info = self.info.get(state_pattern, dict())
        for _, pi_option in pi_option_precedence_info:
            info = pi_option_info.get(pi_option)
            trajectory_info = info.get('trajectory_info')
            cycle_info = info.get('cycle_info')

            if cycle_info is not None:
                if use_average_for_cycle:
                    cycle_v_value = cycle_info.get('average_return')
                else:
                    cycle_v_value = cycle_info.get('v_value')

                if use_average_for_cycle_trajectory:
                    if info.get('cross_policy_length') == 0:
                        v_value = np.full_like(trajectory_info.get('average_return'), np.inf)
                    else:
                        v_value = trajectory_info.get('average_return')
                else:
                    if info.get('cross_policy_length') == 0:
                        cycle_transitions_v_value_list.append(cycle_v_value)
                        cycle_transitions_spi_list.append((state_pattern, pi_option))
                        continue
                    else:
                        v_value = info.get('cross_policy_value')

                if ignore_trajectory_length:
                    v_value = v_value[:-1]

                cycle_trajectory_cycle_info_v_value_list.append(cycle_v_value)
                cycle_trajectory_cycle_info_spi_list.append((state_pattern, pi_option))
                cycle_trajectory_trajectory_info_v_value_list.append(v_value)
                cycle_trajectory_trajectory_info_spi_list.append((state_pattern, pi_option))
            else:
                if cycles_only:
                    continue

                if ignore_cross_policy:
                    v_value = trajectory_info.get('v_value')
                else:
                    v_value = info.get('cross_policy_value')
                if ignore_trajectory_length:
                    v_value = v_value[:-1]

                trajectory_info_v_value_list.append(v_value)
                trajectory_info_spi_list.append((state_pattern, pi_option))

        cycle_transitions_spi_list = np.array(cycle_transitions_spi_list)
        cycle_trajectory_cycle_info_spi_list = np.array(cycle_trajectory_cycle_info_spi_list)
        trajectory_info_spi_list = np.array(trajectory_info_spi_list)

        cycle_transitions_v_value_list = np.array(cycle_transitions_v_value_list)
        cycle_trajectory_cycle_info_v_value_list = np.array(cycle_trajectory_cycle_info_v_value_list)
        cycle_trajectory_trajectory_info_v_value_list = np.array(cycle_trajectory_trajectory_info_v_value_list)
        trajectory_info_v_value_list = np.array(trajectory_info_v_value_list)

        non_dominated_set, non_dominated_v_value_list = (
            self._check_cycle_domination(
                cycle_transitions_spi_list, cycle_trajectory_cycle_info_spi_list,
                cycle_transitions_v_value_list, cycle_trajectory_cycle_info_v_value_list, cycle_trajectory_trajectory_info_v_value_list,
                ignore_cycle_length))

        if not cycles_only:
            all_cycle_trajectory_v_value_list = np_concatenate([
                cycle_transitions_v_value_list, cycle_trajectory_cycle_info_v_value_list + cycle_trajectory_trajectory_info_v_value_list])

            trajectory_non_dominated_set, trajectory_non_dominated_v_value_list = (
                self._check_trajectory_domination(
                    trajectory_info_spi_list, all_cycle_trajectory_v_value_list, trajectory_info_v_value_list, ignore_trajectory_length))

            non_dominated_set = np_concatenate([trajectory_non_dominated_set, non_dominated_set])

        return non_dominated_set, non_dominated_v_value_list

    def fetch_dominated_sets(self, state):

        def cap_domination(non_dominated_mask, dominator_indices):
            if len(non_dominated_mask) > 0:
                dominated_indices = set(np.where(~non_dominated_mask)[0])
                dominator_indices = np.array(list(map(lambda x: x - dominated_indices, dominator_indices)))

            return non_dominated_mask, dominator_indices

        state_pattern = self.abstract_state(state)

        cycle_transitions_spi_list = []
        cycle_trajectory_cycle_info_spi_list = []
        cycle_trajectory_trajectory_info_spi_list = []
        trajectory_info_spi_list = []

        cycle_transitions_v_value_list = []
        cycle_trajectory_cycle_info_v_value_list = []
        cycle_trajectory_trajectory_info_v_value_list = []
        trajectory_info_v_value_list = []

        pi_option_precedence_info = self.pi_list_info.get(state_pattern, [])
        for _, pi_option in pi_option_precedence_info:
            info = self.info.get(state_pattern).get(pi_option)
            cycle_info = info.get('cycle_info')

            if cycle_info is not None:
                cycle_v_value = cycle_info.get('v_value')
                if info.get('cross_policy_length') == 0:
                    cycle_transitions_v_value_list.append(cycle_v_value)
                    cycle_transitions_spi_list.append((state_pattern, pi_option))
                else:
                    v_value = info.get('cross_policy_value')

                    cycle_trajectory_cycle_info_v_value_list.append(cycle_v_value)
                    cycle_trajectory_cycle_info_spi_list.append((state_pattern, pi_option))
                    cycle_trajectory_trajectory_info_v_value_list.append(v_value)
                    cycle_trajectory_trajectory_info_spi_list.append((state_pattern, pi_option))
            else:
                trajectory_info_v_value_list.append(info.get('cross_policy_value'))
                trajectory_info_spi_list.append((state_pattern, pi_option))

        cycle_transitions_spi_list = np.array(cycle_transitions_spi_list)
        cycle_trajectory_cycle_info_spi_list = np.array(cycle_trajectory_cycle_info_spi_list)
        cycle_trajectory_trajectory_info_spi_list = np.array(cycle_trajectory_trajectory_info_spi_list)
        trajectory_info_spi_list = np.array(trajectory_info_spi_list)

        cycle_transitions_v_value_list = np.array(cycle_transitions_v_value_list)
        cycle_trajectory_cycle_info_v_value_list = np.array(cycle_trajectory_cycle_info_v_value_list)
        cycle_trajectory_trajectory_info_v_value_list = np.array(cycle_trajectory_trajectory_info_v_value_list)
        trajectory_info_v_value_list = np.array(trajectory_info_v_value_list)

        non_dominated_cycle_mask, cycle_dominator_indices, cycle_transitions_elements_size = self._check_cycle_domination(
            cycle_transitions_spi_list, cycle_trajectory_cycle_info_spi_list,
            cycle_transitions_v_value_list, cycle_trajectory_cycle_info_v_value_list, cycle_trajectory_trajectory_info_v_value_list,
            return_mask=True, return_dominators=True)
        cycle_dominated_set, cycle_dominators = np.array([]), []
        cycle_spi_list = np_concatenate([cycle_transitions_spi_list, cycle_trajectory_cycle_info_spi_list])
        if len(non_dominated_cycle_mask) > 0:
            dominated_indices = set(np.where(~non_dominated_cycle_mask)[0])
            cycle_dominated_set = cycle_spi_list[~non_dominated_cycle_mask]
            dominator_indices = cycle_dominator_indices[~non_dominated_cycle_mask]
            dominator_indices = list(map(lambda x: x - dominated_indices, dominator_indices))
            cycle_dominators = list(map(lambda x: [tuple(cycle_spi_list[i]) for i in x], dominator_indices))

        all_cycle_trajectory_spi_list = np_concatenate([cycle_transitions_spi_list, cycle_trajectory_trajectory_info_spi_list])
        all_cycle_trajectory_v_value_list = np_concatenate([
            cycle_transitions_v_value_list, cycle_trajectory_cycle_info_v_value_list + cycle_trajectory_trajectory_info_v_value_list])

        full_non_dominated_trajectory_mask, full_trajectory_dominator_indices, trajectory_elements_size = self._check_trajectory_domination(
            trajectory_info_spi_list, all_cycle_trajectory_v_value_list, trajectory_info_v_value_list, return_mask=True, return_dominators=True)

        if (len(non_dominated_cycle_mask[cycle_transitions_elements_size:]) > 0
                and len(full_non_dominated_trajectory_mask[cycle_transitions_elements_size:-trajectory_elements_size]) > 0):
            non_dominated_cycle__trajectories_mask = (
                    non_dominated_cycle_mask[cycle_transitions_elements_size:]
                    | full_non_dominated_trajectory_mask[cycle_transitions_elements_size:-trajectory_elements_size])
            full_non_dominated_trajectory_mask[cycle_transitions_elements_size:-trajectory_elements_size] = non_dominated_cycle__trajectories_mask

            cycle__trajectory_domination = full_trajectory_dominator_indices[cycle_transitions_elements_size:-trajectory_elements_size]
            cycle_domination = cycle_dominator_indices[cycle_transitions_elements_size:]
            full_trajectory_dominator_indices[cycle_transitions_elements_size:-trajectory_elements_size] = (
                np.vectorize(lambda x, y: x | y)(cycle__trajectory_domination, cycle_domination))
            full_trajectory_dominator_indices[cycle_transitions_elements_size:-trajectory_elements_size][non_dominated_cycle__trajectories_mask] = set()

        non_dominated_trajectory_mask, trajectory_dominator_indices = cap_domination(
            full_non_dominated_trajectory_mask, full_trajectory_dominator_indices)

        trajectory_dominated_set, trajectory_dominators = np.array([]), []
        trajectory_spi_list = np_concatenate([trajectory_info_spi_list, all_cycle_trajectory_spi_list])
        if len(non_dominated_trajectory_mask) > 0 and trajectory_elements_size > 0:
            trajectory_dominated_set = trajectory_info_spi_list[~non_dominated_trajectory_mask[-trajectory_elements_size:]]
            all_spi_list = np_concatenate([cycle_spi_list, trajectory_spi_list])
            trajectory_dominator_indices = trajectory_dominator_indices[-trajectory_elements_size:][~non_dominated_trajectory_mask[-trajectory_elements_size:]]
            trajectory_dominators = list(map(lambda x: [tuple(all_spi_list[i]) for i in x], trajectory_dominator_indices))

        dominated_set = np_concatenate([trajectory_dominated_set, cycle_dominated_set])
        dominators = trajectory_dominators + cycle_dominators

        return dominated_set, dominators

    def fetch_non_dominated_detected_cycles(self, state, detected_cycles):
        state_pattern = self.abstract_state(state)

        detected_cycles_map = {tuple(cycle[2:4]): cycle for cycle in detected_cycles}

        detected_cycle_spi_list = []
        cycle_spi_list = []
        detected_cycle_v_value_list = []
        cycle_v_value_list = []

        pi_option_precedence_info = self.pi_list_info.get(state_pattern, [])
        for _, pi_option in pi_option_precedence_info:
            info = self.info.get(state_pattern).get(pi_option)
            cycle_info = info.get('cycle_info')

            if cycle_info is not None:
                cycle_v_value_list.append(cycle_info.get('v_value'))
                cycle_spi_list.append((state_pattern, pi_option))

            if (state_pattern, pi_option) in detected_cycles_map:
                detected_cycle = detected_cycles_map[(state_pattern, pi_option)]
                detected_cycle_info = detected_cycle[5]

                detected_cycle_spi_list.append((state_pattern, pi_option))
                detected_cycle_v_value_list.append(detected_cycle_info.get('v_value'))

        cycle_v_value_list = np.array(cycle_v_value_list)

        detected_cycle_spi_list = np.array(detected_cycle_spi_list)
        detected_cycle_v_value_list = np.array(detected_cycle_v_value_list)

        non_dominated_cycle_mask, _, trajectory_elements_size = self._check_trajectory_domination(
            detected_cycle_spi_list, cycle_v_value_list, detected_cycle_v_value_list, return_mask=True, return_dominators=True)

        non_dominated_set = []
        if len(non_dominated_cycle_mask) > 0 and trajectory_elements_size > 0:
            non_dominated_set = detected_cycle_spi_list[non_dominated_cycle_mask[-trajectory_elements_size:]]
            non_dominated_set = list(map(lambda x: detected_cycles_map.get(tuple(x)), non_dominated_set))

        return non_dominated_set

    def _check_cycle_domination(
            self, cycle_transitions_spi_list, cycle_trajectory_cycle_info_spi_list,
            cycle_transitions_v_value_list, cycle_trajectory_cycle_info_v_value_list, cycle_trajectory_trajectory_info_v_value_list,
            ignore_cycle_length=False, return_mask=False, return_dominators=False):

        def flatten_dominators(nested, group_indices):

            sorted_idx = np.argsort(group_indices, kind="stable")

            inverse_sorted_idx = [0] * len(sorted_idx)
            for i, val in enumerate(sorted_idx):
                inverse_sorted_idx[val] = i

            flattened = []
            offset = 0
            for group_pos, _ in enumerate(np.unique(group_indices)):
                group_len = len(nested[group_pos])
                sorted_idx_group = sorted_idx[offset:offset + group_len]
                a_list = []
                for s in nested[group_pos]:
                    set_ = set()
                    for e in s:
                        set_.add(sorted_idx_group[e])
                    a_list.append(set_)
                flattened.extend(a_list)
                offset += group_len

            corrected = np.array(flattened)[inverse_sorted_idx]

            return corrected

        def shift_domination(dominator_indices, trajectory_elements_size):
            return np.array([{e + trajectory_elements_size for e in s} for s in dominator_indices])

        cycle_dominators = []
        dominators = []

        cycle_v_value_list = np_concatenate([cycle_transitions_v_value_list, cycle_trajectory_cycle_info_v_value_list])

        cycle_transitions_elements_size = cycle_transitions_v_value_list.shape[0]

        cycle_v_values = cycle_v_value_list
        if not ignore_cycle_length and cycle_v_value_list.size > 0:
            cycle_v_values = cycle_v_value_list.copy()
            e = 1e-8
            cycle_v_values += max(0, -cycle_v_values.min()) + e

        if len(cycle_v_values) == 0:
            if return_mask:
                if return_dominators:
                    return np.array([]), np.array([]), cycle_transitions_elements_size
                return np.array([])

            if return_dominators:
                raise NotImplementedError

            cycle_spi_list = np.empty((0, 2))
            return cycle_spi_list, cycle_v_values

        cycle_non_dominated_indices = non_dominated_mask(cycle_v_values, self._policy_domination_decimal_places, return_dominators)
        cycle_all_duplicated_indices, group_indices = duplicated_mask(
            cycle_v_values, self._policy_domination_decimal_places, group_duplicates=True)
        cycle_transitions_duplicated_indices = duplicated_mask(cycle_v_values[:cycle_transitions_elements_size],
                                                               self._policy_domination_decimal_places, return_dominators)
        if return_dominators:
            cycle_non_dominated_indices, cycle_dominators = cycle_non_dominated_indices
            cycle_transitions_duplicated_indices, cycle_transitions_dominators = cycle_transitions_duplicated_indices
            cycle_dominators[:cycle_transitions_elements_size] |= cycle_transitions_dominators

        cycle_spi_list = np_concatenate([cycle_transitions_spi_list, cycle_trajectory_cycle_info_spi_list])

        cycle_trajectories_group_indices = group_indices[cycle_transitions_elements_size:]

        if len(cycle_trajectories_group_indices) > 0:
            # Sort by group ID to cluster items
            sorted_idx = np.argsort(cycle_trajectories_group_indices, kind="stable")
            sorted_groups = cycle_trajectories_group_indices[sorted_idx]
            sorted_values = cycle_trajectory_trajectory_info_v_value_list[sorted_idx]

            # Find group boundaries
            if len(np.unique(cycle_trajectories_group_indices)) == 1:
                split_values = [sorted_values]
            else:
                boundaries = np.flatnonzero(np.diff(sorted_groups)) + 1
                split_values = np.split(sorted_values, boundaries)

            # Apply filters
            non_dominated_group_results = [non_dominated_mask(
                group, self._policy_domination_decimal_places, return_dominators) for group in split_values]
            duplicated_group_results = [duplicated_mask(
                group, self._policy_domination_decimal_places, return_dominators) for group in split_values]
            if return_dominators:
                non_dominated_group_results, dom_trajectory_dominators = zip(*non_dominated_group_results)
                duplicated_group_results, dup_trajectory_dominators = zip(*duplicated_group_results)
                trajectory_dominators = [dom_trajectory_dominators[i] | dup_trajectory_dominators[i] for i in range(len(dom_trajectory_dominators))]
                trajectory_dominators = flatten_dominators(trajectory_dominators, cycle_trajectories_group_indices)
                trajectory_dominators = shift_domination(trajectory_dominators, cycle_transitions_elements_size)
                trajectory_dominators = np.concatenate([np.full(cycle_transitions_v_value_list.shape[0], set()), trajectory_dominators])
                dominators = cycle_dominators | trajectory_dominators

            non_dominated_and_non_duplicated_group_results = [
                non_dominated_group_results[i] & ~duplicated_group_results[i] for i in range(len(split_values))]

            # Flatten group-wise masks back into original order
            cycle_dup_trajectory_non_dom_and_non_dup_indices = np.empty_like(cycle_trajectories_group_indices, dtype=bool)
            start = 0
            for group_pos, group_id in enumerate(np.unique(cycle_trajectories_group_indices)):
                group_len = np.sum(cycle_trajectories_group_indices == group_id)
                cycle_dup_trajectory_non_dom_and_non_dup_indices[cycle_trajectories_group_indices == group_id] = (
                    non_dominated_and_non_duplicated_group_results)[group_pos]
                start += group_len

            cycle_dup_trajectory_non_dom_and_non_dup_indices = np.concatenate([
                ~cycle_transitions_duplicated_indices,
                cycle_dup_trajectory_non_dom_and_non_dup_indices])
        else:
            cycle_dup_trajectory_non_dom_and_non_dup_indices = ~cycle_transitions_duplicated_indices
            dominators = cycle_dominators

        mask = (
            cycle_non_dominated_indices &
            (
                ~cycle_all_duplicated_indices |
                cycle_dup_trajectory_non_dom_and_non_dup_indices
            )
        )

        if return_mask:
            if return_dominators:
                return mask, dominators, cycle_transitions_elements_size
            return mask

        if return_dominators:
            raise NotImplementedError

        non_dominated_set = cycle_spi_list[mask]
        non_dominated_v_value_list = cycle_v_value_list[mask]

        return non_dominated_set, non_dominated_v_value_list

    def _check_trajectory_domination(self, trajectory_spi_list, cycle_trajectory_v_value_list, trajectory_v_value_list,
                                     ignore_trajectory_length=False, return_mask=False, return_dominators=False):
        dominators = []

        cycle_trajectory_v_values = cycle_trajectory_v_value_list
        trajectory_v_values = trajectory_v_value_list
        if not ignore_trajectory_length:
            trajectory_v_values = trajectory_v_value_list.copy()
            cycle_trajectory_v_values = cycle_trajectory_v_value_list.copy()

            max_now = 0
            if cycle_trajectory_v_values.size > 0:
                max_now = max(max_now, -cycle_trajectory_v_values.min())

            if trajectory_v_values.size > 0:
                max_now = max(max_now, -trajectory_v_values.min())

            e = 1e-8
            offset = max_now + e

            cycle_trajectory_v_values += offset
            trajectory_v_values += offset

        all_trajectory_v_values = np_concatenate([cycle_trajectory_v_values, trajectory_v_values])

        if len(all_trajectory_v_values) == 0:
            if return_mask:
                if return_dominators:
                    return np.array([]), np.array([]), 0
                return np.array([])

            if return_dominators:
                raise NotImplementedError

            trajectory_spi_list = np.empty((0, 2))
            return trajectory_spi_list, trajectory_v_value_list

        trajectory_elements_size = trajectory_v_values.shape[0]

        trajectory_non_dominated_indices = non_dominated_mask(all_trajectory_v_values, self._policy_domination_decimal_places, return_dominators)
        trajectory_duplicated_indices = duplicated_mask(all_trajectory_v_values, self._policy_domination_decimal_places, return_dominators)
        if return_dominators:
            trajectory_non_dominated_indices, dominators = trajectory_non_dominated_indices
            trajectory_duplicated_indices, dup_dominators = trajectory_duplicated_indices
            dominators = dominators | dup_dominators

        mask = trajectory_non_dominated_indices & ~trajectory_duplicated_indices

        if return_mask:
            if return_dominators:
                return mask, dominators, trajectory_elements_size
            return mask

        if return_dominators:
            raise NotImplementedError

        if trajectory_elements_size == 0:
            trajectory_spi_list = np.empty((0, 2))
            return trajectory_spi_list, trajectory_v_value_list

        filtered_mask = mask[-trajectory_elements_size:]

        non_dominated_set = trajectory_spi_list[filtered_mask]
        non_dominated_v_value_list = trajectory_v_values[filtered_mask]

        return non_dominated_set, non_dominated_v_value_list

    def delete(self, to_delete_list):
        for (state_pattern, pi_option) in to_delete_list:
            info = self.info.get(state_pattern).get(pi_option)
            node_to_remove_id = info.get('pi_list_ref')
            self.pi_list_info[state_pattern].pop(id_=node_to_remove_id)

            anchor_point_to_replace = (info['trajectory_info']['length'], pi_option)
            if self.anchor_point_dict.get(anchor_point_to_replace) is not None:
                anchor_point_key = (0, None)
                if self.anchor_point_dict.get(anchor_point_key) is None:
                    self.anchor_point_dict[anchor_point_key].data = (0, 0, None)

                reference_preserving_copy(self.anchor_point_dict[anchor_point_key], self.anchor_point_dict[anchor_point_to_replace])
                self.anchor_point_dict.pop(anchor_point_to_replace)

            self.info.get(state_pattern).pop(pi_option)
            if len(self.info[state_pattern]) == 0:
                self.info.pop(state_pattern)

    def fetch_pi_options(self, state):
        state_pattern = self.abstract_state(state)

        try:
            pi_options = self.info.get(state_pattern)
            return list(pi_options.keys())
        except AttributeError:
            return []

    def get_action(self, state, pi_option, not_exist_ok=True):
        state_pattern = self.abstract_state(state)
        return self._get_action(state_pattern, pi_option, not_exist_ok)

    def _get_action(self, state_pattern, pi_option, not_exist_ok=True):
        try:
            next_action_info = self.info.get(state_pattern).get(pi_option)
            next_action = next_action_info.get('action_pattern')
            return next_action
        except (AttributeError, TypeError):
            if not_exist_ok:
                return None
            else:
                raise

    def _detect_cycle(self, state_pattern, pi_option, action_pattern, next_state_pattern, next_pi_option):

        detected_cycle = []
        transition_to_delete = []
        to_fix_cycle = False

        if next_pi_option is None:
            return detected_cycle, transition_to_delete, to_fix_cycle

        if self.is_debug and (self._debug_disable_ssm or self._debug_disable_cycle_detection):
            return detected_cycle, transition_to_delete, to_fix_cycle

        next_info = self.info[next_state_pattern][next_pi_option]
        for entry_pi_option, entry_info in self.info[state_pattern].items():

            if (entry_info['cross_policy_length'] >= next_info['cross_policy_length']
                    and not (entry_info['cross_policy_length'] == 0 and next_info['cross_policy_length'] == 0)):
                continue

            cycle_detected = False

            entry_node_iter = reversed(list(entry_info.get('stationary_segments')))
            try:
                entry_node = next(entry_node_iter)
            except StopIteration:
                continue

            forward_path_match = False

            next_node_total_v_value = 0
            next_node_total_length = 0
            entry_node_total_v_value = 0
            entry_node_total_length = 0

            entry_node_v_value, entry_node_length, entry_node_pi_option = entry_node.data
            entry_node_total_v_value += entry_node_v_value
            entry_node_total_length += entry_node_length

            for node in reversed(list(next_info['stationary_segments'])):
                next_node_v_value, next_node_length, next_node_pi_option = node.data
                next_node_total_v_value += next_node_v_value
                next_node_total_length += next_node_length

                if entry_node_total_length == next_node_total_length and next_node_pi_option == entry_node_pi_option:
                    try:
                        entry_node = next(entry_node_iter)
                        entry_node_v_value, entry_node_length, entry_node_pi_option = entry_node.data
                        entry_node_total_v_value += entry_node_v_value
                        entry_node_total_length += entry_node_length
                    except StopIteration:
                        forward_path_match = True
                elif not forward_path_match:
                    break
                else:
                    if (entry_node_total_length + entry_info['trajectory_info']['length'] <= next_node_total_length
                            and next_node_pi_option == entry_pi_option):
                        cycle_detected = True
            else:
                if forward_path_match and (entry_pi_option == next_pi_option or entry_info['cross_policy_length'] == 0):
                    cycle_detected = True

            if not cycle_detected:
                continue

            if entry_info['cross_policy_length'] == 0 and next_info['cross_policy_length'] == 0:
                if pi_option != next_pi_option:
                    transition_to_delete.append((state_pattern, pi_option))
                break

            if entry_info['cross_policy_length'] == 0:
                to_fix_cycle = True

            next_node_total_v_value += next_info.get('trajectory_info').get('v_value')
            next_node_total_length += next_info.get('trajectory_info').get('length') - 1
            entry_node_total_v_value += entry_info.get('trajectory_info').get('v_value')
            entry_node_total_length += entry_info.get('trajectory_info').get('length') - 1

            if entry_info['action_pattern'] != action_pattern:
                info = self.info.get(state_pattern).get(pi_option)
                entry_info_copy = copy.deepcopy(entry_info)
                entry_info_copy['reward_sum'] = info['reward_sum']
                entry_info_copy['reward_count'] = info['reward_count']

                entry_info = entry_info_copy
                entry_pi_option = pi_option

            transition_to_delete.append((state_pattern, pi_option))

            average_reward = np.append(entry_info['reward_sum'] / entry_info['reward_count'], -1)
            cycle_info = self._collect_cycle_info(
                next_node_total_length, entry_node_total_length, next_node_total_v_value, entry_node_total_v_value, average_reward)
            cycle = (next_state_pattern, next_pi_option, state_pattern, entry_pi_option, action_pattern, cycle_info)
            detected_cycle.append(cycle)
            break

        return detected_cycle, transition_to_delete, to_fix_cycle

    @staticmethod
    def _collect_cycle_info(start, end, start_v_value, end_v_value, average_reward):

        cycle_info = {}
        if end > start:
            cycle_info['length'] = end - start + 1
            cycle_info['v_value'] = end_v_value - start_v_value + average_reward
        else:
            cycle_info['length'] = start - end + 1
            cycle_info['v_value'] = start_v_value - end_v_value + average_reward
        cycle_info['average_return'] = cycle_info['v_value'] / cycle_info['length']

        assert -round(cycle_info['v_value'][-1]) == cycle_info['length']

        return cycle_info

    def fix_cycle(self, cycle_spi_associations, detected_cycle):
        v_value = 0
        length = 0
        for cycle_spi_association in cycle_spi_associations:
            (abs_observation, pi_option, abs_action, abs_next_observation, next_pi_option) = cycle_spi_association
            info = self.info.get(abs_observation).get(pi_option)
            v_value += np.append(info['reward_sum'] / info['reward_count'], -1)
            length += 1

        info = self.info[detected_cycle[2]][detected_cycle[3]]
        v_value += np.append(info['reward_sum'] / info['reward_count'], -1)
        length += 1

        cycle_info = detected_cycle[5]
        cycle_info['length'] = length
        cycle_info['v_value'] = v_value
        cycle_info['average_return'] = cycle_info['v_value'] / cycle_info['length']
