from abc import ABCMeta,abstractmethod
import torch

class DataInterface(metaclass=ABCMeta):
    @abstractmethod
    def get_data(self, dim):
        pass
        # b, B, p, P

    # @abstractmethod
    # def get_dict(self):
    #     pass
    #     # dict[dim_b], dict[dim_B], name_p, name_P

#       target
#       inv_succ
# 1     succ
class ILP1(DataInterface):
    def __init__(self) -> None:
        super().__init__()
        self.n_predicate = (1, 3)
        self.N = 10

    def get_data(self, dim=10):
        b = torch.zeros((dim, self.N))
        B = torch.zeros((dim, self.N, self.N))
        p = None
        P = torch.zeros((self.N, self.N))

        b[0] = 1
        for i in range(self.N-1):
            B[0,i,i+1] = 1
            B[1,i+1,i] = 1
            B[2,i+1,i] = 1
        return b, B

# target    inv_succ
# zero      succ
class ILP2(DataInterface):
    def __init__(self) -> None:
        super().__init__()
        self.n_predicate = (2, 2)
        self.N = 10

    def get_data(self, dim=10):
        b = torch.zeros((dim, self.N))
        B = torch.zeros((dim, self.N, self.N))

        b[0,0] = 1
        for i in range(self.N-1):
            B[0,i,i+1] = 1
            B[1,i+1,i] = 1


        for i in range(0,self.N,2):
            b[1][i] = 1

        return b, B

# target
# succ
class ILP4(DataInterface):
    def __init__(self) -> None:
        super().__init__()
        self.n_predicate = (1, 2)
        self.N = 10

    def get_data(self, dim=10):
        b = torch.ones((dim, self.N))
        B = torch.zeros((dim, self.N, self.N))

        for i in range(self.N-1):
            B[0,i,i+1] = 1
        for i in range(self.N):
            for j in range(i+1, self.N):
                B[1,i,j] = 1

        return b, B

# target    inv_succ
# zero      succ
class ILP5(DataInterface):
    def __init__(self) -> None:
        super().__init__()
        self.n_predicate = (2, 2)
        self.N = 10

    def get_data(self, dim=10):
        b = torch.zeros((dim, self.N))
        B = torch.zeros((dim, self.N, self.N))

        b[0,0] = 1
        for i in range(self.N-1):
            B[0,i,i+1] = 1
            B[1,i+1,i] = 1
        for i in range(0,self.N,3):
            b[1, i] = 1

        return b, B

# target    inv_succ
# zero      succ
class ILP6(DataInterface):
    def __init__(self) -> None:
        super().__init__()
        self.n_predicate = (2, 2)
        self.N = 20

    def get_data(self, dim=10):
        b = torch.zeros((dim, self.N))
        B = torch.zeros((dim, self.N, self.N))

        b[0,0] = 1
        for i in range(self.N-1):
            B[0,i,i+1] = 1
            B[1,i+1,i] = 1
        for i in range(0,self.N,4):
            b[1, i] = 1

        return b, B

# target
# inv_value
# inv_cons
# value
# cons
class ILP7(DataInterface):
    def __init__(self) -> None:
        super().__init__()
        self.n_predicate = (1, 5)
        self.N = 9

    def get_data(self, dim=10):
        b = torch.ones((dim, self.N))
        B = torch.zeros((dim, self.N, self.N))

        l = list()
        l.append(4)
        l.append(3)
        l.append(2)
        l.append(1)
        off = 4

        for value in l:
            B[1,value+off,value] = 1
        for i in range(len(l)-1):
            B[0,l[i]+off,l[i+1]+off]=1
        B[2] = B[0].T
        B[3] = B[1].T
        for i in range(4):
            for value in l[i:]:
                B[4,value,l[i]+off]=1

        return b, B


