import pickle
import warnings
from collections import defaultdict

import h5py
import numpy as np

from data.mpec.memory import Memory
from data.mpec.utils.non_domination_util import non_dominated_mask, duplicated_mask
from utils.collections_util import dict_to_defaultdict, defaultdict_to_dict, entry_or_none
from utils.double_linked_list import DoubleLinkedList
from utils.experiment_util import get_max_trajectory_length


def update_q(info, rewards, next_info, alpha=0.1, gamma=0.99):

    q_sa = info.get('q_value', 0)
    q_s_next = [actions_info.get('q_value') for _, actions_info in next_info.items()]

    g = sum(gamma ** i * rewards[i] for i in range(len(rewards))) + gamma ** len(rewards) * max(q_s_next)
    q_sa += alpha * (g - q_sa)
    return q_sa

class QMemory(Memory):

    def __init__(self, policy_domination_decimal_places=None, max_trajectory_length=None, *args, **kwargs):
        if policy_domination_decimal_places is not None:
            policy_domination_decimal_places = np.array(policy_domination_decimal_places)
        self._policy_domination_decimal_places = policy_domination_decimal_places

        self.is_debug = kwargs['debug']
        self._debug_track_trajectories_length = kwargs['debug_track_trajectories_length']
        self._debug_disable_trajectory_length = kwargs['debug_disable_trajectory_length']
        self._debug_disable_average_reward = kwargs['debug_disable_average_reward']
        self.debug_trajectory_length_tracking = {}
        self.debug_learning_rate = kwargs['debug_learning_rate']
        self.debug_discount_factor = kwargs['debug_discount_factor']

        super().__init__(*args, **kwargs)

        if max_trajectory_length is None:
            self._max_trajectory_length = get_max_trajectory_length(self.gamma)
            warnings.warn(f"Gamma {self.gamma} sets max trajectory length to {self._max_trajectory_length}")
        else:
            self._max_trajectory_length = max_trajectory_length

    def setup(self):
        super().setup()
        self.info = defaultdict(lambda: defaultdict(lambda: defaultdict(dict)))
        self.api_list_info = defaultdict(DoubleLinkedList)

    def save(self, env_name):
        super().save(env_name)
        with h5py.File(f"{env_name}.h5", 'a') as f:
            data = pickle.dumps(defaultdict_to_dict(self.api_list_info))
            f.create_dataset('api_list_info', data=np.void(data))

    def load(self, env_name):
        super().load(env_name)
        factory_list = [defaultdict, defaultdict, defaultdict, dict]
        self.info = dict_to_defaultdict(self.info, factory_list)

        with h5py.File(f"{env_name}.h5", 'r') as f:
            self.api_list_info.update(pickle.loads(bytes(f['api_list_info'][()])))
        factory_list = [defaultdict, DoubleLinkedList]
        self.api_list_info = dict_to_defaultdict(self.api_list_info, factory_list)

    def add(self, state, action, pi_option, next_state, reward):

        assert self.step == 1, "Not implemented yet"

        state_pattern = self.abstract_state(state)
        action_pattern = None
        if action is not None:
            action_pattern = self.abstract_action(action)
        next_state_pattern = None
        if next_state is not None:
            next_state_pattern = self.abstract_state(next_state)
        reward = np.array([reward])

        assert self.info.get(next_state_pattern, None) is None

        info = self.info[state_pattern][action_pattern][pi_option]

        info['trajectory_length'] = 1 if next_state_pattern is not None else 0

        info['reward_sum'] = reward
        info['reward_count'] = 1
        average_reward = info['reward_sum'] / info['reward_count']

        info['average_return'] = average_reward

        old_q_sa = info.get('q_value', None)
        info['q_value'] = info['average_return'] * info['trajectory_length']
        info['q_value'] = np.append(info['q_value'], -info['trajectory_length'])

        if self.is_debug and self._debug_disable_average_reward:
            info['debug_q_value'] = self.debug_learning_rate * reward
            info['debug_q_value'] = np.append(info['debug_q_value'], -info['trajectory_length'])

        if np.any(old_q_sa != info['q_value']) or action_pattern is None:
            api_list_ref = info.get('api_list_ref', None)
            node = self.api_list_info[state_pattern].append((action_pattern, pi_option))
            info['api_list_ref'] = node.id
            if api_list_ref is not None:
                self.api_list_info[state_pattern].pop(id_=api_list_ref)

        if self.is_debug and self._debug_track_trajectories_length:
            self.debug_trajectory_length_tracking[(state_pattern, action_pattern, pi_option)] = 1

    def update(self, state_pattern, action_pattern, pi_option,
               next_state_pattern, next_action_pattern, next_pi_option, reward):

        reward = np.array([reward])

        assert self.step == 1, "Not implemented yet"
        assert len(reward) == 1

        if next_state_pattern is None:
            next_info = {
                'average_return': 0,
                'trajectory_length': 0
            }
        else:
            next_info = self.info.get(next_state_pattern).get(next_action_pattern).get(next_pi_option)

        next_length = next_info['trajectory_length']

        length = next_length + 1
        if length > self._max_trajectory_length:
            return False

        if self.is_debug and self._debug_track_trajectories_length:
            if entry_or_none(self.info, [state_pattern, action_pattern, pi_option]) is None:
                self.debug_trajectory_length_tracking[(state_pattern, action_pattern, pi_option)] = length
                self.debug_trajectory_length_tracking.pop((next_state_pattern, next_action_pattern, next_pi_option), None)
            else:
                if self.debug_trajectory_length_tracking.get((state_pattern, action_pattern, pi_option), None) is not None:
                    self.debug_trajectory_length_tracking[(state_pattern, action_pattern, pi_option)] = length

        info = self.info[state_pattern][action_pattern][pi_option]

        next_return = next_info['average_return']
        next_length = next_info['trajectory_length']

        info['trajectory_length'] = next_length + 1

        info['reward_sum'] = info.get('reward_sum', 0) + reward
        info['reward_count'] = info.get('reward_count', 0) + 1
        average_reward = info['reward_sum'] / info['reward_count']

        info['average_return'] = next_return + (average_reward - next_return) / (next_length + 1)

        old_q_sa = info.get('q_value', 0)
        info['q_value'] = info['average_return'] * info['trajectory_length']
        info['q_value'] = np.append(info['q_value'], -info['trajectory_length'])

        if self.is_debug and self._debug_disable_average_reward:
            q_value = info.get('debug_q_value', None)
            next_q_value = next_info.get('debug_q_value', None)
            q_value = info['debug_q_value'] = info['debug_q_value'][:-1] if q_value is not None else 0
            next_q_value = next_q_value[:-1] if next_q_value is not None else 0

            info['debug_q_value'] = q_value + self.debug_learning_rate * (reward + self.debug_discount_factor * next_q_value - q_value)
            info['debug_q_value'] = np.append(info['debug_q_value'], -info['trajectory_length'])

        if np.any(old_q_sa != info['q_value']):
            api_list_ref = info.get('api_list_ref', None)
            node = self.api_list_info[state_pattern].append((action_pattern, pi_option))
            info['api_list_ref'] = node.id
            if api_list_ref is not None:
                self.api_list_info[state_pattern].pop(id_=api_list_ref)

        return True

    def get(self, state, action, pi_option=None):
        state_pattern = self.abstract_state(state)
        action_pattern = self.abstract_action(action)

        try:
            q_value = self.info.get(state_pattern).get(action_pattern).get(pi_option)
            return q_value
        except AttributeError:
            return None

    def fetch_non_dominated_set(self, state, ignore_trajectory_length=False, use_average=False):

        state_pattern = self.abstract_state(state)

        if self.is_debug and self._debug_disable_trajectory_length:
            ignore_trajectory_length = True

        sapi_list = []
        q_value_list = []
        action_info = self.info.get(state_pattern, dict())
        api_info = self.api_list_info.get(state_pattern, DoubleLinkedList())
        for _, (action_pattern, pi_option) in api_info:
            sapi_list.append((state_pattern, action_pattern, pi_option))
            if use_average:
                q_value = np.concatenate([action_info[action_pattern][pi_option]['average_return'][0], [-1]])
            else:
                q_value = action_info[action_pattern][pi_option]['q_value']
                if self.is_debug and self._debug_disable_average_reward:
                    q_value = action_info[action_pattern][pi_option]['debug_q_value']
            if ignore_trajectory_length:
                q_value = q_value[:-1]
            q_value_list.append(q_value)

        sapi_list = np.array(sapi_list)
        q_value_list = np.array(q_value_list)

        if len(sapi_list) == 0:
            return sapi_list, q_value_list

        q_values = q_value_list
        if not ignore_trajectory_length:
            q_values = q_value_list.copy()
            e = 1e-8
            q_values += max(0, -q_values.min()) + e

        non_dominated_indices = non_dominated_mask(q_values, self._policy_domination_decimal_places)
        duplicated_indices = duplicated_mask(q_values, self._policy_domination_decimal_places)
        non_dominated_set = sapi_list[non_dominated_indices & ~duplicated_indices]
        q_value_list = q_value_list[non_dominated_indices & ~duplicated_indices]

        return non_dominated_set, q_value_list

    def fetch_pi_options(self, state, action):
        state_pattern = self.abstract_state(state)
        action_pattern = self.abstract_action(action)

        try:
            pi_options = self.info.get(state_pattern).get(action_pattern)
            return list(pi_options.keys())
        except AttributeError:
            return []

    def delete_dominated_sets(self, state, action):
        state_pattern = self.abstract_state(state)
        action_pattern = self.abstract_action(action)

        ignore_trajectory_length = False
        if self.is_debug and self._debug_disable_trajectory_length:
            ignore_trajectory_length = True

        sapi_list = []
        q_value_list = []
        action_info = self.info.get(state_pattern, dict())
        pi_option_info = action_info.get(action_pattern, dict())
        for pi_option, q_value_dict in pi_option_info.items():
            sapi_list.append((state_pattern, action_pattern, pi_option))
            q_value = q_value_dict['q_value']
            if self.is_debug and self._debug_disable_average_reward:
                q_value = q_value_dict['debug_q_value']
            q_value_list.append(q_value)

        sapi_list = np.array(sapi_list)
        q_value_list = np.array(q_value_list)

        if len(sapi_list) == 0:
            return []

        q_values = q_value_list
        if not ignore_trajectory_length:
            q_values = q_value_list.copy()
            e = 1e-8
            q_values += max(0, -q_values.min()) + e

        non_dominated_indices = non_dominated_mask(q_values, self._policy_domination_decimal_places)
        duplicated_indices = duplicated_mask(q_values, self._policy_domination_decimal_places)
        dominated_set = sapi_list[~non_dominated_indices|duplicated_indices]

        self._delete(dominated_set)
        return dominated_set

    def _delete(self, to_delete_list):
        for (state_pattern, action_pattern, pi_option) in to_delete_list:
            info = self.info.get(state_pattern).get(action_pattern).get(pi_option)
            node_to_remove_id = info.get('api_list_ref')
            self.api_list_info[state_pattern].pop(id_=node_to_remove_id)

            self.info.get(state_pattern).get(action_pattern).pop(pi_option)
            if len(self.info[state_pattern][action_pattern]) == 0:
                self.info[state_pattern].pop(action_pattern)
            if len(self.info[state_pattern]) == 0:
                self.info.pop(state_pattern)
