import pickle
import random

import os
import warnings
from warnings import warn

import numpy as np

from symbols.domain.domain import Domain
from symbols.file_utils import make_path, make_dir
from symbols.logger.precondition_reader import PreconditionReader, PreconditionReaderPD
from symbols.logger.precondition_sample import PreconditionSample
from symbols.logger.transition_logger import TransitionLogger
from symbols.logger.transition_reader import TransitionReader, TransitionReaderPD
import pandas as pd


from symbols.logger.transition_sample import TransitionSample


def _uniform_sample_policy(domain: Domain):
    """
    Select an admissible action uniformly at random. If none exists, return None
    :param domain: the domain
    :return: an action selected uniformly randomly
    """
    return random.choice(domain.admissible_actions)


def load_precondition_samples(data_dir, option, object_id, label, max_samples, view='problem', verbose=False):
    if verbose:
        print("Loading {} samples...".format('negative' if label == 0 else 'positive'))
    reader = PreconditionReader(data_dir, option, view=view, max_samples=max_samples)
    samples = reader.get_samples()
    if label == 0:
        # load negative samples
        return np.array(
            [sample.state for sample in samples if sample.object_id == object_id and not sample.can_execute])
    else:
        # load positive samples
        return np.array([sample.state for sample in samples if sample.object_id == object_id and sample.can_execute])


def load_precondition_samples_pandas(data_dir, the_option, object_id, label, max_samples, view='problem', verbose=False):
    if verbose:
        print("Loading {} samples...".format('negative' if label == 0 else 'positive'))

    samples = list()
    data = pd.read_pickle(make_path(data_dir, 'init.pkl'), compression='gzip')
    data = data.loc[(data['object'] == object_id) & (data['option'] == the_option)].reset_index(drop=True)
    for _, row in data.iterrows():  # slow. Don't care
        state = row['state']
        observation = row['object_state']
        object_id = row['object']
        option = row['option']
        if option != the_option:
            continue
        label = row['can_execute']
        s = PreconditionSample(state, observation, option, object_id, label, view=view)
        samples.append(s)

    random.shuffle(samples)
    # reader = PreconditionReaderPD(data_dir, option, n_episodes,view=view, max_samples=max_samples)
    # samples = reader.get_samples()
    if label == 0:
        # load negative samples
        data = np.array(
            [sample.state for sample in samples if not sample.can_execute])
    else:
        # load positive samples
        data = np.array([sample.state for sample in samples if sample.can_execute])

    if len(data) > max_samples:
        data = data[:max_samples]
    return data

def load_full_transition_data(env,
                              directory,
                              view='problem',
                              verbose=False):
    """
    Load transition data from file
    :param env: the environment
    :param directory: the directory that contains the data
    :param verbose: whether to print information to screen
    :return: a list of all the transition samples
    """
    if verbose:
        print("Loading transition data...")
    samples = list()
    for option in env.action_space:
        transition_data = load_transition_data(option, directory, view=view, verbose=verbose)
        samples += list(transition_data)
    if verbose:
        print(str(len(samples)) + " samples loaded.")
    return samples


def load_full_transition_data_pandas(env,
                              directory,
                              view='problem',
                              verbose=False):
    """
    Load transition data from file
    :param env: the environment
    :param directory: the directory that contains the data
    :param verbose: whether to print information to screen
    :return: a list of all the transition samples
    """
    if verbose:
        print("Loading transition data...")

    samples = list()
    all_data = pd.read_pickle(make_path(directory, 'transition.pkl'), compression='gzip')
    for _, row in all_data.iterrows():  # slow. Don't care
        state = row['state']
        observation = row['object_state']
        object_id = row['object']
        option = row['option']
        reward = row['reward']
        next_state = row['next_state']
        next_observation = row['next_object_state']
        s = TransitionSample(state, observation, option, object_id, reward, next_state, next_observation,
                             view=view)
        if len(s.flat_mask) > 0:
            samples.append(s)
        else:
            warn("Dropping transition with empty mask")

    # samples = list()
    # for option in env.action_space:
    #     transition_data = load_transition_data_pandas(option, directory, n_episodes, view=view, verbose=verbose)
    #     samples += list(transition_data)
    if verbose:
        print(str(len(samples)) + " samples loaded.")
    return samples


