import gym
import numpy as np


class ActionSelectWrapper(gym.ActionWrapper):
    def action(self, action):
        return int(action)


def ident(arr):
    return 1 if len(arr) > 0 else 0


class EventWrapper(gym.Wrapper):
    EVENTS = {
        'crafter': ['collect_coal', 'collect_diamond', 'collect_drink', 'collect_iron', 'collect_sapling', 'collect_stone', 'collect_wood', 'defeat_skeleton', 'defeat_zombie', 'eat_cow', 'eat_plant', 'make_iron_pickaxe', 'make_iron_sword', 'make_stone_pickaxe', 'make_stone_sword', 'make_wood_pickaxe', 'make_wood_sword', 'place_furnace', 'place_plant', 'place_stone', 'place_table', 'wake_up']
    }
    def __init__(self, env: gym.Env, env_name='crafter'):
        super(EventWrapper, self).__init__(env)

        self.events = EventWrapper.EVENTS[env_name]
        self.n_events = len(self.events)
        self.observation_space = gym.spaces.Dict({
            'event': gym.spaces.Box(
                low=0,
                high=1,
                shape=(self.n_events,),
                dtype=np.uint8
            ),
            **env.observation_space
        })

    def reset(self):
        return {
            'event': np.zeros(self.n_events, dtype=np.uint8),
            **self.env.reset()
        }
    
    def step(self, action):
        observation, reward, done, info = self.env.step(action)
        event_encoding = np.zeros(self.n_events, dtype=np.uint8)
        for event in info['events']:
            event_encoding[self.events.index(event)] = 1
        return {
            'event': event_encoding,
            **observation
        }, reward, done, info


