import warnings
from collections import defaultdict

import itertools

import numpy as np
from symbols.data.data import load_transition_data
from symbols.data.partitioned_option import PartitionedOption
from symbols.file_utils import load, make_path
from symbols.logger.transition_sample import TransitionSample
from symbols.pddl.typed_description import Predicate
from symbols.render.image import Image
from symbols.transfer.transferable_operator import TransferableOperator


def show(effect, render):
    d = render(effect)
    images = list()
    for m, im in d.items():
        im = Image.merge(im)
        if len(im.shape) == 3:
            image = Image.to_image(im, mode='RGB')
        else:
            image = Image.to_image(im)
        images.append(image)
    im = Image.combine(images)
    im.show()


# Object type profile
class Profile:

    def __init__(self, object):
        self.object = object
        self.effects = defaultdict(list)

    def update(self, option, next_state):
        self.effects[option].append(next_state)

    def get(self, option):
        if option not in self.effects:
            return None
        effects = self.effects[option]

        none_length = len([x for x in effects if x is None])

        if none_length == len(effects):
            return None
        return effects

    def compare(self, option_effects: list, **kwargs):
        entered = False
        score = list()
        for option, effects in enumerate(option_effects):
            data = self.get(option)
            if data is None:
                if effects is not None:
                    score.append(0)
                else:
                    score.append(1)
            elif effects is None:
                score.append(0)
            else:
                entered = True
                if 'render' in kwargs:
                    kwargs['render'](data)

                ave_prob = 0
                for state in data:
                    m_prob = 0
                    for effect in effects:
                        if len(effect.effects) > 1:
                            raise ValueError
                        p, eff = effect.effects[0]

                        if len(eff) > 1:
                            raise ValueError
                        eff = eff[0]
                        try:
                            logit = np.exp(eff.score_samples(state.reshape(1, -1))[0])
                            # rescale
                            prob = logit / (1 + logit)
                        except ValueError as e:
                            # wrong dimensions! jsut return 0
                            # prob = 0
                            return 0
                        m_prob = max(m_prob, prob)
                    ave_prob += m_prob
                ave_prob /= len(data)
                score.append(ave_prob)
        if not entered:
            # we never even evaluated a single thing!!
            return 0
        return np.mean(score)


