from .fit_base import Fit


class FitRex(Fit):
    def __init__(self, rules, seq):
        super().__init__(rules)
        self.seq = seq

    def fit(self, now, **kwargs):
        """

        :param now:
        :param kwargs: threshold
        :return:
        """

        threshold = 0.05
        totwt = 0
        # print(rules)
        # print(self.rules)
        for (wordx, wordy, dist), wt in self.rules:
            wordx = self.seq[wordx]
            wordy = self.seq[wordy]
            posx = -1
            posy = -1
            for i, x in enumerate(now):
                if x == wordx:
                    posx = i
                if x == wordy:
                    posy = i
            if dist == -1:
                dist -= 1
            if posx != -1 and posy != -1 and posy - posx > dist:
                totwt += wt
        # if threshold < 0:
        #     totwt = -totwt
        #     threshold = - threshold
        # print("weight: ",end="")
        # print(abs(totwt),threshold)
        if abs(totwt) > threshold:
            if totwt>0:
                return 1
            else:
                return -1
        return 0


class FitLime(Fit):
    def __init__(self, rules, words):
        super().__init__(rules)
        self.words = words

    def fit(self, now, **kwargs):
        threshold = 0.05
        totwt = 0
        if len(now) != len(self.words):
            return False
        for word, wt in self.rules:
            if self.words[word] in now:
                for i, x in enumerate(now):
                    if x == self.words[word] and self.words[i] == x:
                        totwt += wt
        if abs(totwt) > threshold:
            if totwt>0:
                return 1
            else:
                return -1
        return 0


class FitLimeNoDist(Fit):
    def fit(self, now, **kwargs):
        threshold = 0.05
        label = kwargs['label']
        totwt = sum([wt for (word, wt) in self.rules if word in now])
        if label == 0:
            totwt = -totwt
            threshold = -threshold
        return totwt>threshold
