from collections import defaultdict
import numpy as np

import logging

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)

class Predicate:
    def __init__(self, negate=False):
        self.negate = negate

    def get_window(self, subject, time, state_space, end_time=None):
        obj_data = state_space[subject]
        window_data = {}
        for variable, data in obj_data.items():
            if variable == 'object_movements' or len(data) <= time:
                continue
            if end_time is not None:
                window_data[variable] = data[time: end_time]
            else:
                window_data[variable] = data[time]
        return window_data

    def evaluate(self, subject, time, state_space, end_time=None, obj=None):
        window = self.get_window(subject, time, state_space, end_time=end_time)
        value = self.execute(window)
        if self.negate:
            value = not value
        return value

    def evaluate_soft(self, subject, time, state_space, end_time=None, obj=None):
        window = self.get_window(subject, time, state_space, end_time=end_time)
        value = self.execute_soft(window)
        if self.negate:
            value = 1 - value
        return value

    def execute(self, window) -> bool:
        raise NotImplementedError

X_COORD = 0
Y_COORD = 1
Z_COORD = 2

MAX_DIST = 12

def parabolic_dist(x1, x2):
    return - ((x1 - x2)/MAX_DIST)**2 + 1

def sigmoid_sim(x1, x2):
    # 0 difference leads to 1 simiarlity, else drops off rapidly
    # google: 2/(1 + e^x)
    similarity  = 2/(1 + np.exp(abs(x1 - x2)))
    return similarity

# ideally these shift and temperatures should be learnable end to end
def shift_sigmoid_sim(x1, x2, shift=4, temp=1):
    # 0 difference lead to ~ 1 similarity as well as region close to 0
    # then drops off rapidly, google: 1/(1 + e^(x - 4)) 
    return 1/(1 + np.exp((abs(x1 - x2) - shift)/temp))

class SameCoord(Predicate):
    def __init__(self, dim, negate=False, margin=0.1):
        super().__init__(negate=negate)
        self.dim = dim
        self.margin = margin

    def evaluate(self, subject, time, state_space, end_time=None, obj=None):
        subj_window = self.get_window(subject, time, state_space)
        obj_window = self.get_window(obj, time, state_space)
        value = self.execute(subj_window, obj_window)
        if self.negate:
            value = not value
        return value

    def evaluate_soft(self, subject, time, state_space, end_time=None, obj=None):
        subj_window = self.get_window(subject, time, state_space)
        obj_window = self.get_window(obj, time, state_space)
        value = self.execute_soft(subj_window, obj_window)
        if self.negate:
            value = 1 - value
        return value

    def execute(self, subj_window, obj_window):
        subj_location = subj_window['object_loc']
        obj_location = obj_window['object_loc']
        return abs(subj_location[self.dim] - obj_location[self.dim]) <= self.margin

    def execute_soft(self, subj_window, obj_window):
        subj_location = subj_window['object_loc']
        obj_location = obj_window['object_loc']
        return shift_sigmoid_sim(subj_location[self.dim], obj_location[self.dim], shift=self.margin)

class ChangeCoord(Predicate):
    def __init__(self, dim, negate=False, margin=0.01):
        super().__init__(negate=negate)
        self.dim = dim
        self.margin = margin

    def execute(self, window):
        start_location = window['object_loc'][0]
        end_location = window['object_loc'][-1]
        return abs(start_location[self.dim] - end_location[self.dim]) > self.margin

    def execute_soft(self, window):
        start_location = window['object_loc'][0]
        end_location = window['object_loc'][-1]
        return 1 - shift_sigmoid_sim(start_location[self.dim], end_location[self.dim], shift=self.margin)

class ConstantCoord(Predicate):
    def __init__(self, dim, negate=False, margin=0.1):
        super().__init__(negate=negate)
        self.dim = dim
        self.margin = margin

    def execute(self, window):
        locations = window['object_loc']
        start_location = locations[0]
        #TODO: can vectorize
        for check_location in locations[1:]:
            if abs(start_location[self.dim] - check_location[self.dim]) > self.margin:
                return False
        return True

    def execute_soft(self, window):
        locations = window['object_loc']
        start_location = locations[0]
        #TODO: can vectorize
        distances = []
        for check_location in locations[1:]:
            distances.append(shift_sigmoid_sim(start_location[self.dim], check_location[self.dim], shift=self.margin))
        return min(distances)

class Rule:
    def __init__(self, name, *predicates):
        self.name = name
        self.predicates = predicates

    def evaluate(self, subject, time, state_space, end_time=None, obj=None):
        # conjunction formulism
        values = []
        for predicate_conj in self.predicates:
            value = True
            for predicate in predicate_conj:
                value = value and predicate.evaluate(subject, time, state_space, end_time=end_time, obj=obj)
                if not value:
                    # logger.info(f'{self.name} predicate {predicate} failed')
                    values.append(False)
                    break
            else:
                values.append(True)
        return bool(sum(values))

    def evaluate_soft(self, subject, time, state_space, end_time=None, obj=None):
        values = []
        for predicate_conj in self.predicates:
            conj_values = []
            for predicate in predicate_conj:
                value = predicate.evaluate_soft(subject, time, state_space, end_time=end_time, obj=obj)
                # print(predicate, value)
                conj_values.append(value)
            values.append(conj_values)
        # consider np.mean(values) as well if evaluations are too small
        # get the maximum product from the most likely conjunction, give many disjunctions
        return np.max(np.min(values, axis=1))

# technically it can slide accross x or y
# slide = Rule('slide', [ChangeCoord(X_COORD), ChangeCoord(Y_COORD), ConstantCoord(Z_COORD, margin=0.1)])
slide = Rule('slide', [ChangeCoord(X_COORD), ConstantCoord(Z_COORD)], 
                    [ChangeCoord(Y_COORD), ConstantCoord(Z_COORD)])

# pick_place = Rule('pick_place', [ChangeCoord(X_COORD), ChangeCoord(Y_COORD), ConstantCoord(Z_COORD, margin=0.1, negate=True)])
pick_place = Rule('pick_place', [ChangeCoord(X_COORD), ConstantCoord(Z_COORD, negate=True)], 
                                [ChangeCoord(Y_COORD), ConstantCoord(Z_COORD, negate=True)])
                                
contain = Rule('contain', [SameCoord(X_COORD), SameCoord(Y_COORD)])#, SameCoord(Z_COORD, margin=0.7)])  # object heights may vary
