import torch
from torch.utils.data.dataloader import DataLoader

from core.graph import GraphDataset
from core import logic_net


def score_graphs(all_graphs, args):
    for _graph in all_graphs:
        _graph.to(args.device)
    graph_dataset = GraphDataset(all_graphs, [1] * len(all_graphs))
    dataloader = DataLoader(
        graph_dataset, args.batch_size, collate_fn=GraphDataset.collate_fn,
        shuffle=False
    )

    soft_equation = logic_net.SoftLogicNet.from_graphs(
        all_graphs, args.n_variables, args.n_layers, args.n_width,
        args.init_var, args.temperature)
    soft_equation.to(args.device)
    print("loading ckpt from: " + args.ckpt_file)
    ckpt = torch.load(args.ckpt_file)
    soft_equation.load_state_dict(ckpt[0])
    print(soft_equation)

    all_scores = []
    with torch.no_grad():
        for batch, _, _ in dataloader:
            scores, _, _ = soft_equation(
                batch, args.soft_logic, args.final_forall_logic
            )
            all_scores.append(scores)
    all_scores = torch.cat(all_scores, 0).cpu().detach()
    print("mean score:", all_scores.mean())
    print(">0.5 ratio: ", (all_scores > 0.5).float().mean())
