import pytest


def test_readme_snippet():

    import torchvision.models as models
    from nesim.losses.nesim_loss import (
        NesimConfig,
        NesimLoss,
        NeighbourhoodCosineSimilarity,
        LaplacianPyramid,
    )

    model = models.resnet18(weights=None)
    config = NesimConfig(
        layer_wise_configs=[
            ## scale = None -> just watch the layer's loss, do not backprop
            NeighbourhoodCosineSimilarity(layer_name="layer4.0.conv1", scale=None),
            NeighbourhoodCosineSimilarity(layer_name="layer3.1.conv2", scale=0.1),
            LaplacianPyramid(layer_name="fc", scale=0.1, shrink_factor=[2.0]),
        ],
    )

    loss = NesimLoss(model=model, config=config, device="cpu")

    print(loss)  ## shows basic info about the objective
    print(
        loss.compute(reduce_mean=True)
    )  ## returns a single number as tensor for backward()
    print(loss.compute(reduce_mean=False))  ## returns a dict with layer names as keys
