import zipfile
from collections import defaultdict
from zipfile import ZipFile
import os
from symbols.file_utils import make_path, make_dir, save
import numpy as np
import pandas as pd



class TransitionLoggerPD:
    """
    A class for logging transition samples.
    """

    def __init__(self, base_data_dir):

        base_data_dir = make_path(base_data_dir, 'transition_data')
        make_dir(base_data_dir)
        self.dir = base_data_dir
        self.transition_data = pd.DataFrame(
            columns=['episode', 'state', 'object_state', 'option', 'object', 'reward', 'next_state',
                     'next_object_state'])

    def close(self):
        self.transition_data.to_pickle('{}/transition.pkl'.format(self.dir), compression='gzip')

    def save(self):
        self.transition_data.to_pickle('{}/transition.pkl'.format(self.dir), compression='gzip')

    def log_sample(self,
                   episode,
                   state: np.ndarray,
                   observation: np.ndarray,
                   option,
                   object_id,
                   reward: float,
                   next_state: np.ndarray,
                   next_observation: np.ndarray, ):
        """
        Log a single state-option-reward-state sample to file.
        :param state: the state under consideration, represented as a feature vector
        :param option: the number of the option under consideration
        :param reward: the reward for executing the option in the given context
        :param next_state: the successor state, represented as a feature vector
        """
        self.transition_data.loc[len(self.transition_data)] = [episode, state, observation, option,
                                                     object_id, reward, next_state, next_observation]



class TransitionLogger:
    """
    A class for logging transition samples.
    """

    def __init__(self, base_data_dir):

        base_data_dir = make_path(base_data_dir, 'transition_data')
        make_dir(base_data_dir)
        self.dir = base_data_dir
        self._output_files = dict()

    def close(self):
        for key, file in self._output_files.items():
            file.close()

    def log_sample(self,
                   episode,
                   state: np.ndarray,
                   observation: np.ndarray,
                   option,
                   object_id,
                   reward: float,
                   next_state: np.ndarray,
                   next_observation: np.ndarray, ):
        """
        Log a single state-option-reward-state sample to file.
        :param state: the state under consideration, represented as a feature vector
        :param option: the number of the option under consideration
        :param reward: the reward for executing the option in the given context
        :param next_state: the successor state, represented as a feature vector
        """

        state = ', '.join(map(str, state)).replace("\n", "")
        observation = ', '.join(map(str, observation)).replace("\n", "")
        next_state = ', '.join(map(str, next_state)).replace("\n", "")
        next_observation = ', '.join(map(str, next_observation)).replace("\n", "")
        line = '{};{};{};{};{};{};{};{}\n'.format(episode, option, object_id, state, observation, reward,
                                               next_state, next_observation)

        if option not in self._output_files:
            filename = make_path(self.dir, 'option_' + str(option) + '.dat')
            fl = open(filename, 'w')
            self._output_files[option] = fl

        self._output_files[option].write(line)  # option is a tuple ; option[0] is the act


class TransitionLogger2:
    """
    A class for logging transition samples.
    """

    def __init__(self, base_data_dir):

        base_data_dir = make_path(base_data_dir, 'transition_data')
        make_dir(base_data_dir)
        self.dir = base_data_dir
        self._output = defaultdict(list)

    def close(self):
        for option, data in self._output.items():
            save(data, filename=make_path(self.dir, 'option_' + str(option) + '.dat'))


    def log_sample(self,
                   episode,
                   state: np.ndarray,
                   observation: np.ndarray,
                   option,
                   object_id,
                   reward: float,
                   next_state: np.ndarray,
                   next_observation: np.ndarray, ):
        """
        Log a single state-option-reward-state sample to file.
        :param state: the state under consideration, represented as a feature vector
        :param option: the number of the option under consideration
        :param reward: the reward for executing the option in the given context
        :param next_state: the successor state, represented as a feature vector
        """

        self._output[option].append((episode, option, object_id, state, observation, reward,
                                               next_state, next_observation))


class TransitionLoggerZ:
    """
    A class for logging transition samples.
    """

    def __init__(self, base_data_dir):

        base_data_dir = make_path(base_data_dir, 'transition_data')
        make_dir(base_data_dir)
        self.dir = base_data_dir
        self._output_files = dict()

    def close(self):
        for option, f in self._output_files.items():
            f.close()
            path = make_path(self.dir, 'option_' + str(option) + '.dat')
            name = make_path(self.dir, 'option_' + str(option) + '.zip')
            z = zipfile.ZipFile(name, "w")
            z.write(path, arcname='option_' + str(option) + '.dat')
            z.close()
            os.remove(path)

    def log_sample(self,
                   episode,
                   state: np.ndarray,
                   observation: np.ndarray,
                   option,
                   object_id,
                   reward: float,
                   next_state: np.ndarray,
                   next_observation: np.ndarray, ):
        """
        Log a single state-option-reward-state sample to file.
        :param state: the state under consideration, represented as a feature vector
        :param option: the number of the option under consideration
        :param reward: the reward for executing the option in the given context
        :param next_state: the successor state, represented as a feature vector
        """

        state = ', '.join(map(str, state)).replace("\n", "")
        observation = ', '.join(map(str, observation)).replace("\n", "")
        next_state = ', '.join(map(str, next_state)).replace("\n", "")
        next_observation = ', '.join(map(str, next_observation)).replace("\n", "")
        line = '{};{};{};{};{};{};{};{}\n'.format(episode, option, object_id, state, observation, reward,
                                               next_state, next_observation)

        if option not in self._output_files:
            filename = make_path(self.dir, 'option_' + str(option) + '.dat')
            fl = open(filename, 'w')
            self._output_files[option] = fl

        self._output_files[option].write(line)  # option is a tuple ; option[0] is the act



