from tqdm import tqdm

from symbols.file_utils import make_path
from symbols.logger.precondition_reader import PreconditionReader
from symbols.logger.transition_reader import TransitionReader
import pandas as pd
import numpy as np


def convert(state):
    object_state = list()
    for i, x in enumerate(state):
        if i != 9:
            object_state.append(x)
        else:
            for y in x:
                temp = np.array([y])
                object_state.append(temp)
    object_state = np.array(object_state)
    return object_state


directory = '20200106'

task_id = 0
task_dir = make_path(directory, task_id)

transition_data = pd.DataFrame(
    columns=['episode', 'state', 'object_state', 'option', 'object', 'reward', 'next_state', 'next_object_state',
             'done', 'goal_achieved', 'mask', 'object_mask', 'next_options'])
initiation_data = pd.DataFrame(columns=['state', 'object_state', 'option', 'object', 'can_execute'])

for option in range(9):
    print(option)
    reader = PreconditionReader(task_dir, option, view='agent')
    samples = reader.get_samples()

    for sample in tqdm(samples):
        state = sample.observation[1]
        object_state = convert(sample.state)
        option = sample.option
        object = sample.object_id
        can_execute = sample.can_execute
        initiation_data.loc[len(initiation_data)] = [state, object_state, option,
                                                     object, can_execute]

    rd = TransitionReader(task_dir, option, view='agent')
    samples = rd.get_samples()
    for sample in tqdm(samples):
        s = 1
        episode = 0
        state = sample.observation[1]
        object_state = convert(sample.state)
        option = sample.option
        object = sample.object_id
        reward = sample.reward
        next_state = sample.next_observation[1]
        next_object_state = convert(sample.next_state)
        done = False
        success = False

        object_mask = np.array([j for j in range(0, len(object_state)) if not np.array_equal(object_state[j], next_object_state[j])])
        if not np.array_equal(sample.mask, object_mask):
            p = 0
            print("AHAHAH")
            print(sample.mask, object_mask)


            # print(object_state)
            # print(next_object_state)
        mask = []
        for i in range(len(state)):
            if state[i] != next_state[i]:
                mask.append(i)
        transition_data.loc[len(transition_data)] = [episode, state, object_state, option,
                                                     object, reward, next_state, next_object_state,
                                                     done, success, mask,
                                                     object_mask, []]

transition_data.to_pickle('transition.pkl', compression='gzip')
initiation_data.to_pickle('init.pkl', compression='gzip')
