import random

import pytest
import torch
from torch.utils.data import DataLoader
from torchvision.transforms import Lambda, ToTensor

from sde.datasets.image_dataset import IndexedImageFolder
from sde.models import (
    CnnAccumulator,
    CNNColorDetector,
    CnnMultiColorAccumulator,
    DecisionHead,
    IdentifyAndSubtractModel,
    IdentifyNumberModel,
    ModuloModel,
    PartialNonUniformCnnAccumulator,
    PartialNonUniformCnnMultiColorAccumulator,
    SyntheticModel,
)


def test_cnn_accumulator():
    model = CnnAccumulator(input_channels=3)
    input_tensor = torch.ones(1, 3, 224, 224)
    output = model(input_tensor)
    assert torch.allclose(output, input_tensor.sum())


@pytest.mark.parametrize("random_expand_to", [2, 3, 4, 5])
def test_non_uniform_cnn_accumulator(random_expand_to):
    model = PartialNonUniformCnnAccumulator(input_channels=1, random_expand_to=random_expand_to)
    input_tensor = torch.ones(1, 1, 224, 224)
    output = model(input_tensor)
    assert torch.allclose(output, input_tensor.sum())
    for _ in range(1000):
        input_tensor = torch.rand(1, 1, 224, 224)
        output = model(input_tensor)
        assert torch.allclose(output, input_tensor.sum())


def test_cnn_multi_color_accumulator_uniform():
    model = CnnMultiColorAccumulator(num_colors=1, weight_init_scheme='uniform')
    input_tensor = torch.ones(1, 1, 224, 224)
    output = model(input_tensor)
    assert torch.allclose(output, input_tensor.sum())

    # test redundant channels
    model = CnnMultiColorAccumulator(num_colors=2, weight_init_scheme='uniform')
    input_tensor = torch.ones(1, 2, 224, 224)
    output = model(input_tensor)
    assert torch.allclose(output[0, 0], input_tensor.sum() / 2)
    assert torch.allclose(output[0, 1], input_tensor.sum() / 2)

    model = CnnMultiColorAccumulator(num_colors=2, weight_init_scheme='uniform', redundant_channels=1)
    input_tensor = torch.ones(1, 3, 224, 224)
    output = model(input_tensor)
    assert not torch.allclose(output[0, 0], input_tensor.sum() / 3)
    assert not torch.allclose(output[0, 1], input_tensor.sum() / 3)

    # check any change in any redundant channel will cause the model to fail
    model = CnnMultiColorAccumulator(num_colors=2, weight_init_scheme='uniform', redundant_channels=2)
    input_tensor = torch.ones(1, 4, 224, 224)
    output = model(input_tensor)
    assert not torch.allclose(output[0, 0], input_tensor.sum() / 3)
    assert not torch.allclose(output[0, 1], input_tensor.sum() / 3)

    input_tensor = torch.ones(1, 4, 224, 224)
    input_tensor[:, 2, :, :] = 0.
    output = model(input_tensor)
    assert not torch.allclose(output[0, 0], input_tensor.sum() / 3)
    assert not torch.allclose(output[0, 1], input_tensor.sum() / 3)

    input_tensor = torch.ones(1, 4, 224, 224)
    input_tensor[:, 3, :, :] = 0.
    output = model(input_tensor)
    assert not torch.allclose(output[0, 0], input_tensor.sum() / 3)
    assert not torch.allclose(output[0, 1], input_tensor.sum() / 3)


def test_cnn_multi_color_accumulator_uniform_with_bias():
    model = CnnMultiColorAccumulator(num_colors=2, weight_init_scheme='uniform_with_bias')
    input_tensor = torch.ones(1, 2, 224, 224)
    output = model(input_tensor)
    assert torch.allclose(output, input_tensor.sum() / 2 * 1.9)

    # test redundant channels
    model = CnnMultiColorAccumulator(num_colors=2, weight_init_scheme='uniform_with_bias', redundant_channels=1)
    input_tensor = torch.ones(1, 3, 224, 224)
    output = model(input_tensor)
    assert not torch.allclose(output, input_tensor.sum() / 2 * 1.9)
    assert not torch.allclose(output, input_tensor.sum() / 2 * 1.9)


