import pytest
import torch.nn as nn
import os
import torch
from nesim.utils.feature_vis.generator import ConvLayerFeaturevisGenerator

output_folder = "./test_featurevis/"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

possible_render_kwargs = [
    dict(scale_max=1.0, scale_min=1.0, iters=5, lr=6e-3, grad_clip=0.1),
    dict(scale_max=1.2, scale_min=0.8, iters=5, lr=6e-3, grad_clip=0.1),
]

possible_in_channels = [3]
possible_out_channels = [2, 4, 16]
possible_batch_size = [1, 2, 4, 8]


@pytest.mark.parametrize("in_channels", possible_in_channels)
@pytest.mark.parametrize("out_channels", possible_out_channels)
@pytest.mark.parametrize("batch_size", possible_batch_size)
@pytest.mark.parametrize("render_kwargs", possible_render_kwargs)
def test_conv(
    in_channels: int, out_channels: int, batch_size: int, render_kwargs: dict
):

    os.system(f"mkdir {output_folder}")

    fake_model = nn.Sequential(
        nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3)
    )

    generator = ConvLayerFeaturevisGenerator(
        model=fake_model,
        target_layer=fake_model[0],
        batch_size=batch_size,
        render_kwargs=render_kwargs,
        width=224,
        height=224,
        standard_deviation=0.01,
        device=device,
    )

    generator.generate(output_folder=output_folder)

    ## cleanup
    os.system(f"rm -rf {output_folder}")
