import random
from collections import defaultdict

import cv2
import matplotlib.pyplot as plt
import numpy as np
import pickle
from datetime import datetime

import os
from PIL import Image

from experiment.collect import debug_show
from pca.base_pca import SIZE, PCA_N
from pca.pca import PCA
from pca.sparse_pca import SparsePCA

from symbols.file_utils import make_path



def shrink_gray(image):
    return cv2.resize(gray(image), SIZE, interpolation=cv2.INTER_AREA)


def extract_transitions():
    transition_data = list()
    dir_name = 'raw_data'
    directory = make_path(dir_name)
    for task_id in range(5):
        task_dir = make_path(directory, str(task_id))
        dir = make_path(task_dir, 'transition_data')

        for i in range(3):
            with open(make_path(dir, str(i)), 'rb') as file:
                transitions = pickle.load(file)
                for x in transitions:
                    transition_data.append((x[1], x[2], x[-1], task_id))
    return transition_data


def gray(rgb):
    return np.uint8(np.dot(rgb[..., :3], [0.299, 0.587, 0.114]))



def show(pca, task_id, option):
    dir_name = datetime.today().strftime('%Y%m%d') + "_raw"
    dir_name = "raw_data"
    dir_name = '20191119_raw'
    directory = make_path(dir_name)
    task_dir = make_path(directory, str(task_id))

    dir = make_path(task_dir, 'transition_data')

    obs = list()
    next_obs = list()

    for i in range(15):

        with open(make_path(dir, str(i)), 'rb') as file:
            transitions = pickle.load(file)
            for x in transitions:

                if x[2] == option:
                    obs.append(x[1])
                    next_obs.append(x[-1])


    #
    for i, (x, y) in enumerate(zip(obs, next_obs)):
        debug_show(x, pca)
        debug_show(y, pca)
        if i > 0:
            return


if __name__ == '__main__':

    # print("Extracting data...")
    #
    # transition_data = extract_transitions()
    # with open("states.dat", "wb") as file:
    #     pickle.dump(transition_data, file)
    #
    # with open("states.dat", "rb") as file:
    #     transition_data = pickle.load(file)
    #     print(len(transition_data))
    #
    # print("Making images...")
    # images = list()
    #
    # first = defaultdict(int)
    # items = defaultdict(int)
    # attacks = defaultdict(int)
    # chests = defaultdict(int)
    # doors = defaultdict(int)
    # toggle = defaultdict(int)
    #
    # for j, (state, action, next_state, task_id) in enumerate(transition_data):
    #     print(j)
    #
    #     if task_id not in first:
    #         for q in range(10):
    #             for k in range(9):
    #                 images.append(shrink_gray(state[k]))
    #                 images.append(shrink_gray(next_state[k]))
    #         first[task_id] += 1
    #
    #     if action == 0:
    #
    #         if items[task_id] == 0:
    #             for _ in range(3):
    #                 images.append(shrink_gray(next_state[0]))
    #             items[task_id] += 1
    #
    #     elif action == 1:
    #         if attacks[task_id] == 0:
    #             for _ in range(50):
    #                 images.append(shrink_gray(state[0]))
    #                 images.append(shrink_gray(next_state[0]))
    #                 images.append(shrink_gray(next_state[-2]))
    #                 images.append(shrink_gray(next_state[-3]))
    #             attacks[task_id] += 1
    #
    #     elif action == 7:
    #
    #         if chests[task_id] == 0:
    #             for _ in range(10):
    #                 images.append(shrink_gray(state[0]))
    #                 images.append(shrink_gray(next_state[0]))
    #                 images.append(shrink_gray(state[-4]))
    #                 images.append(shrink_gray(next_state[-4]))
    #             chests[task_id] += 1
    #
    #     elif action == 8:
    #         if toggle[task_id] == 0:
    #             for _ in range(15):
    #                 for k in range(9):
    #                     images.append(shrink_gray(state[k]))
    #                     images.append(shrink_gray(next_state[k]))
    #             toggle[task_id] += 1
    #
    #     elif action in [3, 4, 5]:
    #
    #         if doors[task_id] == 0:
    #             for _ in range(1):
    #                 images.append(shrink_gray(state[0]))
    #                 images.append(shrink_gray(next_state[0]))
    #             doors[task_id] += 1
    #     elif random.randint(0, 100) < 2:
    #         for k in range(9):
    #             images.append(shrink_gray(state[k]))
    #             images.append(shrink_gray(next_state[k]))
    #
    # print(first)
    # print(items)
    # print(attacks)
    # print(chests)
    # print(doors)
    #
    # images = np.array(images)
    # print(images.shape)
    #
    # np.save('reduced', images)
    x = np.load('reduced.npy')
    print(len(x))

    print("Training PCA...")

    # pca = PCA(n_components=PCA_N, whiten=True)
    pca = SparsePCA(n_components=PCA_N, normalise_images=True)
    image = x[:, :, :]
    image = np.reshape(image, (x.shape[0], image.shape[1] * image.shape[2]))
    print(image.shape)
    pca.fit(image)
    print("Saving...")
    pca_path = 'pca_models/sparse_pca_normalise_15_component.dat'
    pca.save(pca_path)

    # pca = SparsePCA(n_components=PCA_N)
    # pca.load(pca_path)

    for a in [8]:
        for i in range(5):
            print(a, i)
            show(pca, i, a)
