import ast
import random
import zipfile

import numpy as np
import re

from symbols.file_utils import make_path, exists, load
from symbols.logger.precondition_sample import PreconditionSample
import pandas as pd




class PreconditionReaderPD:
    """
    Reads the transition data for a give option from file
    """

    def _clean(self, dir):
        idx = dir.rfind('_')
        num = int(dir[idx + 1:])
        if num == 60:
            return dir
        return dir[0:idx+1] + '10'


    def __init__(self, base_data_dir, option, n_episodes, view='problem', max_samples=np.inf):
        """
        :param base_data_dir: the base output directory (usually 'output')
        :param option: the option
        """
        base_data_dir = self._clean(base_data_dir)
        self._view = view
        self._option = option
        self.episodes = n_episodes
        self.max_samples = max_samples
        self.file = make_path(base_data_dir, 'init.pkl')
        if not exists(self.file):
            raise FileNotFoundError

    def get_samples(self, as_list=False):
        """
        Return an array of precondition samples
        :return: a numpy array of samples
        """
        x = pd.read_pickle(self.file, compression='gzip')
        episodes = x['episode'].unique()[:self.episodes]
        np.random.shuffle(episodes)
        samples = list()
        for count, episode in enumerate(episodes):
            data = x.loc[(x['episode'] == episode) & (x['option'] == self._option)].reset_index(drop=True)
            for _, row in data.iterrows():  # slow. Don't care
                state = row['state']
                observation = row['object_state']
                object_id = row['object']
                option = row['option']
                if option != self._option:
                    raise ValueError
                label = row['can_execute']
                s = PreconditionSample(state, observation, self._option, object_id, label, view=self._view)
                s.episode = count
                samples.append(s)
                if len(samples) >= self.max_samples:
                    break
        if as_list:
            return samples
        return np.array(samples)

class PreconditionReader:
    """
    Reads the precondition data for a give option from file
    """

    def __init__(self, base_data_dir, option, view='problem', max_samples=np.inf):
        """
        :param base_data_dir: the base output directory (usually 'output')
        :param option: the option index
        """
        self._view = view
        self._option = option
        self.max_samples = max_samples


        self._filenames = list()
        file = make_path(base_data_dir, 'init_set_data', 'option_{}.dat'.format(option))
        if exists(file):
            self._filenames.append(file)

        # i = 0
        # while True:
        #     file = make_path(base_data_dir, 'init_set_data', 'option_({}, {}).dat'.format(option, i))
        #     if exists(file):
        #         self._filenames.append(file)
        #     else:
        #         if i > 10:
        #             break
        #     i = i + 1

    def get_samples(self):
        """
        Return an array of precondition samples
        :return: a numpy array of samples
        """
        samples = list()
        filenames = self._filenames.copy()
        random.shuffle(filenames)
        for filename in filenames:

            file = open(filename, 'r')
            for line in file:
                entries = str.split(line, ';')
                object_id = int(entries[2])
                state = self._to_array(entries[3])
                observation = self._to_array(entries[4])
                label = True if (entries[5].strip() == "True") else False
                s = PreconditionSample(state, observation, self._option, object_id, label, view=self._view)
                s.episode = int(entries[0])
                samples.append(s)
                if len(samples) >= self.max_samples:
                    break

            file.close()
        return np.array(samples)

    def _to_array(self, state_str):

        clean = re.sub('\[\s+', '[', state_str)
        clean = re.sub('\s+\]', ']', clean)
        clean = clean.split(',')
        state = np.array([np.array(ast.literal_eval(re.sub('\s+', ',', x.strip()))) for x in clean])
        return state
        #return np.concatenate(state).ravel()


class PreconditionReader2:
    """
    Reads the precondition data for a give option from file
    """

    def __init__(self, base_data_dir, option, view='problem'):
        """
        :param base_data_dir: the base output directory (usually 'output')
        :param option: the option index
        """
        self._view = view
        self._option = option
        self._filenames = list()
        file = make_path(base_data_dir, 'init_set_data', 'option_{}.dat'.format(option))
        if exists(file):
            self._filenames.append(file)

    def get_samples(self):
        """
        Return an array of precondition samples
        :return: a numpy array of samples
        """
        samples = list()
        for filename in self._filenames:
            data = load(filename)
            for (episode, option, object_id, state, observation, label) in data:
                s = PreconditionSample(state, observation, self._option, object_id, label, view=self._view)
                s.episode = episode
                samples.append(s)
        return np.array(samples)



class PreconditionReaderZ:
    """
    Reads the precondition data for a give option from file
    """

    def __init__(self, base_data_dir, option, view='problem'):
        """
        :param base_data_dir: the base output directory (usually 'output')
        :param option: the option index
        """
        self._view = view
        self._option = option


        self._filenames = list()
        file = make_path(base_data_dir, 'init_set_data', 'option_{}.dat'.format(option))
        if exists(file):
            self._filenames.append(file)

        # i = 0
        # while True:
        #     file = make_path(base_data_dir, 'init_set_data', 'option_({}, {}).dat'.format(option, i))
        #     if exists(file):
        #         self._filenames.append(file)
        #     else:
        #         if i > 10:
        #             break
        #     i = i + 1

    def get_samples(self):
        """
        Return an array of precondition samples
        :return: a numpy array of samples
        """
        samples = list()
        for filename in self._filenames:

            file = open(filename, 'r')
            for line in file:
                entries = str.split(line, ';')
                object_id = int(entries[2])
                state = self._to_array(entries[3])
                observation = self._to_array(entries[4])
                label = True if (entries[5].strip() == "True") else False
                s = PreconditionSample(state, observation, self._option, object_id, label, view=self._view)
                s.episode = int(entries[0])
                samples.append(s)
            file.close()
        return np.array(samples)

    def _to_array(self, state_str):

        clean = re.sub('\[\s+', '[', state_str)
        clean = re.sub('\s+\]', ']', clean)
        clean = clean.split(',')
        state = np.array([np.array(ast.literal_eval(re.sub('\s+', ',', x.strip()))) for x in clean])
        return state
        #return np.concatenate(state).ravel()