import torch
import torch.nn as nn

# phi(x,y):=A(x,y) | A(x,z),a(z),A(z,y) | a(x),A(x,y),a(y)
class LogicModule2(nn.Module):
    def __init__(self, n_pred1, n_pred2) -> None:
        super().__init__()
        self.info = dict()
        self.info['arity'] = 2
        self.info['fixed'] = False
        self.n_pred1 = n_pred1
        self.n_pred2 = n_pred2
        self.logic_threshhold = 0.3
        w2 = nn.Parameter(torch.zeros(self.n_pred2+1))
        w3 = nn.Parameter(torch.zeros(self.n_pred1+1))
        w4 = nn.Parameter(torch.zeros(self.n_pred2+1))
        w5 = nn.Parameter(torch.zeros(self.n_pred1+1))
        w6 = nn.Parameter(torch.zeros(self.n_pred2+1))
        w7 = nn.Parameter(torch.zeros(self.n_pred1+1))
        theta = nn.Parameter(torch.zeros(2))
        self.register_parameter('w2',w2)
        self.register_parameter('w3',w3)
        self.register_parameter('w4',w4)
        self.register_parameter('w5',w5)
        self.register_parameter('w6',w6)
        self.register_parameter('w7',w7)
        self.register_parameter('theta',theta)        

        self.p2 = -1
        self.p3 = -1
        self.p4 = -1
        self.p5 = -1
        self.p6 = -1
        self.p7 = -1

    def convert(self, a, A, i):
        d = -1
        if i == 2:
            d = self.p2
        elif i == 3:
            d = self.p3
        elif i == 4:
            d = self.p4
        elif i == 5:
            d = self.p5
        elif i == 6:
            d = self.p6
        elif i == 7:
            d = self.p7

        if i in [2,4,6]:
            ret = torch.zeros((self.n_pred2+1,A.shape[1],A.shape[2]))
            ret[:self.n_pred2] = A[:self.n_pred2].clone().detach()
            if d != -1:
                ret[self.n_pred2] = A[d].clone()
        elif i in [3,5,7]:
            ret = torch.zeros((self.n_pred1+1,a.shape[1]))
            ret[:self.n_pred1] = a[:self.n_pred1].clone().detach()
            if d != -1:
                ret[self.n_pred1] = a[d].clone()
        return ret




    def forward(self, a, A):
        w2 = self.w2.softmax(dim=0)
        w3 = self.w3.softmax(dim=0)
        w4 = self.w4.softmax(dim=0)
        w5 = self.w5.softmax(dim=0)
        w6 = self.w6.softmax(dim=0)
        w7 = self.w7.softmax(dim=0)
        theta = self.theta.softmax(dim=0)

        A2 = self.convert(a,A,2).T.matmul(w2).T
        a3 = self.convert(a,A,3).T.matmul(w3).T
        A4 = self.convert(a,A,4).T.matmul(w4).T
        a5 = self.convert(a,A,5).T.matmul(w5).T
        A6 = self.convert(a,A,6).T.matmul(w6).T
        a7 = self.convert(a,A,7).T.matmul(w7).T

        A2 = A2 * a3
        R2 = A2.matmul(A4)
        R3 = ((a5 * A6).T * a7).T

        return torch.stack([R2, R3],dim=-1).matmul(theta)

    def logic(self, a, A):
        with torch.no_grad():
            return (self.forward(a, A)>self.logic_threshhold).type(torch.int)

    def fix_parameters(self):
        for p in self.parameters():
            p.requires_grad = False
        self.info['fixed'] = True

# phi(x):=a(x) | a(y),A(x,y) | A(x,x)
class LogicModule1(nn.Module):
    def __init__(self, n_pred1, n_pred2) -> None:
        super().__init__()
        self.info = dict()
        self.info['arity'] = 1
        self.info['fixed'] = False
        self.n_pred1 = n_pred1
        self.n_pred2 = n_pred2
        self.logic_threshhold = 0.3
        w2 = nn.Parameter(torch.zeros(self.n_pred1+1))
        w3 = nn.Parameter(torch.zeros(self.n_pred2+1))
        w4 = nn.Parameter(torch.zeros(self.n_pred2+1))
        theta = nn.Parameter(torch.zeros(2))
        self.register_parameter('w2',w2)
        self.register_parameter('w3',w3)
        self.register_parameter('w4',w4)
        self.register_parameter('theta',theta)

        self.p2 = -1
        self.p3 = -1
        self.p4 = -1

    def convert(self, a, A, i):
        d = -1
        if i == 2:
            d = self.p2
        elif i == 3:
            d = self.p3
        elif i == 4:
            d = self.p4

        if i in [3,4]:
            ret = torch.zeros((self.n_pred2+1,A.shape[1],A.shape[2]))
            ret[:self.n_pred2] = A[:self.n_pred2].clone().detach()
            if d != -1:
                ret[self.n_pred2] = A[d].clone()
        elif i == 2:
            ret = torch.zeros((self.n_pred1+1,a.shape[1]))
            ret[:self.n_pred1] = a[:self.n_pred1].clone().detach()
            if d != -1:
                ret[self.n_pred1] = a[d].clone()
        return ret

    def forward(self, a, A):
        w2 = self.w2.softmax(dim=0)
        w3 = self.w3.softmax(dim=0)
        w4 = self.w4.softmax(dim=0)
        theta = self.theta.softmax(dim=0)

        a2 = self.convert(a,A,2).T.matmul(w2).T
        A3 = self.convert(a,A,3).T.matmul(w3).T
        a4 = (self.convert(a,A,4).T.matmul(w4).T).diag()

        r2 = A3.matmul(a2)

        return torch.stack([r2, a4],dim=-1).matmul(theta)

    def logic(self, a, A):
        with torch.no_grad():
            return (self.forward(a, A)>self.logic_threshhold).type(torch.int)


    def fix_parameters(self):
        for p in self.parameters():
            p.requires_grad = False
        self.info['fixed'] = True