def load_transition_data(option,
                         directory,
                         view='problem',
                         verbose=False):
    """
    Load transition data from file
    :param option: the option
    :param directory: the directory that contains the data
    :param verbose: whether to print information to screen
    :return: a list of all the transition samples
    """
    if verbose:
        print("Loading transition data...")

    # Load all the transition samples
    rd = TransitionReader(directory, option, view=view)
    samples = rd.get_samples()
    if verbose:
        print(str(len(samples)) + " samples loaded.")
    return samples


def load_transition_data_pandas(option, directory, n_episode, view='problem', verbose=False):
    if verbose:
        print("Loading transition data...")
    rd = TransitionReaderPD(directory, option, n_episode, view=view)
    samples = rd.get_samples()
    if verbose:
        print(str(len(samples)) + " samples loaded.")
    return samples


def load_first_transitions(directory, verbose=False):
    if verbose:
        print("Loading transition data...")

    x = pd.read_pickle(make_path(directory, 'transition.pkl'), compression='gzip')
    episodes = x['episode'].unique()[:1]
    samples = list()
    for episode in episodes:
        data = x.loc[(x['episode'] == episode)].reset_index(drop=True)
        for _, row in data.iterrows():  # slow. Don't care
            state = row['state']
            observation = row['object_state']
            return state, observation

    return None, None



def load_subset(directory, n_samples, verbose=False):
    if verbose:
        print("Loading transition data...")

    x = pd.read_pickle(make_path(directory, 'transition.pkl'), compression='gzip')
    episodes = x['episode'].unique()[:n_samples]

    x = x.loc[(x['episode'].isin(episodes))].reset_index(drop=True)

    y = pd.read_pickle(make_path(directory, 'init.pkl'), compression='gzip')
    y = y.loc[(y['episode'].isin(episodes))].reset_index(drop=True)
    return y, x



def write_partition_data(option,
                         partitioned_options,
                         output_directory,
                         verbose=False):
    """
    Write the partitioned data to file
    :param option: the option
    :param partitioned_options: the partitioned data
    :param output_directory: the directory to write the data to
    :param verbose: whether to print information to screen
    """

    if verbose:
        print(str(option) + ': ' + str(len(partitioned_options)) + ' partitions found')
    partition = 1
    for partitioned_symbol in partitioned_options:
        filename = make_path(output_directory, str(option) + '-' + str(partition) + '.dat')
        with open(filename, "wb") as file:
            if verbose:
                print("Saving to file " + filename + "...")
            pickle.dump(partitioned_symbol, file)
        partition += 1


def load_option_partitions(options,
                           directory):
    """
    Load the option partitions from file
    :param options: the option space
    :param directory: the directory containing the files
    :return: a list of the partitioned options
    """
    option_partitions = dict()
    for option in options:

        option_partitions[option] = list()
        i = 0
        partition_list = []
        while True:
            filename = make_path(directory, str(option) + "-" + str(i + 1) + ".dat")
            if not os.path.isfile(filename):
                break
            with open(filename, "rb") as file:
                partition_list.append(pickle.load(file))
            option_partitions[option] = partition_list
            i += 1
    return option_partitions


def load_option_subpartitions(options,
                              directory):
    """
    Load the option subpartitions from file
    :param options: the option space
    :param directory: the directory containing the files
    :return: a list of the partitioned options
    """
    option_partitions = dict()
    for option in options:
        i = 0
        dir = make_path(directory, 'option_' + str(option), 'partition_' + str(i))
        partitions = list()
        while os.path.exists(dir):
            j = 0
            subpartitions = list()
            filename = make_path(dir, str(option) + "-" + str(j + 1) + ".dat")
            while os.path.isfile(filename):
                with open(filename, "rb") as file:
                    subpartitions.append(pickle.load(file))
                j += 1
                filename = make_path(dir, str(option) + "-" + str(j + 1) + ".dat")

            partitions.append(subpartitions)

            i += 1
            dir = make_path(directory, 'option_' + str(option), 'partition_' + str(i))
        option_partitions[option] = partitions

    return option_partitions