class ObjectiveWrapper(gym.Wrapper):
    def __init__(self, env: gym.Env, num_objectives, include_new_tasks=True, objective_selection=('random', {}), done_if_reward=False, seed=None):
        super(ObjectiveWrapper, self).__init__(env)
        self.seed(seed)
        # self._random = np.random.RandomState(self._seed)

        self._n = num_objectives
        self._num_objectives = num_objectives + 1 if include_new_tasks else num_objectives
        self._objective_selection = objective_selection
        self._include_new_tasks = include_new_tasks
        self._done_if_reward = done_if_reward
        self.observation_space = gym.spaces.Dict({
            "objective": gym.spaces.Box(
                low=0, 
                high=num_objectives,
                shape=(),
                dtype=np.uint8
            ),
            "completed": gym.spaces.Box(
                low=0, 
                high=1,
                shape=(self._num_objectives,),
                dtype=np.uint8
            ),
            "last_completed": gym.spaces.Box(
                low=0, 
                high=num_objectives,
                shape=(),
                dtype=np.uint8
            ),
            "regen_steps": gym.spaces.Box(
                low=0,
                high=1000000,
                shape=(),
                dtype=np.int32
            ),
            **env.observation_space.spaces
        })
        self._remaining = None
        self._next_data = None
        self._objective = None
        self._hidden_objective = None
        self._last_completed = None
        self._completed = np.zeros((self._num_objectives,), dtype=np.uint8)
        self._seq_pos = None
        self._seq = objective_selection[1]["seq"] if objective_selection[0] == 'seq' else None
        self._done = False
        self._regen_steps = 0   # not serialized

    def seed(self, _seed=None):
        seed = np.random.randint(0, 2 ** 31 - 1) if _seed is None else _seed
        self._seed = seed
        self._random = np.random.default_rng(seed)
        self.env.seed(abs(hash(f"objective-sub-env-{seed}")) % (2 ** 32))

    def _init(self):
        # self._completed.zero_()
        self._seq_pos = 0
        self._completed[:] = 0
        self._remaining = set(range(self._num_objectives))
        self._last_completed = self._num_objectives

    def _generate_objective(self):
        self._regen_steps = 0
        algo, params = self._objective_selection
        if algo == 'random':
            if len(self._remaining) > 0 and self._random.random() > -0.2:
                self._objective = self._random.choice(list(self._remaining))
            else:
                self._objective = self._random.integers(self._num_objectives)
        elif algo == 'graph':
            graph = params["graph"]
            # start = self._last_completed
            # # print("causal graph:", graph)
            # if start == self._num_objectives + 1:
            #     start = -1
            # # self._objective = self._random.choice(graph[start] + [self._num_objectives])
            # if start in graph and len(graph[start]) > 0:
            #     self._objective = self._random.choice(graph[start])
            # else:
            #     self._objective = self._random.integers(self._num_objectives)

            # completed = [i for i in range(self._num_objectives) if self._completed[i]] + [-1]


            # if self._hidden_objective is not None and not self._completed[self._hidden_objective]:
            #     v = self._hidden_objective
            #     ready = True
            #     for u in range(self._n):
            #         if graph[u][v] and not self._completed[u]:
            #             ready = False
            #             break
            #     if ready:

            # else:
            cont_choices = []
            jump_choices = []
            nghb_choices = []
            unrd_choices = []
            for u in range(self._n):
                if self._completed[u]:
                    continue
                ready = True
                for v in range(self._n):
                    if graph[v][u] and not self._completed[v]:
                        ready = False
                        break
                if ready:
                    if self._last_completed < self._n:
                        x = self._last_completed
                        if graph[x][u]:
                            cont_choices.append(u)
                        else:
                            found = False
                            for y in range(self._n):
                                if not self._completed[y] and graph[x][y] and graph[u][y]:
                                    found = True
                                    break
                            if found:
                                nghb_choices.append(u)
                            else:
                                jump_choices.append(u)
                    else:
                        jump_choices.append(u)
                else:
                    unrd_choices.append(u)
            # for u in completed:
            #     if u in graph:
            #         for v in graph[u]:
            #             if not self._completed[v] and not v in choices:
            #             # if not v in choices:
            #                 choices.append(v)
            if self._include_new_tasks and self._random.random() < 0.4:
                self._objective = self._num_objectives - 1
            else:
                lc = ident(cont_choices)
                ln = ident(nghb_choices)
                lj = ident(jump_choices)
                lu = ident(unrd_choices)
                wc = 80 * lc
                wn = 10 * ln
                wj = 10 * lj
                wu = 5 * lu
                if wc + wn + wj + wu == 0:
                    self._objective = self._random.integers(self._num_objectives)
                else:
                    p = np.array([wc, wn, wj, wu])
                    choices = [cont_choices, nghb_choices, jump_choices, unrd_choices]
                    while True:
                        c = self._random.choice(4, p=p / p.sum())
                        if len(choices[c]) > 0:
                            self._objective = self._random.choice(choices[c])
                            break

            # elif len(cont_choices) > 0 and (len(jump_choices) == 0 or self._random.random() < 0.8):
            #     self._objective = self._random.choice(cont_choices)
            # elif len(jump_choices) > 0:
            #     self._objective = self._random.choice(jump_choices)
            # else:
            #     self._objective = self._random.integers(self._num_objectives)


                
            # if len(choices) == 0:
            #     self._objective = self._random.integers(self._num_objectives)
            #     # self._objective = self._random.choice(graph[-1])
            #     # self._objective = 0
            # else:
            #     self._objective = self._random.choice(choices)
        elif algo == 'fixed':
            self._objective = params["objective"]
        elif algo == 'seq':
            self._objective = self._seq[min(self._seq_pos, len(self._seq) - 1)]
        else:
            raise NotImplementedError(f"Unknown objective selection algorihtm {algo}.")

    def _make_observation(self, observation):
        # print(self._objective, self._completed)
        return {
            "objective": np.array([self._objective], dtype=np.uint8),
            "completed": self._completed,
            "last_completed": np.array([self._last_completed], dtype=np.uint8),
            "regen_steps": np.array([self._regen_steps], dtype=np.int32),
            **observation
        }

    def reset(self):
        self._done = False
        if self._next_data is not None:
            self._deserialize(self._next_data)
        else:
            self._init()
        self._generate_objective()
        self._next_data = None
        return self._make_observation(self.env.reset())

    def step(self, action):
        observation, reward, done, info = self.env.step(action)
        self._regen_steps += 1
        info['reset_done'] = done
        if self._done_if_reward:
            # done = done or self._done
            # self._done = reward > 0.1
            done = done or (reward > 0.1)
        return self._make_observation(observation), reward, done, info

    def complete_objective(self, objective, regenerate=True):
        if objective < self._num_objectives:
            self._completed[objective] = 1
        self._last_completed = objective
        if self._seq is not None and self._seq_pos < len(self._seq) and objective == self._seq[self._seq_pos]:
            self._seq_pos += 1
        if objective in self._remaining:
            self._remaining.remove(objective)
        # if regenerate and objective == self._objective:
        if regenerate == "correct":
            if objective == self._objective:
                self._generate_objective()
        elif regenerate:
            self._generate_objective()

    def serialize(self):
        seq = self.env.serialize()
        return np.concatenate(([self._seq_pos, self._last_completed], self._completed, seq)).astype(np.uint8)

    def _deserialize(self, data):
        self._seq_pos = data[0]
        self._last_completed = data[1]
        self._completed[:] = data[2:]
        self._remaining = set()
        for i in range(self._num_objectives):
            if not self._completed[i]:
                self._remaining.add(i)

    def set_next_reset_config(self, name, data):
        self._next_data = data[:self._num_objectives + 2]
        self.env.set_next_reset_config(name, data[self._num_objectives + 2:])
