import torch
from tqdm import tqdm
from datagenerator import *
from module import LogicModule2
from module import LogicModule1
from model import LogicModel


D = ILP1()
# D = ILP4()
# D = ILP7()
# D = ILP8()
# D = ILP15()
# D = ILP9()
# D = ILP19()

n_predicate = D.n_predicate
target = n_predicate[1]-1
b,B = D.get_data(dim=10)
p = B[target].clone()
neg = 1-B[target]
B[target] = 0
model = LogicModel(n_predicate[0],n_predicate[1])
print(p)

for _ in range(5):
    model.add_module2()
    model.add_module2(target,p2=0)

    opt = torch.optim.Adam(model.parameters(),lr=1e-1)
    with tqdm(range(500),ncols=80) as _t:
        for _ in _t:
            pred = model.forward(b, B)[1][model.submodules[-1].info['dim']]
            loss = -((p*pred).sum()/p.sum()+1e-5).log() + ((neg*pred).sum()/neg.sum()+1e-5).log()
            _t.set_postfix_str('POS: {:.2f} NEG: {:.2f}'.format((p*pred).sum(),(neg*pred).sum()))
            opt.zero_grad()
            loss.backward()
            opt.step()

    model.fix_parameters()
    model.inference_(b, B)
    print(B[target])
    p -= B[target]
    p = p.clamp(min=0,max=1)
    print(p)
    print('Unproved: '+str(p.sum().tolist()))
    if p.sum() == 0:
        print('Done.')
        break
