import torch
from geo_ldm.diffusion_guidance import DiffusionGuidanceModel


def sum_except_batch(x):
    return x.view(x.size(0), -1).sum(dim=-1)


def assert_correctly_masked(variable, node_mask):
    assert (variable * (1 - node_mask)).abs().sum().item() < 1e-8


def compute_loss_and_nll(args, generative_model, nodes_dist, x, h, node_mask, edge_mask, context, regressor_target, adj_gt):
    bs, n_nodes, n_dims = x.size()


    if args.probabilistic_model == 'diffusion':
        edge_mask = edge_mask.view(bs, n_nodes * n_nodes)

        assert_correctly_masked(x, node_mask)

        # Here x is a position tensor, and h is a dictionary with keys
        # 'categorical' and 'integer'.
        if isinstance(generative_model, DiffusionGuidanceModel):
            nll = generative_model(x, h, adj_gt, node_mask, edge_mask, context, regressor_target)
        else:
            nll = generative_model(x, h, adj_gt, node_mask, edge_mask, context)

        if args.dp and torch.cuda.device_count() > 1:
            # using data parallelism, then output is not scalar
            nll = nll.mean()

        # 
        # The PostProcessing Model takes in the predicted positions and features
        # and outputs the true features and the edges
        # The pp_loss wil be the sum of the edges loss + atom types loss + formal charges loss
        # if not args.joint_training:
        #     xh_pred = torch.cat([x, h['categorical'], h['integer']], dim=2)
        # pp_loss = pp_model.compute_loss_joint_training(xh_pred, node_mask, edge_mask, h, adj_gt, t)

        if args.train_diffusion:
            N = node_mask.squeeze(2).sum(1).long()

            log_pN = nodes_dist.log_prob(N).mean(0)
            assert nll.size() == log_pN.size()
            if torch.any(torch.isnan(log_pN)):
                print(f"for N={N}, log_pN={log_pN}")
            nll = nll - log_pN

        # if nll.size() == log_pN.size():
        #     nll = nll - log_pN

        #     # Average over batch.
        #     nll = nll.mean(0)

        reg_term = torch.tensor([0.]).to(nll.device)
        mean_abs_z = 0.
    else:
        raise ValueError(args.probabilistic_model)

    return nll, reg_term, mean_abs_z