# target
# length_one
# inv_value
# inv_cons
# value
# cons
class ILP8(DataInterface):
    def __init__(self) -> None:
        super().__init__()
        self.n_predicate = (1, 6)
        self.N = 9

    def get_data(self, dim=10):
        b = torch.ones((dim, self.N))
        B = torch.zeros((dim, self.N, self.N))

        l = list()
        l.append(4)
        l.append(3)
        l.append(2)
        l.append(1)
        off = 4

        for value in l:
            B[1,value+off,value] = 1
        for i in range(len(l)-1):
            B[0,l[i]+off,l[i+1]+off]=1
        B[2] = B[0].T
        B[3] = B[1].T
        B[4,1+off,1] = 1
        B[5,1+off,1] = 1
        B[5,2+off,2] = 1
        B[5,3+off,3] = 1
        B[5,4+off,4] = 1

        return b, B

# target
# sister
# brother
# inv_father
# father        1
class ILP9(DataInterface):
    def __init__(self) -> None:
        super().__init__()
        self.n_predicate = (1, 5)
        self.N = 9
    
    def get_data(self, dim=10):
        b = torch.zeros((dim, self.N))
        B = torch.zeros((dim, self.N, self.N))
        b[0] = 1

        B[0,0,1] = 1
        B[0,0,2] = 1
        B[0,3,4] = 1
        B[0,3,5] = 1
        B[0,6,7] = 1
        B[0,6,8] = 1
        B[2,1,2] = 1
        B[2,2,1] = 1
        B[2,4,5] = 1
        B[3,5,4] = 1
        B[3,7,8] = 1
        B[3,8,7] = 1
        B[1] = B[0].T
        
        B[4,1,0] = 1
        B[4,2,0] = 1
        B[4,4,3] = 1

        return b, B

# target
# inv_edge
# edge
class ILP15(DataInterface):
    def __init__(self) -> None:
        super().__init__()
        self.n_predicate = (1, 3)
        self.N = 4
    
    def get_data(self, dim=10):
        b = torch.ones((dim, self.N))
        B = torch.zeros((dim, self.N, self.N))
        B[0,0,1] = 1
        B[0,1,3] = 1
        B[0,2,2] = 1
        B[1] = B[0].T
        B[2] = (B[0]+B[1]).clamp(max=1)

        return b, B

# target
# green inv_edge
# red   edge
class ILP16(DataInterface):
    def __init__(self) -> None:
        super().__init__()
        self.n_predicate = (3, 2)
        self.N = 5

    def get_data(self, dim=10):
        b = torch.zeros((dim, self.N))
        B = torch.zeros((dim, self.N, self.N))

        B[0,0,1] = 1
        B[0,1,0] = 1
        B[0,2,3] = 1
        B[0,2,4] = 1
        B[0,3,4] = 1
        B[1] = B[0].T

        b[0,0] = 1
        b[1,1] = 1
        b[0,2] = 1
        b[0,3] = 1
        b[1,4] = 1

        b[2,1] = 1
        b[2,2] = 1

        return b, B

# target
# inv_edge
# edge
class ILP19(DataInterface):
    def __init__(self) -> None:
        super().__init__()
        self.n_predicate = (1, 3)
        self.N = 4

    def get_data(self, dim=10):
        b = torch.ones((dim, self.N))
        B = torch.zeros((dim, self.N, self.N))

        B[0,0,1] = 1
        B[0,1,2] = 1
        B[0,2,3] = 1
        B[0,1,0] = 1
        B[1] = B[0].T
        
        B[2] = B[0]
        for _ in range(10):
            B[2] += B[2].matmul(B[0])
            B[2] = B[2].clamp(max=1)

        return b, B

# target    inv_edge
# 1         edge
class ILP20(DataInterface):
    def __init__(self) -> None:
        super().__init__()
        self.n_predicate = (2, 2)
        self.N = 6

    def get_data(self, dim=10):
        b = torch.zeros((dim, self.N))
        B = torch.zeros((dim, self.N, self.N))
        b[0] = 1

        B[0,0,1] = 1
        B[0,1,2] = 1
        B[0,2,0] = 1
        B[0,1,3] = 1
        B[0,3,4] = 1
        B[0,3,5] = 1
        B[0,4,5] = 1
        B[0,5,4] = 1
        B[1] = B[0].T
        
        b[1,0] = 1
        b[1,1] = 1
        b[1,2] = 1
        b[1,4] = 1
        b[1,5] = 1

        return b, B