def write_precondition_masks(option,
                             masks,
                             output_directory,
                             verbose=False):
    """
    Write the precondition masks for the partitions of a given option to file
    :param option: the option
    :param masks: the precondition masks for each partition of the option
    :param output_directory: the directory to write to
    :param verbose: whether to print information to screen
    """

    make_dir(output_directory, clean=False)
    for i in range(0, len(masks)):
        filename = make_path(output_directory, str(option) + "-" + str(i + 1) + ".mask")
        if verbose:
            print("Writing precondition mask for option " + str(option) + " partition " + str(i))
        with open(filename, "w") as file:
            for variable in masks[i]:
                file.write(str(variable) + "\n")
    if verbose:
        print("Done.")


def load_precondition_masks(output_directory,
                            option_partitions,
                            verbose=True):
    """
    Load the precondition masks for all option partitions
    :param output_directory: the output directory
    :param option_partitions: a list specifying the number of partitions for each option
    :return: a 2D list of masks, indexed by option and then partition number
    :param verbose: whether to print information to screen
    """

    masks = list()
    for option in range(0, len(option_partitions)):
        if verbose:
            print("Loading masks for option " + str(option))

        masks.append([])
        for partition in range(0, option_partitions[option]):
            mask = list()
            with open(make_path(output_directory, str(option) + "-" + str(partition + 1) + ".mask"), "rb") as file:
                for line in file:
                    if len(line) > 0:
                        mask.append(int(line))
            masks[option].append(mask)

    return masks


def assign_data_to_partitions(options,
                              transition_data,
                              init_data,
                              partitions,
                              verbose=True):
    """
    Assign data to each partition of each option. States assigned to a given partition are positive examples, and all
    other states (including those in which another partition of the same option could be executed) are negative ones.
    :param options: the possible options
    :param transition_data: the transition data
    :param init_data: the precondition data
    :param partitions: the partitioned options
    :param verbose: whether to print information to screen
    :return: for each options partition, a list of indices representing the states which belong to that partition,
    for transition and precondition data
    """
    transition_indices = []
    precondition_indices = []

    for option in options:

        if verbose:
            print("Assigning data to option " + str(option))

        transitions = transition_data[option]
        states = np.array([t.state for t in transitions])

        partitioned_symbols = partitions[option]
        num_partitions = len(partitioned_symbols)

        from symbolflow.learn.learn import build_knn_classifier  # local import to prevent circular dependency
        if num_partitions == 0:
            warnings.warn("No partition for option " + str(option))
            transition_indices.append([])
            precondition_indices.append([])
            continue
        nn_classifier = build_knn_classifier(partitioned_symbols)

        transition_list = [[] for _ in range(0, num_partitions)]
        pos_i_list = [[] for _ in range(0, num_partitions)]
        neg_i_list = [[] for _ in range(0, num_partitions)]

        if verbose:
            print("\tAssigning transition data")

        for x in range(0, states.shape[0]):
            p_spot = int(nn_classifier.predict(states[x, :].reshape(1, -1))[0])  # calculate which partition
            transition_list[p_spot].append(x)

        if verbose:
            print("\t\t" + str(states.shape[0]) + " assigned")

        transition_indices.append(transition_list)

        if verbose:
            print("\tAssigning precondition data")

        (states, labels) = init_data
        labels = labels[:, option]

        for x in range(0, states.shape[0]):

            if labels[x] == 0.0:
                for p in range(0, num_partitions):
                    neg_i_list[p].append(x)
            else:
                p_spot = int(nn_classifier.predict(states[x, :].reshape(1, -1))[0])  # calculate which partition
                for p in range(0, num_partitions):

                    if p == p_spot:
                        pos_i_list[p].append(x)
                    else:
                        neg_i_list[p].append(x)

        if verbose:
            print("\t\t" + str(states.shape[0]) + " assigned")

        precondition_indices.append([pos_i_list, neg_i_list])

    if verbose:
        print("Done.")

    # indices[option][partition num] -> indices of states
    return transition_indices, precondition_indices


