import numpy as np
# from predict import Predictor

tot = 0
class CovPreCalculator:

    def __init__(self, words, rules, myfit, predict,label):
        self.rules = rules
        self.words = words
        self.fit = myfit.fit
        self.predict = predict
        self.mask = -100
        self.label = label



    def sample(self):
        text = self.words.copy()
        t = list(range(len(text)))
        part = np.random.choice(list(range(len(text))),2)
        part = sorted(part)
        np.random.shuffle(t)
        # print(part)
        for i in range(part[0]):
            text[i] = -100
        for i in range(part[0],part[1]):
            text[i] = text[i] + 2*np.random.random()-1
        text = [x for x in text if x!= -100]
        return text



    def calc(self):

        # sp = self.sp
        global tot
        tot = 0
        matched = []

        tot = 1000


        for i in range(1000):
            now = self.sample()
            if self.fit(now):
                matched.append(now)

        inrule = list(matched)
        # inanchor = list(inanchor)
        coverage = len(inrule)/tot

        myres = self.predict(inrule)
        print(sum(myres),len(myres))
        pos = []
        neg = []
        for i,x in enumerate(myres):
            if x == 1:
                pos.append(inrule[i])
            elif x == 0:
                neg.append(inrule[i])
        precision = len(pos if self.label == 1 else neg)/len(inrule)
        # print("positive:")
        # print(np.random.choice(pos,num))
        # print("negative:")
        # print(np.random.choice(neg,num))
        print([coverage, precision])
        return [coverage, precision]
    
    
class CovPreCalculatorLIME:

    def __init__(self, words, rules, myfit, predict,label):
        self.rules = rules
        self.words = words
        self.fit = myfit.fit
        self.predict = predict
        self.mask = -100
        self.label = label



    def sample(self):
        text = self.words.copy()
        t = list(range(len(text)))
        part = np.random.choice(list(range(len(text))),2)
        part = sorted(part)
        np.random.shuffle(t)
        # print(part)
        for i in range(part[0]):
            text[i] = -100
        for i in range(part[0],part[1]):
            text[i] = text[i] + 2*np.random.random()-1
        text = [x for x in text if x!= -100]
        return text



    def calc(self):

        # sp = self.sp
        global tot
        tot = 0
        matched = []

        tot = 10000

        rpos = []
        rneg = []

        for i in range(10000):
            now = self.sample()
            t = self.fit(now)
            if t!=0:
                matched.append(now)
                if t>0:
                    rpos.append(now)
                else:
                    rneg.append(now)
        inrule = list(matched)
        # inanchor = list(inanchor)
        coverage = len(inrule)/tot

        mypos = self.predict(rpos)
        myneg = self.predict(rneg)
        right = 0
        for i,x in enumerate(mypos):
            if x == 1:
                right+=1
        for i,x in enumerate(myneg):
            if x== 0:
                right+=1
        precision = right/len(inrule)
        # print("positive:")
        # print(np.random.choice(pos,num))
        # print("negative:")
        # print(np.random.choice(neg,num))
        print([coverage, precision])
        return [coverage, precision]