import torch
import torch.nn as nn
from torch.nn import functional as F
from tqdm import tqdm
from device import device

class ReasoningModule(nn.Module):
    def __init__(self, n_predicates, info:dict, randinit:bool=False, power=1.0):
        super(ReasoningModule, self).__init__()
        self.info = info

        self.n_predicates = n_predicates
        self.power = power
        if randinit:
            w1 = nn.Parameter(torch.randn(self.n_predicates, device=device))
            w2 = nn.Parameter(torch.randn(self.n_predicates, device=device))
        else:
            w1 = nn.Parameter(torch.zeros(self.n_predicates, device=device))
            w2 = nn.Parameter(torch.zeros(self.n_predicates, device=device))
        weight = nn.Parameter(torch.ones((), device=device))
        self.register_parameter('w1', w1)
        self.register_parameter('w2', w2)
        self.register_parameter('weight', weight)

        self.fixed = False

    def fix_parameters(self):
        self.fixed = True
        for para in self.parameters():
            para.requires_grad = False
        self.phi1 = self.w1.argmax().tolist()
        self.phi2 = self.w2.argmax().tolist()

    def generate_phi_mat(self, A):
        if not self.fixed:
            w1 = self.w1.softmax(dim=0)
            w2 = self.w2.softmax(dim=0)
            if self.power != 1.0:
                w1 = w1.pow(self.power)
                w2 = w2.pow(self.power)
            
            A0 = A[:self.n_predicates]
            return A0.T.matmul(w1).T, A0.T.matmul(w2).T
        else:
            return A[self.phi1].detach(), A[self.phi2].detach()

    def EXIST(self, A, B):
        R = A.matmul(B)
        # return R
        return R / R.clamp(min=1).detach()

        # A = A.unsqueeze(-1)
        # A = A.repeat(1,1,A.shape[0])
        # B = B.unsqueeze(0)
        # B = B.repeat(B.shape[0],1,1)
        # return 1-(1-A*B).clamp(min=1e-10).log().sum(dim=1).exp()

    def forward(self, A):
        m1, m2 = self.generate_phi_mat(A)

        scores = self.EXIST(m1, m2)
        return scores

    def equal(self, module):
        assert self.fixed
        assert module.fixed
        if self.mode.argmax().tolist() != module.mode.argmax().tolist():
            return False
        if self.phi1 != module.phi1:
            return False
        if self.phi2 != module.phi2:
            return False
        return True


    def output_model(self, predicate_id2name:dict=None):
        if not self.fixed:
            raise RuntimeError('Fix the model first!')

        if predicate_id2name == None:
            predicate_id2name = dict()
        def name(index):
            if index in predicate_id2name:
                return predicate_id2name[index]
            else:
                return 'R'+str(index-len(predicate_id2name))

        output = name(self.info['head'])+'(X,Y)<-'
        w1 = self.w1.argmax().tolist()
        w2 = self.w2.argmax().tolist()

        output += 'EXIST Z['+name(w1)+'(X,Z) AND '+name(w2)+'(Z,Y)]'

        return output

    def check_convergence(self):
        if not self.fixed:
            raise RuntimeError('Fix the model first!')

        ret = list()
        w1 = self.w1.softmax(dim=-1)
        w2 = self.w2.softmax(dim=-1)
        for i in range(self.w1.shape[0]):
            if i != w1.argmax().tolist():
                ret.append(w1[i])
            if i != w2.argmax().tolist():
                ret.append(w2[i])
        return ret