import torch
from tqdm import tqdm
from datagenerator import *
from module import LogicModule2
from module import LogicModule1
from model import LogicModel

# D = ILP2()
# D = ILP5()
# D = ILP6()
# D = ILP16()
D = ILP20()

n_predicate = D.n_predicate
target = n_predicate[0]-1
b,B = D.get_data(dim=10)
p = b[target].clone()
neg = 1-b[target]
b[target] = 0
model = LogicModel(n_predicate[0],n_predicate[1])
print(p)
print(target)
for _ in range(5):
    model.add_module2()
    model.add_module2()
    model.add_module1()
    model.add_module1()
    model.add_module1()
    model.add_module1()
    model.add_module1(target)

    opt = torch.optim.Adam(model.parameters(),lr=1e-1)
    with tqdm(range(500),ncols=80) as _t:
        for _ in _t:
            pred = model.forward(b, B)[0][model.submodules[-1].info['dim']]
            loss = -((p*pred).sum()/p.sum()+1e-5).log() + ((neg*pred).sum()/neg.sum()+1e-5).log()
            _t.set_postfix_str('POS: {:.2f} NEG: {:.2f}'.format((p*pred).sum(),(neg*pred).sum()))
            opt.zero_grad()
            loss.backward()
            opt.step()

    model.fix_parameters()
    model.inference_(b, B)
    print('Proved:   '+str(b[target].tolist()))
    p -= b[target]
    p = p.clamp(min=0,max=1)
    print('Unproved: '+str(p.tolist()))
    if p.sum() == 0:
        print('Done.')
        break