import uuid
from collections import defaultdict
from random import Random

from data.mpec.memory import Memory
from utils.collections_util import dict_to_defaultdict, defaultdict_to_dict


class NSPiMemory(Memory):

    def __init__(self, seed, *args, **kwargs):
        self.random_generator = Random(seed)
        super().__init__(*args, **kwargs)

    def setup(self):
        super().setup()
        self.info = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(dict))))

    def load(self, env_name):
        super().load(env_name)
        factory_list = [defaultdict, defaultdict, defaultdict, defaultdict, dict]
        self.info = dict_to_defaultdict(self.info, factory_list)

    def add(self, state_pattern, action_pattern, pi_option, next_state_pattern, next_action_pattern, next_pi_option):

        assert self.step == 1, "Not implemented yet"

        info = self.info[state_pattern][action_pattern][pi_option][next_state_pattern]
        info['next_action_pattern'] = next_action_pattern
        info['next_pi_option'] = next_pi_option

    def delete(self, state, action, pi_option, next_state=None):

        assert self.step == 1, "Not implemented yet"

        state_pattern = self.abstract_state(state)
        action_pattern = self.abstract_action(action)
        if next_state is not None:
            next_state_pattern = self.abstract_state(next_state)
            self.info.get(state_pattern).get(action_pattern).get(pi_option).pop(next_state_pattern)

            if len(self.info[state_pattern][action_pattern][pi_option]) == 0:
                self.info[state_pattern][action_pattern].pop(pi_option)
        else:
            try:
                self.info.get(state_pattern).get(action_pattern).pop(pi_option)
            except (AttributeError, KeyError):
                pass

        if len(self.info[state_pattern][action_pattern]) == 0:
            self.info[state_pattern].pop(action_pattern)
        if len(self.info[state_pattern]) == 0:
            self.info.pop(state_pattern)

    def get(self, state, action, pi_option, next_state, not_exist_ok=True):

        state_pattern = self.abstract_state(state)
        action_pattern = self.abstract_action(action)
        next_state_pattern = self.abstract_state(next_state)

        try:
            next_action_pi_option = self.info.get(state_pattern).get(action_pattern).get(pi_option).get(next_state_pattern)
            abs_next_action = next_action_pi_option.get('next_action_pattern')
            next_pi_option = next_action_pi_option.get('next_pi_option')
            return abs_next_action, next_pi_option
        except (AttributeError, TypeError):
            if not_exist_ok:
                return None, None
            else:
                raise

    def fetch_sapi_associations(self, state, action, next_state):

        state_pattern = self.abstract_state(state)
        action_pattern = self.abstract_action(action)
        next_state_pattern = self.abstract_state(next_state)

        associations_list = []
        action_info = self.info.get(state_pattern, dict()).get(action_pattern, dict())

        for pi_option, next_state_pattern_info in action_info.items():
            if next_state_pattern not in next_state_pattern_info:
                continue

            abs_next_action, next_pi_option = next_state_pattern_info.get(next_state_pattern).values()
            associations_list.append((state_pattern, action_pattern, pi_option, next_state_pattern, abs_next_action, next_pi_option))

        return associations_list

    def generate_pi_option_id(self, state, action):

        # state_pattern = self.abstract_state(state)
        # action_pattern = self.abstract_action(action)

        return uuid.UUID(int=self.random_generator.getrandbits(128), version=4).hex
