import copy
import uuid
from collections import defaultdict
from random import Random

import numpy as np

from data.mpec.memory import Memory
from utils.collections_util import dict_to_defaultdict, entry_or_none, reference_preserving_copy
from utils.double_linked_list import Node


class SPiMemory(Memory):

    def __init__(self, seed, *args, **kwargs):
        self.random_generator = Random(seed)
        self.is_debug = kwargs['debug']
        self._debug_disable_ssm = kwargs['debug_disable_ssm']
        self._debug_disable_cycle_detection = kwargs['debug_disable_cycle_detection']
        self._debug_disable_reconnection = kwargs['debug_disable_reconnection']
        super().__init__(*args, **kwargs)

    def setup(self):
        super().setup()
        self.info = defaultdict(lambda: defaultdict(Node))

    def load(self, env_path):
        super().load(env_path)
        factory_list = [defaultdict, defaultdict, defaultdict, defaultdict, dict]
        self.info = dict_to_defaultdict(self.info, factory_list)

    def add(self, state, pi_option, action, next_state):

        state_pattern = self.abstract_state(state)
        action_pattern = self.abstract_action(action)
        next_state_pattern = self.abstract_state(next_state)

        assert self.step == 1, "Not implemented yet"

        assert self.info[state_pattern].get(pi_option) is None

        info = self.info[state_pattern][pi_option]
        info.data['state_pattern'] = state_pattern
        info.data['pi_option'] = pi_option
        info.data['action_pattern'] = action_pattern
        info.data['next_state_pattern'] = next_state_pattern
        info.next = {}
        info.previous = {}
        info.other_pi_options_previous = {}

    def update(self, state_pattern, pi_option, action_pattern, next_state_pattern, next_pi_option):

        assert self.step == 1, "Not implemented yet"

        pi_options_match = True
        if pi_option != next_pi_option:
            pi_options_match = False

        current_action_pattern_info = self.info[state_pattern].get(pi_option)
        assert (current_action_pattern_info is None
                or current_action_pattern_info.data.get('action_pattern') is None
                or current_action_pattern_info.data.get('action_pattern') == action_pattern)

        info = self.info[state_pattern][pi_option]
        info.data['state_pattern'] = state_pattern
        info.data['pi_option'] = pi_option
        info.data['action_pattern'] = action_pattern
        info.data['next_state_pattern'] = next_state_pattern

        next_info = self.info.get(next_state_pattern).get(next_pi_option)
        info.next = next_info
        info.previous = {}
        info.other_pi_options_previous = {}

        if next_info is not None:
            if pi_options_match:
                next_info.previous = info
            else:
                next_info.other_pi_options_previous[(state_pattern, action_pattern)] = info

    def transfer(self, state_pattern, pi_option, dominators=None):

        def are_transitions_equal(info1, info2, both_false_ok=True):
            if not info1 and not info2 and not both_false_ok:
                raise ValueError("Both info cannot be None")

            if not info1 or not info2:
                return False

            return (info1.data['state_pattern'] == info2.data['state_pattern']
                    and info1.data['action_pattern'] == info2.data['action_pattern']
                    and info1.data['next_state_pattern'] == info2.data['next_state_pattern'])

        assert self.step == 1, "Not implemented yet"

        if dominators is None:
            dominators = []
        assert len(dominators) > 0

        spi_associations_to_update = []
        spi_associations_to_delete = set()

        for _, dominant_pi_option in dominators:
            current_state_pattern = state_pattern
            current_pi_option = dominant_pi_option
            current_info = self.info[current_state_pattern][current_pi_option]

            already_entered = False

            dominated_info = self.info[state_pattern][pi_option]
            dominated_previous_info = dominated_info.previous
            previous_state_pattern = dominated_previous_info.data['state_pattern'] if dominated_previous_info else None
            while previous_state_pattern is not None and previous_state_pattern != state_pattern:

                entry = current_info.other_pi_options_previous.get(
                    (dominated_previous_info.data['state_pattern'], dominated_previous_info.data['action_pattern'])
                ) if current_info else None
                previous_info = current_info.previous if current_info else None
                if (not already_entered
                        and (are_transitions_equal(dominated_previous_info, previous_info)
                             or are_transitions_equal(dominated_previous_info, entry))):

                    spi_associations_to_delete.add((previous_state_pattern, pi_option))

                    current_state_pattern = previous_state_pattern
                    if are_transitions_equal(dominated_previous_info, entry):
                        assert id(entry) == id(self.info[dominated_previous_info.data['state_pattern']][entry.data['pi_option']])
                        current_pi_option = entry.data['pi_option']
                        current_info = entry
                    else:
                        current_info = previous_info

                    dominated_previous_info = dominated_previous_info.previous
                    previous_state_pattern = dominated_previous_info.data['state_pattern'] if dominated_previous_info else None
                    continue

                already_entered = True

                previous_pi_option = self.generate_pi_option_id(previous_state_pattern, current_state_pattern, current_pi_option)

                assert self.info[previous_state_pattern].get(previous_pi_option) is None
                previous_info = self.info[previous_state_pattern][previous_pi_option] = copy.copy(dominated_previous_info)
                previous_info.data['pi_option'] = previous_pi_option
                previous_info.next = current_info
                previous_info.previous = {}
                previous_info.other_pi_options_previous = {}

                action = previous_info.data['action_pattern']
                if previous_pi_option == current_pi_option:
                    current_info.previous = previous_info
                else:
                    current_info.other_pi_options_previous[(previous_state_pattern, action)] = previous_info

                spi_association = (previous_state_pattern, previous_pi_option, action, current_state_pattern, current_pi_option)
                spi_associations_to_update.append((spi_association, pi_option))

                current_state_pattern = previous_state_pattern
                current_pi_option = previous_pi_option
                current_info = previous_info

                dominated_previous_info = dominated_previous_info.previous
                previous_state_pattern = dominated_previous_info.data['state_pattern'] if dominated_previous_info else None

        previous_info = self.info.get(state_pattern).get(pi_option).previous
        previous_state_pattern = previous_info.data['state_pattern'] if len(previous_info) > 0 else None
        while previous_state_pattern is not None and previous_state_pattern != state_pattern:
            current_state_pattern = previous_state_pattern
            previous_previous_info = previous_info.previous
            previous_state_pattern = previous_previous_info.data['state_pattern'] if previous_previous_info else None
            previous_info = previous_previous_info
            self.delete([(current_state_pattern, pi_option)])

        self.delete([(state_pattern, pi_option)])

        return spi_associations_to_update, spi_associations_to_delete

    def recall_trajectory(self, state_pattern, pi_option):

        assert self.step == 1, "Not implemented yet"

        spi_associations_to_update = []

        start_state_pattern = state_pattern
        current_state_pattern = state_pattern

        info = self.info.get(current_state_pattern).get(pi_option)
        action = info.data['action_pattern']
        next_state_pattern, next_pi_option = self.get_next__state_pattern__pi_option(current_state_pattern, pi_option)

        spi_association = (current_state_pattern, pi_option, action, next_state_pattern, next_pi_option)
        spi_associations_to_update.append(spi_association)

        previous_info = self.info.get(state_pattern).get(pi_option).previous
        previous_state_pattern = previous_info.data['state_pattern'] if previous_info else None

        while previous_state_pattern is not None and previous_state_pattern != start_state_pattern:
            info = self.info[previous_state_pattern][pi_option]
            action = info.data['action_pattern']

            spi_association = (previous_state_pattern, pi_option, action, current_state_pattern, pi_option)
            spi_associations_to_update.append(spi_association)

            current_state_pattern = previous_state_pattern
            previous_info = self.info.get(current_state_pattern).get(pi_option).previous
            previous_state_pattern = previous_info.data['state_pattern'] if previous_info else None

        return spi_associations_to_update

    def _debug_recolor(self, spi_associations):

        updated_spi_associations = []
        new_pi_option = self.generate_pi_option_id()

        for spi_association in spi_associations:
            state_pattern = spi_association[0]
            old_pi_option = spi_association[1]
            self.info[state_pattern][new_pi_option] = self.info[state_pattern][old_pi_option]
            self.info[state_pattern][new_pi_option].data['pi_option'] = new_pi_option

            next_state_pattern, next_pi_option = self.get_next__state_pattern__pi_option(state_pattern, old_pi_option)
            next_entry = entry_or_none(self.info, [next_state_pattern, next_pi_option])
            if next_entry is not None:
                if old_pi_option == next_pi_option:
                    action = self._get_action(state_pattern, old_pi_option)
                    if action is not None:
                        next_entry.other_pi_options_previous[(state_pattern, action)] = next_entry.previous
                    next_entry.previous = {}

            reference_preserving_copy(self.info[state_pattern][old_pi_option], self.info[state_pattern][new_pi_option])

            self.info[state_pattern].pop(old_pi_option)
            if len(self.info[state_pattern]) == 0:
                self.info.pop(state_pattern)

            spi_association_elements = [*spi_association]
            spi_association_elements[1] = new_pi_option
            spi_association_elements[4] = new_pi_option if spi_association_elements[4] is not None else None
            updated_spi_associations.append(tuple(spi_association_elements))

        return updated_spi_associations

    def fetch_trajectory(self, state_pattern, pi_option, stop_state=None):

        spi_associations = []

        visited = set()
        info = self.info.get(state_pattern).get(pi_option)

        while info.data['state_pattern'] != stop_state:
            id_ = id(info)
            if id_ in visited:
                break
            visited.add(id_)

            next_pi_option = info.next.data['pi_option'] if info.next else None
            spi_association = (*info.data.values(), next_pi_option)
            spi_associations.append(spi_association)

            if not info.next:
                break
            info = info.next

        return spi_associations

    def delete(self, to_delete_list):
        for (state_pattern, pi_option) in to_delete_list:
            next_state_pattern, next_pi_option = self.get_next__state_pattern__pi_option(state_pattern, pi_option)
            next_entry = entry_or_none(self.info, [next_state_pattern, next_pi_option])
            if next_entry is not None:
                if pi_option == next_pi_option:
                    next_entry.previous = {}
                else:
                    action = self.get_action(state_pattern, pi_option)
                    if action is not None:
                        next_entry.other_pi_options_previous.pop((state_pattern, action))

            reference_preserving_copy(Node(), self.info[state_pattern][pi_option])

            self.info[state_pattern].pop(pi_option)
            if len(self.info[state_pattern]) == 0:
                self.info.pop(state_pattern)

    def copy_and_mark_cycles(self, detected_cycles):

        transitions_to_copy_list = []
        old_pi_option_list = []
        for i, cycle in enumerate(detected_cycles):

            abort_iteration = False
            transitions_to_copy = []
            old_pi_options = []

            (start_state_pattern, start_pi_option, end_state_pattern, end_pi_option, end_action_pattern, cycle_info) = cycle

            new_pi_option = self.generate_pi_option_id()

            current_state_pattern = start_state_pattern
            current_pi_option = start_pi_option

            end_info = self.info[end_state_pattern][new_pi_option]
            end_info.data['state_pattern'] = end_state_pattern
            end_info.data['pi_option'] = new_pi_option
            end_info.data['action_pattern'] = end_action_pattern

            if start_state_pattern == end_state_pattern:
                end_info.data['next_state_pattern'] = current_state_pattern

                end_info.next = end_info
                end_info.previous = end_info
                end_info.other_pi_options_previous = {}
            else:
                current_info = start_info = self.info[current_state_pattern][new_pi_option] = copy.deepcopy(self.info[current_state_pattern][current_pi_option])
                current_info.data['state_pattern'] = current_state_pattern
                current_info.data['pi_option'] = new_pi_option
                current_info.previous = end_info
                current_info.other_pi_options_previous = {}

                end_info.data['next_state_pattern'] = current_state_pattern
                end_info.next = start_info

                next_state_pattern, next_pi_option = self.get_next__state_pattern__pi_option(current_state_pattern, current_pi_option)

                action_pattern = current_info.data['action_pattern']
                spi_association = (current_state_pattern, new_pi_option, action_pattern, next_state_pattern, new_pi_option)
                transitions_to_copy.append(spi_association)
                old_pi_options.append(current_pi_option)

                while next_state_pattern != end_state_pattern:

                    current_state_pattern = next_state_pattern
                    current_pi_option = next_pi_option

                    previous_info = current_info

                    if self.is_debug and self._debug_disable_reconnection:
                        if current_state_pattern is None:
                            abort_iteration = True
                            break

                    current_info = self.info[current_state_pattern][new_pi_option] = copy.deepcopy(self.info.get(current_state_pattern).get(current_pi_option))
                    current_info.data['state_pattern'] = current_state_pattern
                    current_info.data['pi_option'] = new_pi_option
                    current_info.previous = previous_info
                    current_info.other_pi_options_previous = {}

                    previous_info.data['next_state_pattern'] = current_state_pattern
                    previous_info.next = current_info

                    next_state_pattern, next_pi_option = self.get_next__state_pattern__pi_option(current_state_pattern, current_pi_option)

                    action_pattern = current_info.data['action_pattern']
                    spi_association = (current_state_pattern, new_pi_option, action_pattern, next_state_pattern, new_pi_option)
                    transitions_to_copy.append(spi_association)
                    old_pi_options.append(current_pi_option)

                if abort_iteration:
                    continue

                current_info.data['next_state_pattern'] = next_state_pattern
                current_info.next = end_info
                end_info.previous = current_info
                end_info.other_pi_options_previous = {}

            spi_association = (end_state_pattern, new_pi_option, end_action_pattern, start_state_pattern, new_pi_option)
            transitions_to_copy.append(spi_association)
            old_pi_options.append(end_pi_option)

            transitions_to_copy_list.append(transitions_to_copy)
            old_pi_option_list.append(old_pi_options)

            detected_cycles[i] = (
                start_state_pattern, new_pi_option, end_state_pattern, new_pi_option, end_action_pattern, cycle_info)

        return detected_cycles, transitions_to_copy_list, old_pi_option_list

    def fetch_spi_associations(self, state, action, next_state):

        state_pattern = self.abstract_state(state)
        action_pattern = self.abstract_action(action)
        next_state_pattern = self.abstract_state(next_state)

        return self._fetch_spi_associations(state_pattern, action_pattern, next_state_pattern)

    def _fetch_spi_associations(self, state_pattern, action_pattern, next_state_pattern):

        associations_list = []
        state_info = self.info.get(state_pattern, dict())

        for i, (pi_option, pi_option_info) in enumerate(state_info.items()):
            if (next_state_pattern != pi_option_info.data['next_state_pattern']
                    or action_pattern != pi_option_info.data['action_pattern']):
                continue
            action_pattern = pi_option_info.data['action_pattern']
            next_info = pi_option_info.next
            next_pi_option = next_info.data['pi_option'] if next_info else None
            associations_list.append(
                (state_pattern, pi_option, action_pattern, next_state_pattern, next_pi_option)
            )

        return associations_list

    def generate_pi_option_id(self, state_pattern=None, next_state_pattern=None, next_pi_option=None):

        assertion_list = np.array([state_pattern, next_state_pattern, next_pi_option])
        assert np.all(assertion_list == None) or np.all(assertion_list != None)

        pi_option = next_pi_option

        next_state_info = entry_or_none(self.info, [next_state_pattern, next_pi_option])
        assert next_pi_option is None or next_state_info is not None
        state_info = entry_or_none(self.info, [state_pattern, pi_option])

        if self.is_debug and self._debug_disable_ssm:
            if next_pi_option is None or state_info is not None:
                pi_option = uuid.UUID(int=self.random_generator.getrandbits(128), version=4).hex
            return pi_option

        if next_pi_option is None or next_state_info.previous or state_info is not None:
            pi_option = uuid.UUID(int=self.random_generator.getrandbits(128), version=4).hex

        return pi_option

    def get_action(self, state, pi_option, not_exist_ok=True):
        state_pattern = self.abstract_state(state)
        return self._get_action(state_pattern, pi_option, not_exist_ok)

    def _get_action(self, state_pattern, pi_option, not_exist_ok=True):
        try:
            next_action_info = self.info.get(state_pattern).get(pi_option)
            next_action = next_action_info.data['action_pattern']
            return next_action
        except (AttributeError, TypeError):
            if not_exist_ok:
                return None
            else:
                raise

    def get_next__state_pattern__pi_option(self, state_pattern, pi_option):
        info = self.info.get(state_pattern).get(pi_option)
        next_info = info.next
        next_state_pattern = next_info.data['state_pattern'] if next_info else None
        next_pi_option = next_info.data['pi_option'] if next_info else None
        return next_state_pattern, next_pi_option

