from sklearn.decomposition import IncrementalPCA

import numpy as np
import pickle



class PCA:
    def __init__(self, n_components, batch_size=20, whiten=False):

        self._pca = IncrementalPCA(n_components=n_components, batch_size=batch_size, copy=False, whiten=whiten)
        self.first_fit = True
        self.from_file = False

        self._n_components = n_components
        self._components = None
        self._explained_variance_ratio = None
        self._explained_variance = None
        self._whiten = whiten
        self._mean = None

    def fit_transitions(self, episodes):
        print("HERE")
        data = list()
        for episode in episodes:
            data.append(self.extract_(episode))
        print("Extracted data")
        X = np.vstack(tuple(data))
        print(X.shape)
        print("Fitting data")
        self.fit(X)
        print("Data fitted")
        print(self.explained_variance_ratio)

    def fit(self, X):
        if self.first_fit:
            self.first_fit = False
            self._pca.fit(X)
        else:
            self._pca.partial_fit(X)

    @property
    def whiten(self):
        return self._whiten

    @property
    def n_components(self):
        if self.from_file:
            return self._n_components
        return self._pca.n_components

    @property
    def explained_variance_ratio(self):
        if self.from_file:
            return self._explained_variance_ratio
        return self._pca.explained_variance_ratio_

    @property
    def explained_variance(self):
        if self.from_file:
            return self._explained_variance
        return self._pca.explained_variance_

    @property
    def components(self):
        if self.from_file:
            return self._components
        return self._pca.components_

    @property
    def mean(self):
        if self.from_file:
            return self._mean
        return self._pca.mean_

    def compress_(self, image):
        image = self.flat_gray(image)
        X = image - self.mean
        X_transformed = np.dot(X, self.components.T)
        if self.whiten:
            X_transformed /= np.sqrt(self.explained_variance)
        return X_transformed

    def uncompress_(self, image):

        if self.whiten:
            return np.dot(image, np.sqrt(self.explained_variance[:, np.newaxis]) * self.components) + self.mean
        else:
            return np.dot(image, self.components) + self.mean

    def compress(self, state, flatten=False):
        n_non_images = 2
        n_non_images = 1
        compress = list()
        for i in range(state.shape[0] - n_non_images):
            compress.append(self.compress_(state[i]))

        for i in range(-n_non_images, 0, 1):
            compress.append(state[i])
        compressed = np.array(compress)
        if flatten:
            return np.concatenate(compressed).ravel()
        return compressed

    def uncompress(self, compressed_state):
        n_non_images = 2
        n_non_images = 1
        uncompress = list()
        for i in range(compressed_state.shape[0] - n_non_images):
            uncompress.append(self.unflatten(self.uncompress_(compressed_state[i])))
        for i in range(-n_non_images, 0, 1):
            uncompress.append(compressed_state[i])
        return np.array(uncompress)

    def save(self, filename):
        with open(filename, 'wb') as file:
            pickle.dump((self.n_components, self.explained_variance_ratio, self.components, self.mean, self.first_fit,
                         self._pca.explained_variance_, self.whiten),
                        file)

    def load(self, filename):
        with open(filename, 'rb') as file:
            self._n_components, self._explained_variance_ratio, self._components, self._mean, self.first_fit, \
            self._explained_variance, self._whiten = pickle.load(file)

            self.from_file = True

    def rgb2gray_(self, rgb):
        return np.uint8(np.dot(rgb[..., :3], [0.299, 0.587, 0.114]))

    def unflatten(self, image):
        return np.reshape(image, (120, 160))

    def flat_gray(self, image):
        return np.reshape(self.rgb2gray_(image), (image.shape[0] * image.shape[1]))

    def add_state_(self, X, idx, state):
        for i in range(state.shape[0] - 2):  # -2 because last one xy and second last is inventory
            X[idx, :] = np.reshape(self.rgb2gray_(state[i]), (state[i].shape[0] * state[i].shape[1]))
            idx += 1
        return idx

    def extract_(self, trajectory):
        first = True
        j = 0
        for state, action, reward, next_state in trajectory:

            if first:
                first = False
                # -2 because the last is the position and second last is inventory
                X = np.zeros(
                    shape=((len(trajectory) + 1) * (state.shape[0] - 2), state[0].shape[0] * state[0].shape[1]))
                j = self.add_state_(X, j, state)
            j = self.add_state_(X, j, next_state)
        return X


if __name__ == '__main__':
    with open('transitions.dat.old', 'rb') as file:
        transitions = pickle.load(file)
    pca = PCA(20, whiten=True)
    # pca.fit_transitions(transitions[0:1])
    # pca.save("temp.dat")
    pca.load("temp.dat")
    x = transitions[0][0][0]
    del transitions
    debug_state("test.png", x)
    y = pca.compress(x, flatten=False)
    print(y)
    y_prime = pca.uncompress(y)
    debug_state("recon", y_prime)
