import os
import pickle
from datetime import datetime

from tqdm import tqdm

from pca.base_pca import PCA_N
from pca.pca import PCA
from pca.sparse_pca import SparsePCA
from symbols.file_utils import make_path, make_dir
from symbols.logger.precondition_logger import PreconditionLogger
from symbols.logger.transition_logger import TransitionLogger


def get_initial_states(id, input_dir, pca, num_episodes, verbose=True):
    if verbose:
        print("Loading data for task {}".format(id))

    # input_dir = make_path(input_dir, str(id))

    states = list()

    for episode in range(0, num_episodes):

        if verbose:
            print("Loading episode {}".format(episode))

        dir = make_path(make_path(input_dir, 'transition_data'), str(episode))
        with open(dir, 'rb') as file:
            samples = pickle.load(file)
            state, observation, id, obj, r, next_state, next_observation = samples[0]
            states.append((state, pca.compress(observation)))
    return states


def extract_data(id, input_dir, pca, output_directory, num_episodes, verbose=True):
    make_dir(output_directory, clean=False)

    if verbose:
        print("Loading data for task {}".format(id))

    transition_logger = TransitionLogger(output_directory)
    precondition_logger = PreconditionLogger(output_directory)

    for episode in range(0, num_episodes):

        if verbose:
            print("Loading episode {}".format(episode))

        dir = make_path(make_path(input_dir, 'init_set_data'), str(episode))
        with open(dir, 'rb') as file:
            samples = pickle.load(file)
            for state, observation, id, object_id, allowed in tqdm(samples):
                precondition_logger.log_sample(episode, state, pca.compress(observation), id, object_id, allowed)
        dir = make_path(make_path(input_dir, 'transition_data'), str(episode))
        with open(dir, 'rb') as file:
            samples = pickle.load(file)
            for state, observation, id, object_id, r, next_state, next_observation in tqdm(samples):
                transition_logger.log_sample(episode, state, pca.compress(observation), id, object_id, r, next_state,
                                             pca.compress(next_observation))

    precondition_logger.close()
    transition_logger.close()


if __name__ == '__main__':

    input_directory = '20191230_raw'
    dir_name = datetime.today().strftime('%Y%m%d')
    dir_name = '20191230'
    directory = make_path(dir_name)
    pca_path = os.path.abspath('pca_models/full_pca.dat')
    pca = PCA(PCA_N)
    pca.load(pca_path)

    n_episodes = 15

    for task in range(1):
        task_dir = make_path(directory, str(task))
        extract_data(task, make_path(input_directory, task), pca, task_dir, n_episodes, verbose=True)
