import random
from sympy.utilities.iterables import multiset_permutations as permutations
import copy
import re

def write_to_file(data, filename):
    with open(filename, "w+") as ff:
        for s in data:
            ff.write(s + "\n")

def create_permutations(size):
    data = ["0", "1"]
    temp = []
    while True:
        for sample in data:
            temp.append(sample + "0")
            temp.append(sample + "1")
        
        data = copy.deepcopy(temp)
        if len(temp[-1]) == size:
            break
        temp = []
    return data



class TomitaData(object):
    def generate_data(self, min_len, max_len):
        raise NotImplementedError


class TomitaOne(TomitaData):
    def generate_data(self, min_len, max_len):
        data = []
        for idx in range(min_len, max_len+1):
            data.append("1"*idx)
        return data
    
    @staticmethod
    def test(string):
        for ch in string:
            if ch != "1":
                return False
        return True
        
        
class TomitaTwo(TomitaData):
    def generate_data(self, min_len, max_len):
        data = []
        if min_len % 2 == 1:
            min_len += 1
        
        if max_len % 2 == 1:
            max_len -= 1
            
        for idx in range(min_len, max_len+2, 2):
            usize = int(idx/2)
            data.append("10" * usize)
    
        return data

    @staticmethod
    def test(string):
        if len(string) % 2 == 1:
            return False
        for i in range(0, len(string), 2):
            ch = string[i:i+2]
            if ch != "10":
                return False
        return True


class TomitaThree(TomitaData):
    def generate_data(self, min_len, max_len):
        data = ["1", "0"]
        temp_1 = [("1", "S1"), ("0", "S5")]
        temp_2 = []
        while(True):
            for ss in temp_1:
                string, state = ss
                if state == "S1":
                    _string = string + "0"
                    _state = "S3"
                    temp_2.append((_string,_state))
                    _string = string + "1"
                    _state = "S2"
                    temp_2.append((_string, _state))
                
                elif state == "S2":
                    _string = string + "0"
                    _state = "S5"
                    temp_2.append((_string, _state))
                    _string = string + "1"
                    _state = "S1"
                    temp_2.append((_string, _state))
        
                elif state == "S3":
                    _string = string + "0"
                    _state = "S4"
                    temp_2.append((_string, _state))
                    _string = string + "1"
                    _state = "S6"
                    temp_2.append((_string, _state))
        
                elif state == "S4":
                    _string = string + "0"
                    _state = "S3"
                    temp_2.append((_string, _state))
                    _string = string + "1"
                    _state = "S1"
                    temp_2.append((_string, _state))
                    
                elif state == "S5":
                    _string = string + "0"
                    _state = "S5"
                    temp_2.append((_string, _state))
                    _string = string + "1"
                    _state = "S1"
                    temp_2.append((_string, _state))
        
                elif state == "S6":
                    pass
            
            temp_1 = list(filter(lambda x: False if x[1] == "S6" or x[1] == "S3" else True, temp_2))
            temp_2 = []
            data.extend([ii[0] for ii in temp_1 if len(ii[0])])
            if len(data[-1]) >= max_len:
                break

            data = list(filter(lambda x: len(x) >= min_len and len(x) <= max_len,  data))
        
        return data
    
    def generate_random_sample(self, length):
        if length % 2 == 1:
            odd = True
            l2 = int((length-1)/2)
        else:
            odd = False
            l2 = int(length/2)
        s = "".join([random.choice(["00", "11"]) for ii in range(l2)])
        if odd:
            s = s + "1"
        return s
        
    @staticmethod
    def test(string):
        _not_tomita_3 = re.compile("((0|1)*0)*1(11)*(0(0|1)*1)*0(00)*(1(0|1)*)*$")
        if None == _not_tomita_3.match(string):
            return True
        else:
            return False

    
class TomitaFour(TomitaData):
    def generate_data(self, min_len, max_len):
        data = set()
        for rr in range(min_len, max_len+1):
            samples = create_permutations(rr)
            for s in samples:
                if "000" not in s:
                    data.add(s)
        
        return list(data)
    
    @staticmethod
    def test(string):
        if "000" not in string:
            return True
        else:
            return False


