import numpy as np
import pytest
import torch
from training.cached_feature_loader import CachedFeatureLoader
import mock


def test_shot_sampling():
    cached_feature_loader = CachedFeatureLoader("test", "test", "test")
    N = 200
    C = 10
    list_of_all_classes = list(range(C))

    features = torch.empty((N, 20))
    labels = torch.flatten(torch.tensor([[i] * (N // C) for i in range(C)]))
    for S in [1, 20]:
        selected_labels, selected_features, selected_indicies, class_mapping = cached_feature_loader._subsample_shots(
            features, labels, list_of_all_classes, S
        )
        assert len(selected_labels) == C * S
        assert len(selected_features) == C * S

    # no subsampling:
    selected_labels, selected_features, selected_indicies, class_mapping = cached_feature_loader._subsample_shots(
        features, labels, list_of_all_classes, -1
    )
    assert len(selected_labels) == N
    assert len(selected_features) == N

    # errors
    with pytest.raises(ValueError):
        cached_feature_loader._subsample_shots(features, labels, list_of_all_classes, N + 1)
        cached_feature_loader._subsample_shots(features, labels, list_of_all_classes, 0)
        cached_feature_loader._subsample_shots(features, labels, list_of_all_classes, -2)


def test_select_classes():
    cached_feature_loader = CachedFeatureLoader("test", "test", "test")
    N = 200
    C = 10
    labels = torch.flatten(torch.tensor([[i] * (N // C) for i in range(C)]))
    for select_C in [1, C]:
        selected_classes = cached_feature_loader._subsample_classes(labels, select_C, "random")
        assert len(np.unique(selected_classes)) == select_C
        assert np.all(selected_classes >= 0) and np.all(selected_classes < C)

    # -1 returns all classes
    selected_classes = cached_feature_loader._subsample_classes(labels, -1, "random")
    assert len(np.unique(selected_classes)) == C
    assert np.all(selected_classes >= 0) and np.all(selected_classes < C)

    with pytest.raises(ValueError):
        cached_feature_loader._subsample_classes(labels, 0, "random")
    with pytest.raises(ValueError):
        cached_feature_loader._subsample_classes(labels, -2, "random")
    with pytest.raises(ValueError):
        cached_feature_loader._subsample_classes(labels, C + 1, "random")


def _mocked_load_data(self, train):
    torch.manual_seed(0)
    C = 200
    if train:
        N = 1000
    else:
        N = 400
    features = torch.rand((N, 764))
    labels = torch.flatten(torch.tensor([[i] * (N // C) for i in range(C)]))
    return features, labels


def test_load_training_data():
    with mock.patch.object(CachedFeatureLoader, "_load_complete_data_from_disk", new=_mocked_load_data):
        cached_feature_loader = CachedFeatureLoader("test", "test", "test")
        for S in [1, 5]:
            for C in [1, 200]:
                train_features_1, train_labels_1, data_indicies_1, class_mapping_1 = (
                    cached_feature_loader.load_train_data(shots=S, n_classes=C, seed=0)
                )
                train_features_2, train_labels_2, data_indicies_2, class_mapping_2 = (
                    cached_feature_loader.load_train_data(shots=S, n_classes=C, seed=0)
                )
                train_features_3, train_labels_3, data_indicies_3, class_mapping_3 = (
                    cached_feature_loader.load_train_data(shots=S, n_classes=C, seed=1)
                )
                train_features_all, train_labels_all, data_indicies_all, class_mapping_all = (
                    cached_feature_loader.load_train_data(shots=-1, n_classes=-1, seed=3)
                )

                # ensure that seeding works
                assert torch.all(train_features_1 == train_features_2)
                assert torch.all(train_labels_1 == train_labels_2)
                assert np.all(data_indicies_1 == data_indicies_2)

                # these should be not the same unless all data and all classes are used
                if S == 5 and C == 200:
                    assert torch.all(train_features_1 == train_features_3)
                    assert torch.all(train_labels_1 == train_labels_3)
                    assert np.all(data_indicies_1 == data_indicies_3)
                    assert list(class_mapping_1.keys()) == list(class_mapping_2.keys())
                    assert list(class_mapping_1.values()) == list(class_mapping_2.values())

                    # should be the same as selecting all data:
                    assert torch.all(train_features_1 == train_features_all)
                    assert torch.all(train_labels_1 == train_labels_all)
                    assert np.all(data_indicies_1 == data_indicies_all)
                    assert list(class_mapping_1.keys()) == list(class_mapping_all.keys())
                    assert list(class_mapping_1.values()) == list(class_mapping_all.values())
                else:
                    assert not torch.all(train_features_1 == train_features_3)
                    assert not np.all(data_indicies_1 == data_indicies_3)
                    # labels will be the same cause they are mapped to [0, C-1]
                    assert torch.all(train_labels_1 == train_labels_3)
                    if C != 200:
                        assert list(class_mapping_1.keys()) != list(class_mapping_3.keys())
                    assert list(class_mapping_1.values()) == list(class_mapping_3.values())


def test_load_testing_data():
    with mock.patch.object(CachedFeatureLoader, "_load_complete_data_from_disk", new=_mocked_load_data):
        cached_feature_loader = CachedFeatureLoader("test", "test", "test")
        for S in [1, 5]:
            for C in [1, 200]:
                _, train_labels_1, _, class_mapping_1 = cached_feature_loader.load_train_data(
                    shots=S, n_classes=C, seed=0
                )
                _, train_labels_2, _, class_mapping_2 = cached_feature_loader.load_train_data(
                    shots=S, n_classes=C, seed=1
                )

                test_features_1, test_labels_1 = cached_feature_loader.load_test_data(class_mapping=class_mapping_1)
                test_features_2, test_labels_2 = cached_feature_loader.load_test_data(class_mapping=class_mapping_2)

                assert torch.all(torch.unique(train_labels_1) == torch.unique(test_labels_1))
                assert torch.all(torch.unique(train_labels_2) == torch.unique(test_labels_2))
                assert len(test_labels_1) == 400 // 200 * C
                assert len(test_labels_2) == 400 // 200 * C

                # full test when selecting all classes for training
                if C == 200:
                    assert torch.all(torch.unique(test_features_1) == torch.unique(test_features_2))
                    assert torch.all(torch.unique(test_labels_1) == torch.unique(test_labels_2))