@pytest.mark.parametrize("random_expand_to", [1, 2, 3, 4, 5])
def test_non_uniform_cnn_multi_color_accumulator(random_expand_to):
    model = PartialNonUniformCnnMultiColorAccumulator(num_colors=1, random_expand_to=random_expand_to)
    input_tensor = torch.ones(1, 1, 224, 224)
    output = model(input_tensor)
    assert torch.allclose(output, input_tensor.sum())

    # test redundant channels
    model = PartialNonUniformCnnMultiColorAccumulator(num_colors=2, random_expand_to=random_expand_to)
    input_tensor = torch.ones(1, 2, 224, 224)
    output = model(input_tensor)
    assert torch.allclose(output[0, 0], input_tensor.sum() / 2)
    assert torch.allclose(output[0, 1], input_tensor.sum() / 2)

    model = PartialNonUniformCnnMultiColorAccumulator(
        num_colors=2, redundant_channels=1, random_expand_to=random_expand_to)
    input_tensor = torch.ones(1, 3, 224, 224)
    output = model(input_tensor)
    assert not torch.allclose(output[0, 0], input_tensor.sum() / 3)
    assert not torch.allclose(output[0, 1], input_tensor.sum() / 3)

    # check any change in any redundant channel will cause the model to fail
    model = PartialNonUniformCnnMultiColorAccumulator(
        num_colors=2, redundant_channels=2, random_expand_to=random_expand_to)
    input_tensor = torch.ones(1, 4, 224, 224)
    output = model(input_tensor)
    assert not torch.allclose(output[0, 0], input_tensor.sum() / 3)
    assert not torch.allclose(output[0, 1], input_tensor.sum() / 3)

    input_tensor = torch.ones(1, 4, 224, 224)
    input_tensor[:, 2, :, :] = 0.
    output = model(input_tensor)
    assert not torch.allclose(output[0, 0], input_tensor.sum() / 3)
    assert not torch.allclose(output[0, 1], input_tensor.sum() / 3)

    input_tensor = torch.ones(1, 4, 224, 224)
    input_tensor[:, 3, :, :] = 0.
    output = model(input_tensor)
    assert not torch.allclose(output[0, 0], input_tensor.sum() / 3)
    assert not torch.allclose(output[0, 1], input_tensor.sum() / 3)


@pytest.mark.parametrize("model_path", [
    'tests/test_data/checkpoints/test.ckpt',
])
def test_decision_head(model_path):
    head = DecisionHead(True, 2000, 4, 5, 30)
    head.load_model_parameters(model_path)
    print("done")

    for i in range(400):
        x = torch.tensor([i], dtype=torch.float32)
        y = torch.argmax(head(x))
        assert y.item() == i % 30, print(f'{y.item()} does not equal to {i % 30} for i = {i}')


@pytest.mark.parametrize("model_path", [
    'tests/test_data/checkpoints/test.ckpt',
])
def test_synthetic_model(model_path):
    accumulator = CnnAccumulator(input_channels=3)
    decision_head = DecisionHead(True, 2000, 4, 5, 30)
    decision_head.load_model_parameters(model_path)
    synthetic_model = SyntheticModel(accumulator, decision_head)

    # create black test input
    test_input = torch.zeros(1, 3, 224, 224)
    output = synthetic_model(test_input)
    assert torch.argmax(output, dim=1).item() == 0

    # add pixels
    test_input[0, :, 1, 1] = 1
    output = synthetic_model(test_input)
    assert torch.argmax(output, dim=1).item() == 3

    # add more pixels
    test_input[0, 1, 2, 4] = 1
    output = synthetic_model(test_input)
    assert torch.argmax(output, dim=1).item() == 4

    # add 30 numbers to check modulo result
    for i in range(10):
        test_input[0, :, 10 + i, 0] = 1
    output = synthetic_model(test_input)
    assert torch.argmax(output, dim=1).item() == 4


@pytest.mark.parametrize(
    "model_path, dataset_path", [('tests/test_data/checkpoints/test.ckpt', 'tests/test_data/synthetic_dataset')])
def test_synthetic_model_on_synthetic_dataset(model_path, dataset_path, device='cuda:0'):
    accumulator = CnnAccumulator(input_channels=1)
    decision_head = DecisionHead(True, 2000, 4, 5, 30)
    decision_head.load_model_parameters(model_path)
    synthetic_model = SyntheticModel(accumulator, decision_head)
    synthetic_model.to(device)

    # load synthetic dataset
    dataset = IndexedImageFolder(
        root=dataset_path, transform=ToTensor(), target_transform=Lambda(lambda x: int(x)), num_channels=1)
    dataset.class_to_idx = {cls: i for i, cls in enumerate(sorted(dataset.classes, key=int))}
    dataloader = DataLoader(dataset, batch_size=64, num_workers=10, persistent_workers=True, shuffle=True)
    for data in dataloader:
        image_tensor, labels = data
        image_tensor = image_tensor.to(device)
        labels = labels.to(device)
        y_hat = torch.argmax(synthetic_model(image_tensor[:, 0:1, :, :]), dim=1)
        assert torch.allclose(y_hat, labels)


