import random
from datetime import datetime
import numpy as np
import os

from collections import Iterable
from termcolor import colored
from tqdm import tqdm
import skimage.measure

from scipy.ndimage.filters import maximum_filter

from domain.levels import Task
from pca.base_pca import PCA_N
from pca.pca import PCA
from symbols.file_utils import make_path, make_dir
from symbols.logger.precondition_logger import PreconditionLogger
from symbols.logger.transition_logger import TransitionLogger
import matplotlib.pyplot as plt
import pickle

def printd(s):
    print(colored(s, 'red'))


def debug_show(state, pca):

    fig = plt.figure(figsize=(6, 4))
    fig.subplots_adjust(left=0, right=1, bottom=0, top=1, hspace=0.05, wspace=0.05)
    count = 0
    for i in range(state.shape[0] - 1):
        ax = fig.add_subplot(5, 5, count + 1, xticks=[], yticks=[])
        count += 1
        ax.imshow(state[i], cmap=plt.cm.bone, interpolation='nearest')

        ax = fig.add_subplot(5, 5, count + 1, xticks=[], yticks=[])
        count += 1

        y = pca.representation(state[i])
        # y = skimage.measure.block_reduce(state[i], (16, 16, 1), np.mean)

        ax.imshow(np.uint8(y), cmap=plt.cm.bone, interpolation='nearest')

    plt.show()


def explore(actions, visits):
    weights = [visits[action.id * 10 + action.object.id] for action in actions]
    weights = np.reciprocal(weights)
    weights = weights / np.sum(weights)
    axe = False
    for i, a in enumerate(actions):
        if 'axe' in str(a):
            weights[i] += 1000
            axe = True

    if axe:
        weights = weights / np.sum(weights)
    action = np.random.choice(actions, p=weights)
    visits[action.id * 10 + action.object.id] += 1
    return action


def collect_data(pca_path, seed, output_directory, num_episodes, max_episode_length, verbose=True, raw=False):

    make_dir(output_directory, clean=False)
    pca = PCA(PCA_N)
    pca.load(pca_path)

    if raw:
        compress = pca.shrink
    else:
        compress = pca.compress

    transition_logger = TransitionLogger(output_directory)
    precondition_logger = PreconditionLogger(output_directory)

    precondition_samples = list()
    transition_samples = list()

    env, _, _ = Task.generate(seed)

    for episode in range(0, num_episodes):
        if verbose:
            print('Running episode ' + str(episode) + '...')
        observation, doors, items, = env.reset(seed=seed)

        state = observation[-2:]  # state = inventory + xy
        observation = observation[0:-1]  # obs = images + inventory



        visits = [0.001] * 100

        for _ in range(0, max_episode_length):
            admissible_actions, disallowed = Task.admissable_actions(env, doors, items)

            # TRYING to ADD OBJECT ID!

            can_execute = dict()
            for option in disallowed:
                can_execute[(option.id, option.object.id)] = False
            for option in admissible_actions:
                can_execute[(option.id, option.object.id)] = True  # overwrite if necessary

            for (option_id, object_id), label in can_execute.items():
                precondition_samples.append((state, observation, option_id, object_id, label))

            action = explore(admissible_actions, visits)
            if verbose:
                printd(action)

            next_observation, reward, done, _ = action.execute(env)
            next_state = next_observation[-2:]
            next_observation = next_observation[0:-1]

            if True or not np.array_equal(state, next_state):
                transition_samples.append(
                    (state, observation, action.id, action.object.id, reward, next_state, next_observation))
            if done:
                break
            state = next_state
            observation = next_observation

        if verbose:
            print("Saving data")


        if raw:
            dir = make_path(precondition_logger.dir,  str(episode))
            with open(dir, 'wb') as file:
                pickle.dump(precondition_samples, file)
            dir = make_path(transition_logger.dir,  str(episode))
            with open(dir, 'wb') as file:
                pickle.dump(transition_samples, file)
        else:
            for state, observation, id, object_id, allowed in tqdm(precondition_samples):
                precondition_logger.log_sample(episode, state, compress(observation), id, object_id, allowed)
            if verbose:
                print()
            for state, observation, id, object_id, r, next_state, next_observation in tqdm(transition_samples):
                transition_logger.log_sample(episode, state, compress(observation), id, object_id, r, next_state,
                                             compress(next_observation))
        precondition_samples.clear()
        transition_samples.clear()

    precondition_logger.close()
    transition_logger.close()



if __name__ == '__main__':


    # print("Exiting since scared I run this by accident")
    # exit(0)

    TASK_ID = 0
    RAW = True
    seeds = [31, 33, 76, 82, 92]
    n_episodes = 15

    random.seed(seeds[TASK_ID])
    np.random.seed(seeds[TASK_ID])

    dir_name = datetime.today().strftime('%Y%m%d')
    if RAW:
        dir_name += '_raw'
    directory = make_path(dir_name)
    pca_path = os.path.abspath('pca_models/full_pca.dat')
    task_dir = make_path(directory, str(TASK_ID))
    collect_data(pca_path, seeds[TASK_ID], task_dir, n_episodes, 300, verbose=True, raw=RAW)

