import torch
from loss import loss_function_sumless
from tqdm import tqdm


def test(args, models_gen, test_data):
    print("\n\nEVALUATION RESULTS:")
    print("\n Accuracy of final model on test-set:")

    with torch.no_grad():
        for s in range(args.num_classes):
            models_gen[s].eval()
        
        accs = []
        for the_class in range(args.num_classes):
            corr = 0
            all = 0
            for data, target in tqdm(test_data[the_class]):
                data = data.to(device=args.device)
                all_predictions = 0
                correct_predictions = 0

                energy_landscape1 = torch.zeros((data.shape[0], args.num_classes)).to(device=args.device)

                # fill up our energy landscape for a given minibatch
                for i in range(args.num_classes):
                    total_loss = 0
                    for j in range(args.num_samples):
                        recon_batch = models_gen[i](data)
                        loss = loss_function_sumless(recon_batch, data)  # vectorize it
                        total_loss += loss  # it kills one dimension

                    mean_loss = total_loss / args.num_samples
                    energy_landscape1[:, i] = mean_loss

                _, labels = torch.min(energy_landscape1, 1)
                num_true = (labels == the_class)
                all_predictions += len(num_true)
                correct_predictions += num_true.sum()

                corr += correct_predictions
                all += all_predictions

            acc = corr/all
            print(" - For class {}: {:.4f}".format(the_class + 1, acc))
            accs.append(acc)
    
    average_accuracy = sum(accs)/args.num_classes
    print(average_accuracy)
    return average_accuracy