class KnowledgeBase:

    def __init__(self):
        self.types = dict()  # map task + object -> type
        self.predicates = list()
        self.operators = list()
        self.partitioned_options = dict()  # map task + option -> partitions
        self.type_effects = dict()  # map from type to list of effects!
        self.all_types = set()
        self.operator_data = defaultdict(list)

    def add_operator_data(self, task_id, operator_data):
        self.operator_data[task_id].append(operator_data)

    def add_partitioned_option(self, task_id, partitioned_option: PartitionedOption):
        option= partitioned_option.option
        partition = partitioned_option.partition
        if task_id not in self.partitioned_options:
            self.partitioned_options[task_id] = defaultdict(list)
        self.partitioned_options[task_id][option].append(partition)

    def get_partitioned_option(self, task_id, option, partition=None):
        if task_id not in self.partitioned_options or option not in self.partitioned_options[task_id]:
            return None
        partitions = self.partitioned_options[task_id][option]
        if partition is None:
            return partitions
        for p in partitions:
            if p.partition == partition:
                return p
        return None

    def add_predicate(self, predicate: Predicate):
        self.predicates.append(predicate)

    def add_operator(self, operator: TransferableOperator):
        self.operators.append(operator)

    def _merge(self, current, new):

        for i, (a, b) in enumerate(zip(current, new)):
            if a is None:
                if b is not None:
                    raise ValueError
            elif b is None:
                raise ValueError
            else:
                current[i] += b
        return current

    def set_type(self, task: int, object_idx: int, type: int, effects=None):
        if task not in self.types:
            self.types[task] = dict()
        if object_idx in self.types[task]:
            warnings.warn("Overwriting object type!")
        self.types[task][object_idx] = type
        if effects is not None:
            if type in self.type_effects:
                self.type_effects[type] = self._merge(self.type_effects[type], effects)
                # self.type_effects[type] += effects
            else:
                self.type_effects[type] = effects
        self.all_types.add(type)

    def get_type(self, task: int, object_idx: int):
        """
        Get the type of object
        :param task: the current task ID
        :param object_idx: the index of the object
        :return: the object's type (as an int)
        """
        if task not in self.types:
            self.types[task] = dict()
        if object_idx not in self.types[task]:
            self.types[task][object_idx] = None
        return self.types[task][object_idx]

    def n_objects(self, task_id):
        if task_id not in self.types:
            return None  # haven't seen yet, so don't know
        return len(self.types[task_id])

    def _compare(self, data, effects):

        if effects is None:

            if data is None:
                return 0
            return -1

        m_prob = 0
        for effect in effects:
            if len(effect.effects) > 1:
                raise ValueError
            p, eff = effect.effects[0]

            if len(eff) > 1:
                raise ValueError
            eff = eff[0]
            if data is None:
                prob = 0
            else:
                try:
                    prob = np.exp(eff.score_samples(data.reshape(1, -1))[0])
                except:
                    prob = 0
            m_prob = max(m_prob, prob)

        if m_prob > 0.8:
            return 1
        return -1

    # given some transitions in the new task, determine the object types
    def infer_types(self, transition_samples, verbose=True, **kwargs):

        n_objects = len(transition_samples[0].state)

        # **************************************************************
        # TODO remove
        types = {i: i for i in range(n_objects)}
        for i in [1, 2, 3, 4]:
            types[i] = 1
        return types

        # counts for object - types
        counts = {object: {type: 0 for type in self.all_types} for object in range(n_objects)}
        for object in range(n_objects):
            counts[object][-1] = 0  # other type

        profiles = {object: Profile(object) for object in range(n_objects)}

        for sample in transition_samples:
            option = sample.option
            mask = sample.mask
            next_state = sample.next_state
            for object in range(n_objects):
                if object in mask:
                    effect = next_state[object]
                    profiles[object].update(option, effect)
        types = dict()
        for object in range(n_objects):

            best = None
            best_proba = 0

            for type, effects in self.type_effects.items():
                # proba = profiles[object].compare(effects, render=kwargs['render'])
                proba = profiles[object].compare(effects)
                # if object == 9:
                #     if 'render' in kwargs:
                #         for e in effects:
                #             if e is not None:
                #                 kwargs['render'](e)

                if proba > best_proba:
                    best = type
                    best_proba = proba

            if verbose:
                print('object {} of type {}: {}'.format(object, best, best_proba))
            types[object] = best
        return types


    # given transition, determine the probability that each operator describes it
    def operator_proba(self, task_id, state, option, next_state):
        """
        Determine the probability that transition refers to each operator
        :param task_id: the current task ID
        :param state:
        :param option:
        :param next_state:
        :return:
        """

        probs = dict()

        for operator in self.operators:
            if operator.option != option:
                probs[operator] = 0
            else:
                # TODO compute mask for this task

                # need to find out what types the operator takes in, then try all possible combinations of valid ones in the current transition
                # must also find the mapping from the object ID's in the original operator's mask so we know what order to pass
                # e.g. if kde is distribution over [1, 2] then we must find out what types 1, 2 were (in the original tasks) so we
                # know what order to pass to the kde now. We don't need to do this if we only ever had distributions over one thing
                # (must check, but I think this is the case! In which case it's much easier)
                operator_types = operator.types
                types = defaultdict(list)
                for i, val in enumerate(state):
                    t = self.get_type(task_id, i)
                    if t in operator_types:
                        types[t].append(i)
                max_prob = 0
                if len(operator_types) == len(types.keys()):
                    temp = list()
                    for t in operator_types:
                        temp.append(types[t])
                    for mask in itertools.product(*temp):
                        if len(mask) == len(set(mask)):
                            # valid
                            y = operator.transition_proba(state, next_state, operator_types, mask)
                            max_prob = max(y, max_prob)
                probs[operator] = max_prob
        return probs

    def __str__(self):
        s = """
        *************************************************
        *                                               *
        *                 KNOWLEDGE BASE                *
        *                                               *
        *************************************************"""
        s += '\n\n'
        for task_id, map in self.types.items():
            s += 'There are {} objects in task {}\n'.format(len(map), task_id)
            for idx, type in map.items():
                s += '\tObject {} is of type {}\n'.format(idx, type)
        s += '\n\n'
        s += 'We have the following {} PREDICATES:\n\n'.format(len(self.predicates))
        for predicate in self.predicates:
            s += str(predicate) + '\n'

        s += '\n\n'
        s += 'We have the following {} OPERATORS:\n\n'.format(len(self.operators))
        for operator in self.operators:
            s += operator.name + '\n'
        return s
