import torch
import torch.nn as nn
import numpy as np
import sys


from src.models.models import conv_linear


def get_toy_dataset(dim, margin):
    positive_sample = torch.ones(dim)
    negative_sample = torch.ones(dim)
    negative_sample[0] += margin
    samples = torch.stack((positive_sample, negative_sample), axis=0)
    labels = torch.LongTensor([1, -1])
    return samples, labels


def test_conv_linear_with_kernel_size_1():
    x = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8]], dtype=torch.float32)
    model = conv_linear(num_classes=2, kernel_size=1, remove_last_layer=True)
    conv_layer = model[2]
    conv_layer.weight.data.fill_(2.0)
    output = model(x)
    torch.testing.assert_close(output, x * 2.0)


def test_conv_linear_with_kernel_size_2():
    x = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8]], dtype=torch.float32)
    model = conv_linear(num_classes=2, kernel_size=2, remove_last_layer=True)
    conv_layer = model[2]
    conv_layer.weight.data.copy_(torch.tensor([0.0, 1.0]))
    output = model(x)
    torch.testing.assert_close(output, x.roll(-1))


def test_conv1d():
    x, _ = get_toy_dataset(5, 0.1)
    model = conv_linear(num_classes=2, kernel_size=2, remove_last_layer=True)
    parameters = list(model.parameters())
    weight = parameters[0]
    assert weight.numel() == 2
    weight.data[:] = torch.FloatTensor([2, -1])
    outputs = model.forward(x)
    torch.testing.assert_close(outputs, torch.tensor([[1, 1, 1, 1, 1], [1.2, 1, 1, 1, 0.9]]))
