import torch
import torch.nn as nn

# phi(x,y):=A(x,y) | a(x),A(x,y),a(y) | A(x,z),a(z),A(z,y)
class LogicModule2(nn.Module):
    def __init__(self, n_pred1, n_pred2) -> None:
        super().__init__()
        self.info = dict()
        self.info['arity'] = 2
        self.info['fixed'] = False
        self.n_pred1 = n_pred1
        self.n_pred2 = n_pred2
        self.logic_threshhold = 0.3
        w2 = nn.Parameter(torch.zeros(self.n_pred2))
        w3 = nn.Parameter(torch.zeros(self.n_pred1))
        w4 = nn.Parameter(torch.zeros(self.n_pred2))
        w5 = nn.Parameter(torch.zeros(self.n_pred1))
        w6 = nn.Parameter(torch.zeros(self.n_pred2))
        w7 = nn.Parameter(torch.zeros(self.n_pred1))
        theta = nn.Parameter(torch.zeros(2))
        self.register_parameter('w2',w2)
        self.register_parameter('w3',w3)
        self.register_parameter('w4',w4)
        self.register_parameter('w5',w5)
        self.register_parameter('w6',w6)
        self.register_parameter('w7',w7)
        self.register_parameter('theta',theta)        

    def forward(self, a, A):
        a = a[:self.n_pred1]
        A = A[:self.n_pred2]
        w2 = self.w2.softmax(dim=0)
        w3 = self.w3.softmax(dim=0)
        w4 = self.w4.softmax(dim=0)
        w5 = self.w5.softmax(dim=0)
        w6 = self.w6.softmax(dim=0)
        w7 = self.w7.softmax(dim=0)
        theta = self.theta.softmax(dim=0)

        A2 = A.T.matmul(w2).T
        a3 = a.T.matmul(w3).T
        A4 = A.T.matmul(w4).T
        a5 = a.T.matmul(w5).T
        A6 = A.T.matmul(w6).T
        a7 = a.T.matmul(w7).T

        A2 = A2 * a3
        R2 = A2.matmul(A4)
        R3 = ((a5 * A6).T * a7).T

        return torch.stack([R2, R3],dim=-1).matmul(theta)

    def logic(self, a, A):
        with torch.no_grad():
            return (self.forward(a, A)>self.logic_threshhold).type(torch.int)

    def fix_parameters(self):
        for p in self.parameters():
            p.requires_grad = False
        self.info['fixed'] = True

# phi(x):=a(x) | a(y),A(x,y) | A(x,x)
class LogicModule1(nn.Module):
    def __init__(self, n_pred1, n_pred2) -> None:
        super().__init__()
        self.info = dict()
        self.info['arity'] = 1
        self.info['fixed'] = False
        self.n_pred1 = n_pred1
        self.n_pred2 = n_pred2
        self.logic_threshhold = 0.3
        w2 = nn.Parameter(torch.zeros(self.n_pred1))
        w3 = nn.Parameter(torch.zeros(self.n_pred2))
        w4 = nn.Parameter(torch.zeros(self.n_pred2))
        theta = nn.Parameter(torch.zeros(2))
        self.register_parameter('w2',w2)
        self.register_parameter('w3',w3)
        self.register_parameter('w4',w4)
        self.register_parameter('theta',theta)

    def forward(self, a, A):
        a = a[:self.n_pred1]
        A = A[:self.n_pred2]
        w2 = self.w2.softmax(dim=0)
        w3 = self.w3.softmax(dim=0)
        w4 = self.w4.softmax(dim=0)
        theta = self.theta.softmax(dim=0)

        a2 = a.T.matmul(w2).T
        A3 = A.T.matmul(w3).T
        a4 = (A.T.matmul(w4).T).diag()

        r2 = A3.matmul(a2)

        return torch.stack([r2, a4],dim=-1).matmul(theta)

    def logic(self, a, A):
        with torch.no_grad():
            return (self.forward(a, A)>self.logic_threshhold).type(torch.int)


    def fix_parameters(self):
        for p in self.parameters():
            p.requires_grad = False
        self.info['fixed'] = True