import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F
from tqdm import tqdm
from device import device
from dataloader import DataCLUTRR
from model import ReasoningModel

dataloader = DataCLUTRR('./data/data_publish/data_089907f8/')
# dataloader = DataCLUTRR('./data/data_publish/data_db9b8f04/')

n_predicate = dataloader.n_predicate
batch_size = 10
lr=1e-1
epoch=100

model = ReasoningModel(dataloader,fuzzy_induction=False, dim_train=5, merge_inductive_preds=True, dim_max=500)
model.read_data('train.csv')
model.train_model(batch_size=batch_size,norm=0.2,epoch=epoch,lr=lr, max_try_time=200)
torch.save(model, 'model')

# model = torch.load('model')

for filename in dataloader.data:
    model.read_data(filename)
    correct = 0
    total = 0
    with tqdm(model.data, ncols=80) as _tqdm:
        for wt, target, query_edge, graph in _tqdm:
            model.inference_(graph, 10)
            pred = graph[:model.n_predicate_base,query_edge[0],query_edge[1]]
            if pred.max() == pred[target]:
                sum = 0.0
                for i in range(pred.shape[0]):
                    if pred[i] == pred.max():
                        sum += 1.0
                correct += 1.0/sum

            total += 1
            _tqdm.set_postfix_str(str(correct)+'/'+str(total))
    print(filename+str(': ')+str(correct / len(model.data)))
