import time
from termcolor import colored

from collect_data import debug_state
from domain.levels import Task2, Task3, Task1, Task
import numpy as np


def printd(s):
    print(colored(s, 'red'))


def explore(actions, visits):
    weights = [visits[action.id * 10 + action.object.id] for action in actions]
    weights = np.reciprocal(weights)
    weights = weights / np.sum(weights)

    axe = False
    for i, a in enumerate(actions):
        if 'axe' in str(a):
            weights[i] += 1000
            axe = True

    if axe:
        weights = weights / np.sum(weights)

    action = np.random.choice(actions, p=weights)
    visits[action.id * 10 + action.object.id] += 1
    # printd(visits)
    return action


if __name__ == '__main__':

    # Task = Task1
    import pickle
    #
    # with open('a', 'rb') as file:
    #     obs = pickle.load(file)
    # debug_state('state', obs)
    # exit(0)
    seeds = [31, 33, 76, 82, 92]
    for seed in seeds:
        seed = 92
        env, doors, items = Task.generate(seed)
        obs = env.reset()

        # with open('a', 'wb') as file:
        #     pickle.dump(obs, file)
        #
        # debug_state('debug_{}.png'.format(seed), obs)
        exit(0)

    seed = 92

    env, doors, items = Task.generate(seed)
    obs = env.reset()

    visits = [0.001] * 100

    path = [3, 7, 4, 3, 7, 4, 3, 7, 4]
    for i in range(0, 100):
        # id = path[_]
        admissible_actions, disallowed = Task.admissable_actions(env, doors, items)
        can_execute = dict()

        # for a in admissible_actions:
        #     if a.id == id:
        #         action = a
        #         break

        Task.admissable_actions(env, doors, items)
        action = explore(admissible_actions, visits)
        printd(action)
        printd(i)
        next_observation, reward, done, _ = action.execute(env)
        # debug_state('debug_{}.png'.format(i), next_observation)
        if done:
            break
