import ast
from warnings import warn

import numpy as np

from symbols.file_utils import make_path, exists, load
from symbols.logger.transition_sample import TransitionSample
import re
import pandas as pd



class TransitionReader:
    """
    Reads the transition data for a give option from file
    """

    def __init__(self, base_data_dir, option, view='problem'):
        """
        :param base_data_dir: the base output directory (usually 'output')
        :param option: the option
        """
        self._view = view
        self._option = option

        self._filenames = list()
        i = 0
        file = make_path(base_data_dir, 'transition_data', 'option_{}.dat'.format(option))
        if exists(file):
            self._filenames.append(file)
            # while True:
            #     file = make_path(base_data_dir, 'transition_data', 'option_({}, {}).dat'.format(option, i))
            #     if exists(file):
            #         self._filenames.append(file)
            #     else:
            #         if i > 10:
            #             break
            #     i = i + 1

    def get_samples(self):
        """
        Return an array of precondition samples
        :return: a numpy array of samples
        """
        samples = list()
        for filename in self._filenames:
            file = open(filename, 'r')
            for line in file:
                entries = str.split(line, ';')

                object_id = int(entries[2])
                state = self._to_array(entries[3])
                observation = self._to_array(entries[4])
                reward = float(entries[5])
                next_state = self._to_array(entries[6])
                next_observation = self._to_array(entries[7])
                s = TransitionSample(state, observation, self._option, object_id, reward, next_state, next_observation,
                                     view=self._view)
                if len(s.flat_mask) > 0:
                    samples.append(s)
                else:
                    warn("Dropping transition with empty mask")
            file.close()
        return np.array(samples)

    def _to_array(self, state_str):

        clean = re.sub('\[\s+', '[', state_str)
        clean = re.sub('\s+\]', ']', clean)
        clean = clean.split(',')
        state = np.array([np.array(ast.literal_eval(re.sub('\s+', ',', x.strip()))) for x in clean])
        return state
        # return np.concatenate(state).ravel()


class TransitionReader2:
    """
    Reads the transition data for a give option from file
    """

    def __init__(self, base_data_dir, option, view='problem'):
        """
        :param base_data_dir: the base output directory (usually 'output')
        :param option: the option
        """
        self._view = view
        self._option = option

        self._filenames = list()
        file = make_path(base_data_dir, 'transition_data', 'option_{}.dat'.format(option))
        if exists(file):
            self._filenames.append(file)

    def get_samples(self):
        """
        Return an array of precondition samples
        :return: a numpy array of samples
        """
        samples = list()
        for filename in self._filenames:
            data = load(filename)
            for (episode, option, object_id, state, observation, reward, next_state, next_observation) in data:
                s = TransitionSample(state, observation, self._option, object_id, reward, next_state, next_observation,
                                     view=self._view)
                s.episode = episode
                if len(s.flat_mask) > 0:
                    samples.append(s)
                else:
                    warn("Dropping transition with empty mask")
        return np.array(samples)





class TransitionReaderPD:
    """
    Reads the transition data for a give option from file
    """

    def _clean(self, dir):
        idx = dir.rfind('_')
        num = int(dir[idx + 1:])
        if num == 60:
            return dir
        return dir[0:idx+1] + '10'  # 60


    def __init__(self, base_data_dir, option, n_episodes, view='problem'):
        """
        :param base_data_dir: the base output directory (usually 'output')
        :param option: the option
        """
        base_data_dir = self._clean(base_data_dir)
        self._view = view
        self._option = option
        self.episodes = n_episodes
        self.file = make_path(base_data_dir, 'transition.pkl')
        if not exists(self.file):
            raise FileNotFoundError

    def get_samples(self):
        """
        Return an array of precondition samples
        :return: a numpy array of samples
        """
        x = pd.read_pickle(self.file, compression='gzip')

        episodes = x['episode'].unique()[:self.episodes]
        samples = list()
        for episode in episodes:
            data = x.loc[(x['episode'] == episode) & (x['option'] == self._option)].reset_index(drop=True)
            for _, row in data.iterrows():  # slow. Don't care
                state = row['state']
                observation = row['object_state']
                object_id = row['object']
                option = row['option']
                if option != self._option:
                    raise ValueError
                reward = row['reward']
                next_state = row['next_state']
                next_observation = row['next_object_state']
                s = TransitionSample(state, observation, self._option, object_id, reward, next_state, next_observation,
                                     view=self._view)
                if len(s.flat_mask) > 0:
                    samples.append(s)
                else:
                    warn("Dropping transition with empty mask")
        return np.array(samples)