import pickle

import h5py
import numpy as np

from data.mpec.utils.abstraction_mode import *
from data.necsa.interfaces import Grid
from utils.collections_util import defaultdict_to_dict, dict_to_defaultdict


class Memory:

    def __init__(
            self, step, grid_num, raw_state_dim, state_dim, state_min, state_max, action_dim, action_min, action_max,
            mode, reduction, lr, gamma, *args, **kwargs):

        self.step = step
        self.grid_num = grid_num
        self.raw_state_dim = raw_state_dim
        self.state_dim = state_dim
        self.state_min = state_min
        self.state_max = state_max
        self.action_dim = action_dim
        self.action_min = action_min
        self.action_max = action_max
        self.mode = mode
        self.reduction = reduction

        self.lr = lr
        self.gamma = gamma

        self.setup()

        if self.mode in [HS_A, HS_HA]:
            raise NotImplemented(f"{HS_A} and {HS_HA} are not implemented yet")


    def setup(self):

        self.min_state = np.array([self.state_min for _ in range(self.raw_state_dim)])
        self.max_state = np.array([self.state_max for _ in range(self.raw_state_dim)])
        self.min_action = np.array([self.action_min for _ in range(self.action_dim)])
        self.max_action = np.array([self.action_max for _ in range(self.action_dim)])

        if self.reduction:
            self.setup_projection_matrix()
            self.min_state = np.dot(self.min_state, self.state_project_matrix)
            self.max_state = np.dot(self.max_state, self.state_project_matrix)
            self.min_action = np.dot(self.min_action, self.action_project_matrix)
            self.max_action = np.dot(self.max_action, self.action_project_matrix)

        self.info = dict()

        self.state_grid = Grid(self.min_state, self.max_state, self.grid_num)
        self.action_grid = Grid(self.min_action, self.max_action, self.grid_num)

    def setup_projection_matrix(self):
        self.state_project_matrix = np.random.uniform(0, 0.1, (self.raw_state_dim, self.state_dim))
        if self.mode == HS:
            self.state_project_matrix = np.random.uniform(-1, 1, (self.raw_state_dim, self.state_dim))
        self.action_project_matrix = np.random.uniform(0, 0.1, (self.action_dim, self.action_dim))

    def save(self, env_name):
        with h5py.File(f"{env_name}.h5", 'w') as f:
            data = pickle.dumps(defaultdict_to_dict(self.info))
            f.create_dataset('info', data=np.void(data))

    def load(self, env_name):
        with h5py.File(f"{env_name}.h5", 'r') as f:
            self.info.update(pickle.loads(bytes(f['info'][()])))

    def abstract_state(self, state):
        if self.reduction:
            state = self._dim_reduction(state)

        state = np.array([state])
        abs_state = self.state_grid.state_abstract(state)
        abs_state = '-'.join(abs_state)
        return abs_state

    def abstract_action(self, action):
        action = np.array([action])
        abs_action = self.action_grid.state_abstract(action)
        abs_action = '-'.join(abs_action)
        return abs_action

    def _dim_reduction(self, state):
        small_state = np.dot(state, self.state_project_matrix)
        return  small_state