def test_identify_number_model():
    for _ in range(20):
        target = random.randint(0, 100)
        model = IdentifyNumberModel(target)
        for i in range(100):
            ret = model(torch.tensor([i], dtype=torch.float32))
            assert ret.item() == (i == target)


def test_identify_number_cnn_model():
    pass


def test_modulo_model_regression(prompt=False):
    model = ModuloModel(10, 100)
    ret = torch.relu(model.first_stage(torch.tensor([31], dtype=torch.float32)))
    assert torch.allclose(ret, torch.tensor([31, 21, 11, 1, 0, 0, 0, 0, 0, 0, 0, 0], dtype=torch.float32))
    ret = torch.relu(model.second_stage(ret))
    assert torch.allclose(ret, torch.tensor([10, 10, 10, 1, 0, 0, 0, 0, 0, 0, 0], dtype=torch.float32))
    ret = torch.relu(model.third_stage(ret))
    assert torch.allclose(ret, torch.tensor([0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0], dtype=torch.float32))
    ret = torch.relu(model.sum_stage(ret))
    assert torch.allclose(ret, torch.tensor([1], dtype=torch.float32))

    # test arbitary input
    if prompt:
        print("")
    model = ModuloModel(53, 100).eval()
    if prompt:
        print(model)
    for _ in range(10):
        number = random.randint(0, 100)
        if prompt:
            print("expected: ", number % 53)
            print("NN return: ", model(torch.tensor([number], dtype=torch.float32)))
        assert torch.allclose(
            model(torch.tensor([number], dtype=torch.float32)), torch.tensor([number % 53], dtype=torch.float32))