class TomitaFive(TomitaData):
    def generate_data(self, min_len, max_len):
        data = set()
        for r in range(min_len, max_len+1):
            for i in range(0, r, 2):
                j = r-i
                if j%2 == 0:
                    sample = "1"*i + "0"*j
                    for p in permutations(sample):
                        data.add("".join(p))
        return list(data)
    
    @staticmethod
    def test(string):
        r = len(string)
        ones = sum([int(ch) for ch in string])
        zeros = r - ones
        if ones % 2 == 0 and zeros % 2 == 0:
            return True
        else:
            return False
    
    
class TomitaSix(TomitaData):
    def generate_data(self, min_len, max_len):
        data = set()
        for r in range(min_len, max_len+1):
            for i in range(0, r):
                j = r-i
                if abs(i-j) % 3 == 0:
                    sample = "0"*i + "1"*j
                    for p in permutations(sample):
                        data.add("".join(p))
        return list(data)
    
    @staticmethod
    def test(string):
        r = len(string)
        ones = sum([int(ch) for ch in string])
        zeros = r - ones
        if abs(ones - zeros) % 3 == 0:
            return True
        else:
            return False
        
            
class TomitaSeven(TomitaData):
    def generate_data(self, min_len, max_len):
        data = set()
        for r in range(min_len, max_len+1):
            for i in range(0, r):
                for j in range(0, r-i):
                    for k in range(0, r - (i+j)):
                        l = r - (i+j+k)
                        sample = "0"*i + "1"*j + "0"*k + "1"*l
                        data.add(sample)
                        
        return list(data)
    
    @staticmethod
    def test(string):
        s = 0
        for i in range(1, len(string)):
            s += int(string[i])^int(string[i-1])
        
        if s >= 4:
            return False
        else:
            return True

class NegativeExamples(TomitaData):
    def __init__(self):
        self.test_funcs = []
        
    def register_tests(self, test_func):
        self.test_funcs.append(test_func)
    
    def generate_data(self, ll, max_samples=None):
        data = set()
        irange = list(range(0, ll+1))
        random.shuffle(irange)
        for i in irange:
            j = ll-i
            string = "0"*i + "1"*j
            for p in permutations("".join(string)):
                s = "".join(p)
                is_neg = True
                for test in self.test_funcs:
                    if test(s):
                        #print(type(test))
                        is_neg = False
                        break
                #print(s, is_neg)
                if is_neg:
                    data.add(s)


        data = list(data)
        random.shuffle(data)
        if max_samples is None:
            return data
        return data[:max_samples]


if __name__ == "__main__":
    # print("----------- Tomita 1 --------------")
    # cls = TomitaOne()
    # data = cls.generate_data(3, 10)
    # test = map(TomitaOne.test, data)
    # print(data)
    # print(list(test))
    # print(len(data))
    #
    # print("----------- Tomita 2 --------------")
    # cls = TomitaTwo()
    # data = cls.generate_data(3, 10)
    # test = map(TomitaTwo.test, data)
    # print(data)
    # print(list(test))
    # print(len(data))
    #
    print("----------- Tomita 3 --------------")
    cls = TomitaThree()
    data = cls.generate_data(3, 10)
    test = map(TomitaThree.test, data)
    print(data)
    print(list(test))
    print(len(data))

    # print("----------- Tomita 4 --------------")
    # cls = TomitaFour()
    # data = cls.generate_data(20, 20)
    # test = map(TomitaFour.test, data)
    # print(data)
    # print(list(test))
    # print(len(data))

    # print("----------- Tomita 5 --------------")
    # cls = TomitaFive()
    # data = cls.generate_data(12, 12)
    # test = map(TomitaFive.test, data)
    # print(data)
    # print(list(test))
    # print(len(data))
    #
    # print("----------- Tomita 6 --------------")
    # cls = TomitaSix()
    # data = cls.generate_data(3, 10)
    # test = map(TomitaSix.test, data)
    # print(data)
    # print(list(test))
    # print(len(data))
    #
    # print("----------- Tomita 7 --------------")
    # cls = TomitaSeven()
    # data = cls.generate_data(3, 10)
    # test = map(TomitaSeven.test, data)
    # print(data)
    # print(list(test))
    # print(len(data))
    #
    # print("----------- Negative Examples --------------")
    # cls = NegativeExamples()
    # test_funcs = [TomitaOne.test, TomitaTwo.test, TomitaThree.test,
    #               TomitaFour.test, TomitaFive.test, TomitaSix.test,
    #               TomitaSeven.test]
    # for test in test_funcs:
    #     cls.register_tests(test)
    #
    # cls.set_max_samples(10000)
    # data = cls.generate_data(3, 20)
    # print(data)
    # print(len(data))
    