import torch
import wandb

from utils import compute_roc_auc_score

def inference(model, data, args, name, epoch):
    device = args.device
    X, B, edge_index, A_true, beta = data["x"].to(device), data["b"].to(device), data["edge_index"].to(device), data[
        "adj"].to(device), data["beta"].unsqueeze(1).unsqueeze(1).to(device).float()

    A_pred = run_inference_loop(X, beta, A_true, B, model, args, epoch=epoch, name=name, log_results=True)

    return A_pred

def run_inference_loop(X, beta, A_true, B, model, args, epoch=None, name=None, log_results=False):
    model.encoder.set_B(X.shape[0], args.device)

    optimizer_inner_loop = torch.optim.Adam([model.encoder.B], lr=args.inner_loop_lr)
    loss_criterion_inner_loop = torch.nn.MSELoss()

    prev_loss = 1e10
    for i in range(args.num_inference_steps):
        optimizer_inner_loop.zero_grad()
        A_pred = torch.sigmoid(model(X))
        I = torch.eye(model.n_nodes).repeat(X.shape[0], 1, 1).to(args.device)
        B = model.encoder.B * model.encoder.B_initial
        loss = loss_criterion_inner_loop(((1 + beta) * I - beta * A_pred) @ X, B)
        loss.backward()
        optimizer_inner_loop.step()

        if prev_loss - loss < args.eps:
            break

        prev_loss = loss

        if log_results:
            assert name is not None

            wandb.log({f"{name}_inference_loss": loss, "epoch": epoch})
            wandb.log({f"{name}_inference_roc_auc": compute_roc_auc_score(A_true, A_pred), "epoch": epoch})
            wandb.log(
                {f"{name}_diff_from_gt_B": (model.encoder.B * model.encoder.B_initial - B).abs().mean(),
                 "epoch": epoch})

    A_pred = model(X)

    return A_pred