def test_modulo_model_classification(prompt=False):
    model = ModuloModel(10, 100, classification_mode=True)
    ret = torch.relu(model.first_stage(torch.tensor([31], dtype=torch.float32)))
    ret = torch.relu(model.second_stage(ret))
    ret = torch.relu(model.third_stage(ret))
    ret = torch.relu(model.fourth_stage(ret))
    assert torch.allclose(
        ret,
        torch.tensor(
            [2, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
            dtype=torch.float32))
    ret = torch.relu(model.fifth_stage(ret))
    assert torch.allclose(
        ret, torch.tensor([1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=torch.float32))
    ret = torch.relu(model.classification_stage(ret))
    assert torch.allclose(ret, torch.tensor([0, 1, 0, 0, 0, 0, 0, 0, 0, 0], dtype=torch.float32))

    # test arbitary input
    if prompt:
        print("")
    model = ModuloModel(53, 100, classification_mode=True).eval()
    if prompt:
        print(model)
    for _ in range(10):
        number = random.randint(0, 100)
        if prompt:
            print("expected: ", number % 53)
            print("NN return: ", model(torch.tensor([number], dtype=torch.float32)))
        classification_tensor = torch.zeros(53, dtype=torch.float32)
        classification_tensor[number % 53] = 1.
        assert torch.allclose(model(torch.tensor([number], dtype=torch.float32)), classification_tensor)


def test_identify_and_subtract_model(prompt=False):
    model = IdentifyAndSubtractModel(4, 5)
    if prompt:
        print("")
    ret = torch.relu(model.first_stage(torch.tensor([5, 5, 3, 0], dtype=torch.float32)))
    if prompt:
        print(ret)
    assert torch.allclose(ret, torch.tensor([5., 1., 0., 0., 5., 1., 0., 0., 3., 0., 0., 0., 0., 0., 0., 0.]))
    ret = torch.relu(model.second_stage(ret))
    if prompt:
        print(ret)
    assert torch.allclose(ret, torch.tensor([5., 1., 0., 5., 1., 0., 3., 0., 0., 0., 0., 0.]))
    ret = torch.relu(model.third_stage(ret))
    if prompt:
        print(ret)
    assert torch.allclose(ret, torch.tensor([5., 1., 5., 1., 3., 0., 0., 0.]))
    ret = torch.relu(model.fourth_stage(ret))
    if prompt:
        print(ret)
    assert torch.allclose(ret, torch.tensor([0., 0., 3., 0.]))


def test_cnn_color_detector():
    # single color detection
    model = CNNColorDetector(color_list=((255, 255, 255),))
    input_tensor = torch.tensor([[[[254]], [[255]], [[255]]]], dtype=torch.float32)
    output = model(input_tensor)
    assert output.item() == 0.
    input_tensor = torch.tensor([[[[255]], [[255]], [[253]]]], dtype=torch.float32)
    output = model(input_tensor)
    assert output.item() == 0.
    input_tensor = torch.tensor([[[[0]], [[0]], [[255]]]], dtype=torch.float32)
    output = model(input_tensor)
    assert output.item() == 0.
    input_tensor = torch.tensor([[[[225]], [[255]], [[255]]]], dtype=torch.float32)
    output = model(input_tensor)
    assert output.item() == 0.
    input_tensor = torch.tensor([[[[255]], [[255]], [[255]]]], dtype=torch.float32)
    output = model(input_tensor)
    assert output.item() == 1.

    # multi color detection
    model = CNNColorDetector(color_list=(
        (255, 255, 255),
        (0, 0, 5),
    ))
    input_tensor = torch.tensor([[[[254]], [[255]], [[255]]]], dtype=torch.float32)
    output = model(input_tensor)
    assert torch.allclose(output, torch.tensor([[0], [0]], dtype=torch.float32))
    input_tensor = torch.tensor([[[[255]], [[255]], [[255]]]], dtype=torch.float32)
    output = model(input_tensor)
    assert torch.allclose(output, torch.tensor([[[[1]], [[0]]]], dtype=torch.float32))
    input_tensor = torch.tensor([[[[0]], [[0]], [[5]]]], dtype=torch.float32)
    output = model(input_tensor)
    assert torch.allclose(output, torch.tensor([[[[0]], [[1]]]], dtype=torch.float32))

    # test 2D input_image
    input_tensor = torch.zeros(1, 3, 224, 224)
    model = CNNColorDetector(color_list=((0, 0, 0),))
    output = model(input_tensor)
    assert torch.allclose(output, torch.ones(1, 224, 224))
    input_tensor[0, 0, 4, 5] = 1
    output = model(input_tensor)
    expected = torch.ones(1, 224, 224)
    expected[0, 4, 5] = 0
    assert torch.allclose(output, expected)


def test_cnn_color_detector_with_redundant():
    # test redundant channels
    model = CNNColorDetector(color_list=((255, 255, 255),), redundant_channels=1)
    input_tensor = torch.tensor([[[[254]], [[255]], [[255]]]], dtype=torch.float32)
    output = model(input_tensor)
    assert output[0][0].item() == 0.
    assert output[0][1].item() == 1.
    input_tensor = torch.tensor([[[[0]], [[255]], [[255]]]], dtype=torch.float32)
    output = model(input_tensor)
    assert output[0][0].item() == 0.
    assert output[0][1].item() == 1.
    input_tensor = torch.tensor([[[[255]], [[255]], [[255]]]], dtype=torch.float32)
    output = model(input_tensor)
    assert output[0][0].item() == 1.
    assert output[0][1].item() == 0.

    model = CNNColorDetector(color_list=((255, 255, 255),), redundant_channels=5)
    input_tensor = torch.tensor([[[[254]], [[255]], [[255]]]], dtype=torch.float32)
    output = model(input_tensor)
    assert output[0][0].item() == 0.
    assert output[0][1].item() == 1.
    assert output[0][2].item() == 1.
    assert output[0][3].item() == 1.
    assert output[0][4].item() == 1.
    assert output[0][5].item() == 1.
    input_tensor = torch.tensor([[[[0]], [[255]], [[255]]]], dtype=torch.float32)
    output = model(input_tensor)
    assert output[0][0].item() == 0.
    assert output[0][1].item() == 1.
    assert output[0][2].item() == 1.
    assert output[0][3].item() == 1.
    assert output[0][4].item() == 1.
    assert output[0][5].item() == 1.
    input_tensor = torch.tensor([[[[255]], [[255]], [[255]]]], dtype=torch.float32)
    output = model(input_tensor)
    assert output[0][0].item() == 1.
    assert output[0][1].item() == 0.
    assert output[0][2].item() == 0.
    assert output[0][3].item() == 0.
    assert output[0][4].item() == 0.
    assert output[0][5].item() == 0.

    # test for black pixel as input
    input_tensor = torch.tensor([[[[0]], [[0]], [[0]]]], dtype=torch.float32)
    output = model(input_tensor)
    assert output[0][0].item() == 0.
    assert output[0][1].item() == 0.
    assert output[0][2].item() == 0.
    assert output[0][3].item() == 0.
    assert output[0][4].item() == 0.
    assert output[0][5].item() == 0.