def write_symbols(directory,
                  symbols,
                  verbose=True):
    """
    Write the learned symbols (preconditions and effects) to file
    :param directory: the directory to write to
    :param symbols: the learned symbols
    :param verbose: whether to print information to screen
    """
    make_dir(directory)
    if verbose:
        print("Writing symbols to file")
    for x, symbol in enumerate(symbols):
        filename = make_path(directory, "option-" + str(x + 1) + ".dat")
        with open(filename, "wb") as file:
            pickle.dump(symbol, file)
    if verbose:
        print("Done.")


def load_operators(directory,
                   verbose=True):
    """
    Read the learned symbols (preconditions and effects) from file
    :param directory: the directory to write to
    :param verbose: whether to print information to screen
    """
    if verbose:
        print("Reading operators from file...")

    symbols = list()
    for filename in os.listdir(directory):
        with open(make_path(directory, filename), "rb") as file:
            operator = pickle.load(file)
            if not isinstance(operator, str):
                symbols.append(operator)

    if verbose:
        print(str(len(symbols)) + " partitioned operators loaded")

    return symbols


def write_propositions(directory, symbol_list, verbose=True):
    """
    Write the created propositions to file
    :param directory: the directory to write data to
    :param symbol_list: the list of propositions
    :param verbose: whether to print information to screen:
    """
    make_dir(directory)
    # Save all symbols to disk.

    for i, sym in enumerate(symbol_list):

        if verbose:
            print("Writing symbol " + str(i) + " to file")

        # Symbol
        filename = make_path(directory, "symbol" + str(i) + ".dat")
        with open(filename, "wb") as file:
            pickle.dump(sym, file)

        # Mask
        filename = make_path(directory, "symbol" + str(i) + ".mask")
        with open(filename, "w") as file:
            for m in sym.mask:
                file.write(str(m) + "\n")


def write_schemata(directory,
                   schemata,
                   verbose=True):
    """
    Write the created schemata to file
    :param directory: the directory to write data to
    :param schemata: the list of schemata
    :param verbose: whether to print information to screen:
    """
    make_dir(directory, clean=False)
    if verbose:
        print("Writing schemata to file")
    for x, symbol in enumerate(schemata):
        filename = make_path(directory, "operator-" + str(x + 1) + ".schema")
        with open(filename, "wb") as file:
            pickle.dump(symbol, file)
    if verbose:
        print("Done.")


def load_schemata(directory,
                  verbose=True):
    """
    Read the schemata from file
    :param directory: the directory to write to
    :param verbose: whether to print information to screen
    """
    if verbose:
        print("Reading schemata from file...")

    schemata = dict()
    schemata_list = list()
    x = 0
    filename = make_path(directory, "operator-" + str(x + 1) + ".schema")
    while os.path.isfile(filename):
        with open(filename, "rb") as file:
            schema = pickle.load(file)
            if schema.option not in schemata:
                schemata[schema.option] = dict()

            if schema.partition not in schemata[schema.option]:
                schemata[schema.option][schema.partition] = list()

            schemata[schema.option][schema.partition].append(schema)
            schemata_list.append(schema)
        x += 1
        filename = make_path(directory, "operator-" + str(x + 1) + ".schema")

    if verbose:
        print(str(len(schemata)) + " operators loaded")

    return schemata, schemata_list


def load_ppddl_symbols(directory,
                       verbose=True):
    """
    Read the PPDDL symbols from file
    :param directory: the directory to read from
    :param verbose: whether to print information to screen
    """
    if verbose:
        print("Reading PPDDL symbols from file...")

    symbols = list()
    i = 0
    while True:
        filename = make_path(directory, "symbol" + str(i) + ".dat")
        if not os.path.isfile(filename):
            break

        with open(filename, "rb") as file:
            symbols.append(pickle.load(file))

        i += 1

    if verbose:
        print(str(len(symbols)) + " PPDDL symbols loaded")

    return symbols
